# mmverify.py -- Proof verifier for the Metamath language
# Copyright (C) 2002 Raph Levien raph (at) acm (dot) org
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License

# This program has been tested with python 2.2.1.  It does not run on 1.5.2.

# Before using this program, any compressed proofs must be expanded with the
# Metamath program, e.g.:
#   $ ./metamath 'r set.mm' 'sa p *' 'w s set.mm' q > /dev/null
# To run the program, type
#   $ python mmverify.py < set.mm > set.log
# and set.log will have the verification results.

# (nm 27-Jun-05) mmverify.py requires that a $f hypothesis must not occur
# after a $e hypothesis in the same scope, even though this is allowed by
# the Metamath spec.  This is not a serious limitation since it can be
# met by rearranging the hypothesis order.

#
# wch A: Changed Frame.e from a list of statement to a list of (statement, label)
# wch B: write out proof as a string of labels
# wch C: Changed Frame.f from a list of (var, kind) to a list of (var, kind, label)
#

import GUI
import string
import cStringIO # wch added

verbosity = 1
verbosity = 0 # wch added

class toks:
    def __init__(self, lines):
        """
        lineList = lines.splitlines()
        print "len(lineList)", len(lineList)
        self.lines = iter(lineList)
        print "len(lines)", len(lines)
        """
        self.linesBuffer = cStringIO.StringIO(lines)
        self.tokbuf = []
    def read(self):
        try:
            while self.tokbuf == []:
                #line = self.lines.next()
                line = self.linesBuffer.readline()
                if line == "":
                    return None
                line = line.replace('$(', ' $( ')
                line = line.replace('$)', ' $) ')
                self.tokbuf = line.split()
                self.tokbuf.reverse()
            return self.tokbuf.pop()
        except StopIteration:
            return None

class Frame:
    def __init__(self):
        self.c = []
        self.v = []
        self.d = []
        self.f = []
        self.e = []

class FrameStack:
    def __init__(self):
        self.stack = []
    def push(self):
        frame = Frame()
        self.stack.append(frame)
    def pop(self):
        self.stack.pop()
    def add_c(self, tok):
        frame = self.stack[-1]
        if tok in frame.c:
            raise 'const already defined in scope'
        #@@@ frame.c.append(tok) wch: removed this duplicate statement
        if tok in frame.v:
            raise 'const already defined as var in scope'
        frame.c.append(tok)
    def add_v(self, tok):
        frame = self.stack[-1]
        if tok in frame.v:
            raise 'var already defined in scope'
        if tok in frame.c:
            raise 'var already defined as const in scope'
        frame.v.append(tok)
    def add_f(self, var, kind, label): # wch C: added label
        if not self.lookup_v(var):
            raise ('var in $f not defined: ' + var)
        if not self.lookup_c(kind):
            raise ('const in $f not defined' + kind)
        frame = self.stack[-1]
        for (v, k, l) in frame.f: # wch C: added label
            if v == var:
                raise 'var in $f already defined in scope'
        frame.f.append((var, kind, label)) # wch C: added label
    def add_e(self, stat, label): # wch A: added label
        frame = self.stack[-1]
        frame.e.append((stat, label)) # wch A: added label
    def add_d(self, stat):
        frame = self.stack[-1]
        for i in range(len(stat)):
            for j in range(i + 1, len(stat)):
                x, y = stat[i], stat[j]
                if x > y:
                    x, y = y, x
                if x != y and not (x, y) in frame.d:
                    frame.d.append((x, y))
    def lookup_c(self, tok):
        for i in range(len(self.stack) - 1, -1, -1):
            if tok in self.stack[i].c:
                return 1
    def lookup_v(self, tok):
        for i in range(len(self.stack) - 1, -1, -1):
            if tok in self.stack[i].v:
                return 1
    def lookup_f(self, var):
        for i in range(len(self.stack) - 1, -1, -1):
            frame = self.stack[i]
            for (v, k, l) in frame.f: # wch C: added label
                if v == var:
                    return k
    def lookup_d(self, x, y):
        # return 1 if disjoint, None if not
        if x > y:
            x, y = y, x
        for i in range(len(self.stack) - 1, -1, -1):
            frame = self.stack[i]
            if (x, y) in frame.d:
                return 1
    def get_fhypLabels(self, ehyps, stat): # wch: added this method
        fhypLabels = []
        mand_vars = []
        visible = []
        visible.extend(ehyps)
        visible.append(stat)
        for hyp in visible:
            for tok in hyp:
                if self.lookup_v(tok) and tok not in mand_vars:
                    mand_vars.append(tok)
        for i in range(len(self.stack) - 1, -1, -1):
            fr = self.stack[i]
            for j in range(len(fr.f) - 1, -1, -1):
                (v, k, lab) = fr.f[j] # wch C: added label
                if v in mand_vars:
                    fhypLabels.append(lab)
                    mand_vars.remove(v)
        fhypLabels.reverse()
        return fhypLabels
    def get_ehypLabels(self): # wch A: added this method
        ehypLabels = []
        for fr in self.stack:
            for (e_statement, e_label) in fr.e:
                ehypLabels.append(e_label)
        return ehypLabels
    def make_assertion(self, stat):
        mand_vars = []
        hyps = []
        for fr in self.stack:
            for (e_statement, e_label) in fr.e: # wch A: iterate pairs in fr.e
                hyps.append(e_statement) # wch A: just take statement, not label
        #@@@ frame = self.stack[-1] wch: statement not needed
        visible = hyps[:]
        visible.append(stat)
        for hyp in visible:
            for tok in hyp:
                if self.lookup_v(tok) and tok not in mand_vars:
                    mand_vars.append(tok)
        dm = []
        #@@@ for i in range(len(self.stack)): wch: replace these two statements with following
        #@@@     fr = self.stack[i]
        for fr in self.stack:
            for (x, y) in fr.d:
                if x in mand_vars and y in mand_vars and not (x, y) in dm:
                    dm.append((x, y))
        mand_hyps = []
        for i in range(len(self.stack) - 1, -1, -1):
            fr = self.stack[i]
            for j in range(len(fr.f) - 1, -1, -1):
                (v, k, l) = fr.f[j] # wch C: added label
                if v in mand_vars:
                    mand_hyps.append((k, v))
                    mand_vars.remove(v)
        mand_hyps.reverse()
        if verbosity >= 18:
            print 'ma:', (dm, mand_hyps, hyps, stat)
        return (dm, mand_hyps, hyps, stat)

class MMCompile:
    def __init__(self):
        self.fs = FrameStack()
        self.labels = {}
        self.labelDict = {} # wch added
        self.comment = None # wch added
        self.symbolSet = set() # wch added
    def readc(self, toks):
        while 1:
            tok = toks.read()
            if tok == None:
                return None
            if tok == '$(':
                buffer = cStringIO.StringIO() # wch added
                while 1:
                    tok = toks.read()
                    if tok == '$)':
                        self.comment = buffer.getvalue() # wch added
                        self.comment.rstrip()
                        buffer.close() # wch added
                        break
                    buffer.write(tok) # wch added
                    buffer.write(" ") # wch added
            else:
                return tok
    def readstat(self, toks):
        # read out to $. token; return list
        stat = []
        while 1:
            tok = self.readc(toks)
            if tok == None:
                raise 'EOF before $.'
            elif tok == '$.':
                break
            stat.append(tok)
        return stat
    def read(self, toks):
        self.fs.push()
        label = None
        while 1:
            tok = self.readc(toks)
            if tok == None or tok == '$}':
                self.comment = None # wch added
                break
            elif tok == '$c':
                for tok in self.readstat(toks):
                    self.fs.add_c(tok)
                self.comment = None # wch added
            elif tok == '$v':
                for tok in self.readstat(toks):
                    self.fs.add_v(tok)
                self.comment = None # wch added
            elif tok == '$f':
                stat = self.readstat(toks)
                if len(stat) != 2: raise '$f must have be length 2'
                if verbosity >= 15: print label, '$f', stat[0], stat[1], '$.'
                self.fs.add_f(stat[1], stat[0], label) # wch C: added label
                if not label: raise '$f must have label'
                self.labels[label] = ('$f', stat[0], stat[1])
                self.labelDict[label] = ('$f', self.comment,
                        stat[0] + " " + stat[1]) # wch added
                self.updateSymbolSet(stat) # wch added
                self.comment = None # wch added
                label = None
            elif tok == '$a':
                stat = self.readstat(toks)
                if not label: raise '$a must have label'
                passertion = self.fs.make_assertion(stat) # wch: added
                self.labels[label] = ('$a', passertion)
                ehyps =  passertion[2] # wch: added
                fhypLabels = self.fs.get_fhypLabels(ehyps, stat) # wch: added
                ehypLabels = self.fs.get_ehypLabels() # wch A: added
                statString = " ".join(stat) # wch: write out stat as a string of tokens
                self.labelDict[label] = ('$a', self.comment,
                        fhypLabels, ehypLabels, statString, None) # wch added
                self.updateSymbolSet(stat) # wch added
                self.comment = None # wch added
                label = None
            elif tok == '$e':
                stat = self.readstat(toks)
                self.fs.add_e(stat, label) # wch A: added label
                if not label: raise '$e must have label'
                self.labels[label] = ('$e', stat)
                statString = " ".join(stat) # wch A: write out stat as a string of tokens
                self.labelDict[label] = ('$e', self.comment,
                        statString) # wch added
                self.updateSymbolSet(stat) # wch added
                self.comment = None # wch added
                label = None
            elif tok == '$p':
                stat = self.readstat(toks)
                proof = None
                try:
                    i = stat.index('$=')
                    proof = stat[i + 1:]
                    stat = stat[:i]
                except ValueError:
                    raise '$p must contain proof after $='
                if not label: raise '$p must have label'
                if verbosity >= 1: print 'verifying', label
                self.verify(stat, proof)
                passertion = self.fs.make_assertion(stat) # wch: added
                self.labels[label] = ('$p', passertion)
                ehyps =  passertion[2] # wch: added
                fhypLabels = self.fs.get_fhypLabels(ehyps, stat) # wch: added
                ehypLabels = self.fs.get_ehypLabels() # wch A: added
                statString = " ".join(stat) # wch B: write out stat as a string of tokens
                proofString = " ".join(proof) # wch B: write out proof as a string of labels
                self.labelDict[label] = ('$p', self.comment,
                        fhypLabels, ehypLabels, statString, proofString) # wch added
                self.updateSymbolSet(stat) # wch added
                self.comment = None # wch added
                label = None
            elif tok == '$d':
                stat = self.readstat(toks)
                self.fs.add_d(stat)
                self.comment = None # wch added
            elif tok == '${':
                self.comment = None # wch added
                self.read(toks)
            elif tok[0] != '$':
                label = tok
            else:
                print 'tok:', tok
        self.fs.pop()
    def apply_subst(self, stat, subst):
        result = []
        for tok in stat:
            if subst.has_key(tok):
                result.extend(subst[tok])
            else:
                result.append(tok)
        if verbosity >= 20: print 'apply_subst', (stat, subst), '=', result
        return result
    def find_vars(self, stat):
        vars = []
        for x in stat:
            if not x in vars and self.fs.lookup_v(x):
                vars.append(x)
        return vars
    def verify(self, stat, proof):
        stack = []
        for label in proof:
            step = self.labels[label]
            if verbosity >= 10: print label, ':', step
            if step[0] == '$f':
                stack.append([step[1], step[2]])
            elif step[0] in ('$a', '$p'):
                (distinct, mand_var, hyp, result) = step[1]
                if verbosity >= 12: print (distinct, mand_var, hyp, result)
                npop = len(mand_var) + len(hyp)
                sp = len(stack) - npop
                if sp < 0: raise 'stack underflow'
                subst = {}
                for (k, v) in mand_var:
                    entry = stack[sp]
                    if entry[0] != k:
                        print (k, v), entry
                        raise "stack entry doesn't match mandatory var hyp"
                    subst[v] = entry[1:]
                    sp += 1
                if verbosity >= 15: print 'subst:', subst
                for x, y in distinct:
                    if verbosity >= 16:
                        print 'dist', x, y, subst[x], subst[y]
                    x_vars = self.find_vars(subst[x])
                    y_vars = self.find_vars(subst[y])
                    if verbosity >= 16:
                        print 'V(x) =', x_vars
                        print 'V(y) =', y_vars
                    for gam in x_vars:
                        for delt in y_vars:
                            if gam == delt:
                                raise "disjoint violation " + gam
                            x, y = gam, delt
                            if x > y:
                                x, y = y, x
                            if not self.fs.lookup_d(x, y):
                                raise "disjoint violation " + x + ", " + y
                for h in hyp:
                    entry = stack[sp]
                    subst_h = self.apply_subst(h, subst)
                    if entry != subst_h:
                        print 'st:', entry
                        print 'hy:', subst_h
                        raise "stack entry doesn't match hypothesis"
                    sp += 1
                del stack[len(stack) - npop:]
                stack.append(self.apply_subst(result, subst))
            elif step[0] == '$e':
                stack.append(step[1])
            if verbosity >= 12: print 'st:', stack
        if len(stack) != 1: raise 'stack has >1 entry at end'
        if stack[0] != stat: raise "assertion proved doesn't match"

    def dump(self):
        print self.labels

    def writeLabelDict(self, basename): # wch added this method
        app_dirname = GUI.application().app_dirname
        outFileName = app_dirname + "/" + basename + ".mmo"
        output = open(outFileName, "w")
        print outFileName, "outFileName"
        for line in self.labelDict.items():
            print str(line)
            output.write(str(line)+"\n")
        output.close()
    
    def updateSymbolSet(self, statement): # wch added this method
        for symbol in statement:
            self.symbolSet.add(symbol)

"""
import sys
mmCompile = MMCompile()
#@@@ mmCompile.read(toks(sys.stdin)) wch: replaced with following 2 statements
file = open("set.mms")
mmCompile.read(toks(file))
print len(mmCompile.labels), "len(mmCompile.labels)" #@@@ wch: added
##mmCompile.dump()
mmCompile.writeLabelDict("set.mmo") #@@@ wch: added
"""
