#
#  MMProof.py
#

class MMProof(object):

    def __init__(self, statementLabel, statementType, statementStr, refs,
                labelDict, **kwds): # refs is string of labels, making up the proof
        super(MMProof, self).__init__(**kwds)
        self.statementLabel = statementLabel
        self.statementType = statementType
        self.statementStr = statementStr
        self.targetRefs = refs
        self.labelDict = labelDict
        self.pleasePrint = False # wch: for investigation
        self.recalculate()

    def resetResults(self):
        self.targetTypes = []
        self.targetExpressions = []
        self.targetHyps = []
        self.sourceTypes = []
        self.sourceExpressions = []
        self.sourceHyps = []
        self.sourceStack = []
        self.targetStack = []
        self.stepHypsStack = []
        
    def recalculate(self):
        self.resetResults()
        stack = []
        step = 0
        if self.targetRefs == None: return # wch: stub
        for targetRef in self.targetRefs.split():
            statementData = self.labelDict[targetRef]
            statementType = statementData[0]
            if statementType == '$f':
                targetStatementStr = statementData[2]
                step += 1
                stackItem = [step, targetRef, statementType, targetStatementStr]
                stack.append(stackItem)
                self.sourceStack.append([step, None, None, None])
                self.targetStack.append(stackItem)
                self.stepHypsStack.append([])
            elif statementType == '$e':
                targetStatementStr = statementData[2]
                step += 1
                stackItem = [step, targetRef, statementType, targetStatementStr]
                stack.append(stackItem)
                self.sourceStack.append([step, None, None, None])
                self.targetStack.append(stackItem)
                self.stepHypsStack.append([])
            elif statementType in ('$a', '$p'):
                if self.pleasePrint: print "statementData", statementData
                (mand_var, hyp, sourceStatementStr, proofStr) = statementData[2:]
                npop = len(mand_var) + len(hyp)
                sp = len(stack) - npop
                if sp < 0: raise 'stack underflow'
                subst = {}
                for varLabel in mand_var: # wch: following is convoluted
                    varData = self.labelDict[varLabel]
                    varStatementType = varData[0]
                    varStatementStr = varData[2]
                    varStatementList = varStatementStr.split()
                    varType = varStatementList[0]
                    varValue = varStatementList[1]
                    entry = stack[sp][3]
                    entryList = entry.split()
                    entryType = entryList[0]
                    entryValue = entryList[1:] # wch: I think I want a list
                    if entryType != varType:
                        raise "stack entry doesn't match mandatory var hyp"
                    subst[varValue] = entryValue # wch: is this correct??
                    hStep = stack[sp][0]
                    self.sourceStack[hStep - 1][1] = varLabel # sourcStack counts from 0
                    self.sourceStack[hStep - 1][2] = varStatementType # sourcStack counts from 0
                    self.sourceStack[hStep - 1][3] = varStatementStr # sourcStack counts from 0
                    sp += 1
                """ wch: hold the following for the time being
                for x, y in distinct:
                    x_vars = self.find_vars(subst[x])
                    y_vars = self.find_vars(subst[y])
                    for gam in x_vars:
                        for delt in y_vars:
                            if gam == delt:
                                raise "disjoint violation " + gam
                            x, y = gam, delt
                            if x > y:
                                x, y = y, x
                            if not self.fs.lookup_d(x, y):
                                raise "disjoint violation " + x + ", " + y
                """
                for h in hyp:
                    entry = stack[sp][3]
                    hStatementData = self.labelDict[h]
                    hStatementType = hStatementData[0]
                    hStatementStr = hStatementData[2]
                    subst_h = self.apply_subst(hStatementStr, subst)
                    if entry != subst_h:
                        if self.pleasePrint: print 'st:', entry
                        if self.pleasePrint: print 'hy:', subst_h, hStatementStr
                        #@raise "stack entry doesn't match hypothesis"
                    hStep = stack[sp][0]
                    self.sourceStack[hStep - 1][1] = h # sourcStack counts from 0
                    self.sourceStack[hStep - 1][2] = hStatementType # sourcStack counts from 0
                    self.sourceStack[hStep - 1][3] = hStatementStr # sourcStack counts from 0
                    sp += 1
                step += 1
                stepHyps = []
                for stackItem in stack[len(stack) - npop:]:
                    if self.pleasePrint: print stackItem, "stackItem"
                    stepHyps.append(stackItem[0])
                self.stepHypsStack.append(stepHyps)
                del stack[len(stack) - npop:]
                targetStatementStr = self.apply_subst(sourceStatementStr, subst)
                stackItem = [step, targetRef, statementType, targetStatementStr]
                stack.append(stackItem)
                self.sourceStack.append([step, None, None, sourceStatementStr])
                self.targetStack.append(stackItem)
        if self.pleasePrint: print "subst=", subst
        if len(stack) != 1:
            #raise 'stack has >1 entry at end'
            print 'stack has >1 entry at end'
        if stack[0][3] != self.statementStr:
            print "assertion proved doesn't match", stack[0][3], self.statementStr
            #@raise "assertion proved doesn't match"
        self.sourceStack[-1][1] = self.statementLabel
        self.sourceStack[-1][2] = self.statementType
        for item in self.sourceStack:
            if self.pleasePrint: print item
        for item in self.targetStack:
            if self.pleasePrint: print item
        for item in self.stepHypsStack:
            if self.pleasePrint: print item

    def apply_subst(self, stat, subst):
        result = []
        for tok in stat.split(): # wch: added split, really need to redo this
            if subst.has_key(tok):
                result.extend(subst[tok])
            else:
                result.append(tok)
        return " ".join(result) # wch: added join, really need to redo this
