# This is STMcmc, for super tree mcmc.
# Started 18 March 2011, first commit 22 March 2011.
import p4.pf as pf
import p4.func
from p4.var import var
import math
import random
import string
import sys
import time
import copy
import os
import pickle
import glob
import numpy as np
from p4.p4exceptions import P4Error
from p4.treepartitions import TreePartitions
from p4.constraints import Constraints
from p4.tree import Tree
import datetime
import itertools
from scipy.optimize import minimize
import logging
import bitarray
def choose(n, k):
    """
    A fast way to calculate binomial coefficients 
    by Andrew Dalke (contrib).
    """
    if 0 <= k <= n:
        ntok = 1
        ktok = 1
        for t in xrange(1, min(k, n - k) + 1):
            ntok *= n
            ktok *= t
            n -= 1
        return ntok // ktok
    else:
        return 0
# def nSplits(n):
#     mySum = 0
#     for k in range(2, n-1):
#         mySum += choose(n-1, k)
#     return mySum
def bForN(n):
    # This is the log version of this function.  The max diff (in
    # log(result)) between this and the non-log function seems to be
    # about 2.5e-10 for n up to 10000.
    prodLog = 0.0
    if n > 3:
        for k in range(4, n + 1):
            prodLog += math.log((2 * k) - 5)
    return prodLog
def BS2009_Eqn30_ZTApprox(n, beta, cT):
    # This log version of this function differs from from the non-log
    # version (in log(result)) by at most 6.82e-13 for n up to 150,
    # over a wide range of beta (0.001 -- 1000) and cT (2 -- n/2)
    myLambda = cT / (2.0 * n)
    tester = 0.5 * math.log((n - 3.) / myLambda)
    epsilon = math.exp(-2. * beta)
    bigANEpsilon = 1 + (((2. * n) - 3.) * epsilon) + \
        (2. * ((n * n) - (4. * n) - 6.) * epsilon * epsilon)
    termA = math.log(bigANEpsilon + 6 * cT * epsilon * epsilon)
    if beta < tester:
        termB = -(2. * beta) * (n - 3.) + \
            (myLambda * (math.exp(2. * beta) - 1.))
        termB += bForN(n)
        if termA > termB:
            return termA
        else:
            return termB
    else:
        return termA
def popcountA(k, nBits):
    count = 0
    for i in range(nBits):
        tester = 1 << i
        if tester > k:
            return count
        if tester & k:
            count += 1
    return count
def bitReduce(bk, txBits, lLen, sLen, allOnes):
    # print "bitReduce: bk %i, txBits %i, lLen %i, sLen %i, allOnes %i" % (bk,
    # txBits, lLen, sLen, allOnes)
    newBk = 0
    counter = 0
    pops = 0
    for pos in range(lLen):
        tester = 1 << pos
        # print "pos %2i, tester: %3i" % (pos, tester)
        if tester & txBits:
            # print "    tester & txBits -- True"
            if tester & bk:
                adder = 1 << counter
                # print "        adding:", adder
                newBk += adder
                pops += 1
            else:
                # print "        not adding"
                pass
            counter += 1
    if (1 & newBk):
        # print "flipping"
        newBk = allOnes ^ newBk
        pops = sLen - pops
    # print "returning newBk %i, pops %i" % (newBk, pops)
    return newBk, pops
if 0:  # test bitReduce
    sk = 6   # always at least 2 bits, even
    txBits = 30
    lLen = 5
    sLen = 4
    allOnes = 15
    print("     sk: %3i  %s" % (sk, p4.func.getSplitStringFromKey(sk, lLen)))
    print("taxBits: %3i  %s" % (txBits, p4.func.getSplitStringFromKey(txBits, lLen)))
    rsk, popcount = bitReduce(sk, txBits, lLen, sLen, allOnes)
    print("    rsk: %3i  %s" % (rsk, p4.func.getSplitStringFromKey(rsk, sLen)))
    print("   popcount %i" % popcount)
    #     sk:   6  .**..
    #     taxBits:  30  .****
    #     rsk:  12  ..**
    #     popcount 2
def maskedSymmetricDifference(skk, skSet, taxBits, longLen, shortLen, allOnes):
    if 0:
        print("-" * 50)
        print("skk (skk_ppy1 from the current supertree)")
        for sk in skk:
            print(p4.func.getSplitStringFromKey(sk, longLen))
        print("skSet (from input tree)")
        for sk in skSet:
            print(p4.func.getSplitStringFromKey(sk, shortLen))
        print("taxBits:", taxBits, p4.func.getSplitStringFromKey(taxBits, longLen))
    newSkk = []
    for sk in skk:
        reducedSk, popcount = bitReduce(
            sk, taxBits, longLen, shortLen, allOnes)
        if 0:
            print("taxBits: %s  " % p4.func.getSplitStringFromKey(taxBits, longLen), end=' ')
            print("%4i %s  " % (sk, p4.func.getSplitStringFromKey(sk, longLen)), end=' ')
            print("%4i %s  %i" % (reducedSk, p4.func.getSplitStringFromKey(reducedSk, shortLen), popcount))
        if popcount <= 1 or popcount >= (shortLen - 1):
            pass
        else:
            newSkk.append(reducedSk)
    newSkkSet = set(newSkk)
    # print newSkkSet, skSet
    # print "reduced supertree splits =  newSkkSet = %s" % newSkkSet
    ret = len(newSkkSet.symmetric_difference(skSet))
    # print "symmetric difference %i" % ret
    nCherries = 0
    for sk in newSkkSet:
        popcount = popcountA(sk, shortLen)
        if popcount == 2:
            nCherries += 1
        # not "elif", because they might both be True
        if popcount == (shortLen - 2):
            nCherries += 1
    # print "nCherries %i" % nCherries
    return ret, nCherries
def slowQuartetDistance(st, inputTree):
    dst = st.dupe()
    toRemove = []
    for n in dst.iterLeavesNoRoot():
        if n.name not in inputTree.taxNames:
            toRemove.append(n)
    for n in toRemove:
        dst.removeNode(n)
    qd = dst.topologyDistance(inputTree, metric='scqdist')
    return qd
class STChain(object):
    def __init__(self, aSTMcmc, chNum):
        gm = ['STChain.__init__()']
        self.stMcmc = aSTMcmc
        self.chNum = chNum    # Does not change.  Used with fastspa only
        self.tempNum = chNum  # 'temp'erature, not 'temp'orary;  changes when swapped.
        if chNum == 0:
            self.curTree = aSTMcmc.tree.dupe()
            self.propTree = aSTMcmc.tree.dupe()
        else:      # heated chains are randomized, unless var.mcmc_sameBigTToStartOnAllChains is set.
            if var.mcmc_sameBigTToStartOnAllChains:  # False by default
                self.curTree = aSTMcmc.tree.dupe()
                self.propTree = aSTMcmc.tree.dupe()
            else:
                rTree = aSTMcmc.tree.dupe()
                rTree.randomizeTopology(randomBrLens = False)
                rTree.stripBrLens()
                rTree.setPreAndPostOrder()
                self.curTree = rTree
                self.propTree = rTree.dupe()
        self.logProposalRatio = 0.0
        self.logPriorRatio = 0.0
        self.frrf = None
        #self.nInTreeSplits = 0
        if self.stMcmc.modelName.startswith('SR2008_rf'):
            self.curTree.beta = self.stMcmc.beta
            self.propTree.beta = self.stMcmc.beta
            if self.stMcmc.stRFCalc == 'purePython1':
                self.getTreeLogLike_ppy1()
            elif self.stMcmc.stRFCalc == 'fastReducedRF':
                self.startFrrf()
                self.getTreeLogLike_fastReducedRF()
            elif self.stMcmc.stRFCalc == 'bitarray':
                self.setupBitarrayCalcs()
                self.getTreeLogLike_bitarray()
            self.curTree.logLike = self.propTree.logLike
        elif self.stMcmc.modelName.startswith('SPA'):
            self.curTree.spaQ = np.array([self.stMcmc.spaQ])
            self.propTree.spaQ = np.array([self.stMcmc.spaQ])
            # for t in self.stMcmc.trees:
            #    self.nInTreeSplits += len(t.splSet)
            # print "Got nInTreeSplits %s" % self.nInTreeSplits
            self.setupBitarrayCalcs()
            self.getTreeLogLike_spa_bitarray()
            if var.stmcmc_useFastSpa:
                #print("Here E.  bitarray propTree.logLike is %f" % self.propTree.logLike)
                fspaLike = self.stMcmc.fspa.calcLogLike(self.chNum)
                diff = math.fabs(self.propTree.logLike - fspaLike)
                #print("Got fspaLike %f, diff %g" % (fspaLike, diff))
                if diff > 1e-13:
                    gm.append("bad fastspa likelihood calc, %f vs %f, diff %f" % (self.propTree.logLike, fspaLike, diff))
                    raise P4Error(gm)
            self.curTree.logLike = self.propTree.logLike
        elif self.stMcmc.modelName.startswith('QPA'):
            self.curTree.spaQ = self.stMcmc.spaQ
            self.propTree.spaQ = self.stMcmc.spaQ
            self.nPossibleQuartets = choose(self.stMcmc.tree.nTax, 4) * 3
            self.getTreeLogLike_qpa_slow()
            self.curTree.logLike = self.propTree.logLike
        else:
            gm.append('Unknown modelName %s' % self.stMcmc.modelName)
            raise P4Error(gm)
        if 0:
            print("STChain init()")
            self.curTree.draw()
            print("logLike is %f" % self.curTree.logLike)
    def getTreeLogLike_qpa_slow(self):
        gm = ["STChain.getTreeLogLike_qpa_slow()"]
        if self.propTree.spaQ > 1. or self.propTree.spaQ <= 0.0:
            gm.append("bad propTree.spaQ value %f" % self.propTree.spaQ)
            raise P4Error(gm)
        
        for n in self.propTree.iterInternalsPostOrder():
            if n == self.propTree.root:
                break
            n.stSplitKey = n.leftChild.stSplitKey
            p = n.leftChild.sibling
            while p:
                n.stSplitKey |= p.stSplitKey    # "or", in-place
                p = p.sibling
        self.propTree.skk = [
            n.stSplitKey for n in self.propTree.iterInternalsNoRoot()]
        self.propTree.qSet = set()
        for sk in self.propTree.skk:
            ups = [txBit for txBit in self.propTree.taxBits if (sk & txBit)]
            downs = [
                txBit for txBit in self.propTree.taxBits if not (sk & txBit)]
            for down in itertools.combinations(downs, 2):
                if down[0] > down[1]:
                    down = (down[1], down[0])
                for up in itertools.combinations(ups, 2):
                    if up[0] > up[1]:
                        up = (up[1], up[0])
                    if down[0] < up[0]:
                        self.propTree.qSet.add(down + up)
                    else:
                        self.propTree.qSet.add(up + down)
        # print self.propTree.qSet
        self.propTree.nQuartets = len(self.propTree.qSet)
        if self.propTree.nQuartets:
            q = self.propTree.spaQ / self.propTree.nQuartets
            R = 1. - self.propTree.spaQ
            r = R / (self.nPossibleQuartets - self.propTree.nQuartets)
            logq = math.log(q)
        else:
            R = 1.
            r = R / self.nPossibleQuartets
        logr = math.log(r)
        self.propTree.logLike = 0.0
        for it in self.stMcmc.trees:
            for qu in it.qSet:
                if qu in self.propTree.qSet:
                    self.propTree.logLike += logq
                else:
                    self.propTree.logLike += logr
    def getTreeLogLike_spa_bitarray(self):
        gm = ["STChain.getTreeLogLike_spa_bitarray"]
        if self.propTree.spaQ[0] > 1. or self.propTree.spaQ[0] <= 0.0:
            gm.append("bad propTree.spaQ value %f" % self.propTree.spaQ)
            raise P4Error(gm)
        slowCheck = False
        if slowCheck:
            print("\n", "-" * 30)
            print("Super tree: ", end=' ')
            self.propTree.write()
            slowCheckLogLike = 0.0
            for it in self.stMcmc.trees:
                it.makeSplitKeys()
                it.inbb = [n.br for n in it.iterInternalsNoRoot()]
        self.propTree.logLike = 0.0
        #sumOfLogqs = 0.0
        #sumOfLogrs = 0.0
        # self.propTree.draw()
        for it in self.stMcmc.trees:
            if 0:
                print("-" * 50)
                it.draw()
                print("baTaxBits %s" % it.baTaxBits)
                print("firstTax at %i" % it.firstTax)
            if 0:
                print("  input tree: ", end=' ')
                it.write()
            if slowCheck:
                stDupe = self.propTree.dupe()
                toRemove = []
                for n in stDupe.iterLeavesNoRoot():
                    if n.name not in it.taxNames:
                        toRemove.append(n)
                for n in toRemove:
                    stDupe.removeNode(n)
                stDupe.taxNames = it.taxNames
                stDupe.makeSplitKeys(makeNodeForSplitKeyDict=True)
            # No need to consider (masked) splits with less than two
            # 1s or more than it.nTax - 2 1s.
            upperGood = it.nTax - 2
            relevantStSplits = []
            for n in self.propTree.iterInternalsNoRoot():
                # Choose which spl (spl or spl2) based on it.firstTax)
                if n.ss.spl[it.firstTax]:
                    n.ss.theSpl = n.ss.spl
                else:
                    n.ss.theSpl = n.ss.spl2
                n.ss.maskedSplitWithTheFirstTaxOne = n.ss.theSpl & it.baTaxBits
                n.ss.onesCount = n.ss.maskedSplitWithTheFirstTaxOne.count()
                if 0:
                    print("bigT node %i" % n.nodeNum)
                    print("  theSpl is %s" % n.ss.theSpl)
                    print("  maskedSplitWithTheFirstTaxOne %s" % n.ss.maskedSplitWithTheFirstTaxOne)
                    print("  onesCount %i" % n.ss.onesCount)
                    if n.ss.onesCount >= 2 and n.ss.onesCount <= upperGood:
                        print("    -> relevant")
                    else:
                        print("    -> not relevant")
                if n.ss.onesCount >= 2 and n.ss.onesCount <= upperGood:
                    relevantStSplits.append(n.ss)
            nonRedundantStSplitDict = {}
            for ss in relevantStSplits:
                ss.bytes = ss.maskedSplitWithTheFirstTaxOne.tobytes()
                nonRedundantStSplitDict[ss.bytes] = ss
            if 0:
                for ss in relevantStSplits:
                    ss.dump()
                print("There are %i relevant splits in the st for this it." % len(relevantStSplits))
                for ss in nonRedundantStSplitDict:
                    ss.dump()
                print("There are %i non-redundant splits in the st for this it." % len(nonRedundantStSplitDict))
            if 0:
                if(len(relevantStSplits) != len(nonRedundantStSplitDict)):
                    print("Gen %12i: Got %i relevantStSplits; %i in nonRedundantStSplitDict" % (
                        self.stMcmc.gen, len(relevantStSplits),len(nonRedundantStSplitDict)))
            # S_st is the number of splits in the reduced supertree
            S_st = len(nonRedundantStSplitDict)
            if slowCheck:
                # stDupe.draw()
                # print "the drawing above is stDupe"
                slowCheckS_st = len([n for n in stDupe.iterInternalsNoRoot()])
                assert S_st == slowCheckS_st
            # S is the number of possible splits in an it-sized tree
            S = 2 ** (it.nTax - 1) - (it.nTax + 1)
            # print("    S=%i, S_st=%i, S_x=%i" % (S, S_st, S-S_st))
            if S_st:
                q = self.propTree.spaQ[0] / S_st
                R = 1. - self.propTree.spaQ[0]
                r = R / (S - S_st)
                #print("q=%g" % q)
                logq = math.log(q)
            else:
                R = 1.
                r = R / S
            #print ("r=%g" % r)
            logr = math.log(r)
            for n in it.internals:
                ret = nonRedundantStSplitDict.get(n.stSplitKeyBytes)
                if ret:
                    if self.stMcmc.useSplitSupport and n.br.support != None:
                        thisSplitLike = math.log(r + (n.br.support * (q - r)))
                        self.propTree.logLike += thisSplitLike
                        # sumOfLogqs += thisSplitLike
                    else:
                        self.propTree.logLike += logq
                        # sumOfLogqs += logq
                else:
                    # If we are here when S_st is zero, then q is undefined.
                    # So fall into the else clause
                    if 0:
                        # This has the dodgy assumption that 1-support is in the supertree.
                        # Might be zero support for a split in the supertree.
                        if self.stMcmc.useSplitSupport and S_st and n.br.support != None:
                            self.propTree.logLike += math.log(
                                r + ((1. - n.br.support) * (q - r)))
                        else:
                            self.propTree.logLike += logr
                    else:
                        # This does not make the assumption above.  Safer.
                        self.propTree.logLike += logr
                        # sumOfLogrs += logr
            if slowCheck:
                for inb in it.inbb:
                    splitString = p4.func.getSplitStringFromKey(
                        inb.splitKey, it.nTax)
                    print("    %s " % splitString, end=' ')
                    ret = stDupe.nodeForSplitKeyDict.get(inb.splitKey)
                    # Here we need to check that S_st is not zero.  If it is,
                    # then q is undefined.
                    if self.stMcmc.useSplitSupport and inb.support != None and S_st:
                        if ret:
                            thisLogQc = math.log(r + (inb.support * (q - r)))
                            print("qc %.3f" % thisLogQc)
                            slowCheckLogLike += thisLogQc
                        else:
                            thisLogRc = math.log(
                                r + ((1. - inb.support) * (q - r)))
                            print("rc %.3f" % thisLogRc)
                            slowCheckLogLike += thisLogRc
                    else:
                        if ret:
                            print("q %.3f" % q)
                            slowCheckLogLike += logq
                        else:
                            print("r %.3f" % r)
                            slowCheckLogLike += logr
        if 1:
            if slowCheck:
                # print self.propTree.logLike, slowCheckLogLike
                myDiff = self.propTree.logLike - slowCheckLogLike
                if math.fabs(myDiff) > 1.e-12:
                    gm.append("Bad like calc. slowCheck %f, bitarray %f, diff %g" % (
                        slowCheckLogLike, self.propTree.logLike, myDiff))
                    raise P4Error(gm)
        # print("sumOfLogqs = %g, sumOfLogrs = %g" % (sumOfLogqs, sumOfLogrs))
    def setupBitarrayCalcs(self):
        # Prepare self.propTree (ie bigT).  First make n.stSplitKeys.  These
        # are temporary; the info is held more permanently in n.ss, a BigTSplitStuff object
        for n in self.propTree.iterPostOrder():
            if n == self.propTree.root:
                break
            if n.isLeaf:
                spot = self.stMcmc.taxNames.index(n.name)
                self.stMcmc.tBits[spot] = True
                n.stSplitKey = bitarray.bitarray(self.stMcmc.tBits)
                self.stMcmc.tBits[spot] = False
            else:
                n.stSplitKey = n.leftChild.stSplitKey.copy()
                p = n.leftChild.sibling
                while p:
                    n.stSplitKey |= p.stSplitKey    # "or", in-place
                    p = p.sibling
        # Next transfer the internal node split keys to BigTSplitStuff objects
        for n in self.propTree.iterInternalsNoRoot():
            n.ss = BigTSplitStuff()
            n.ss.spl = n.stSplitKey
            n.ss.spl2 = n.ss.spl.copy()
            n.ss.spl2.invert()
        # This next one will be empty, not used immediately, but will
        # be used after supertree rearrangements.
        self.propTree.root.ss = BigTSplitStuff()
        if self.stMcmc.modelName.startswith('SPA') and var.stmcmc_useFastSpa:
            self.stMcmc.fspa.setBigT(len(self.propTree.nodes), self.propTree.nTax, 
                                     self.propTree.postOrder, self.propTree.spaQ)
            for nNum in self.propTree.postOrder:
                if nNum == -10000:
                    break
                n = self.propTree.nodes[nNum]
                if n == self.propTree.root or n.isLeaf:
                    theSpl = '0'
                    theSpl2 = '0'
                else:
                    theSpl = n.ss.spl.to01()
                    theSpl2 = n.ss.spl2.to01()
                self.stMcmc.fspa.setBigTNoSpl(self.chNum, nNum, theSpl, theSpl2)
    def refreshBitarrayPropTree(self):
        # Refresh self.propTree (ie bigT) after a topology change.
        for n in self.propTree.iterPostOrder():
            if n == self.propTree.root:
                break
            if n.isLeaf:
                pass
            else:
                n.stSplitKey = n.leftChild.stSplitKey.copy()
                p = n.leftChild.sibling
                while p:
                    n.stSplitKey |= p.stSplitKey    # "or", in-place
                    p = p.sibling
        # Next transfer the internal node split keys to BigTSplitStuff objects
        for n in self.propTree.iterInternalsNoRoot():
            n.ss.spl = n.stSplitKey
            n.ss.spl2 = n.ss.spl.copy()
            n.ss.spl2.invert()
        if self.stMcmc.modelName.startswith('SPA') and var.stmcmc_useFastSpa:
            for nNum in self.propTree.postOrder:
                if nNum == -10000:
                    break
                n = self.propTree.nodes[nNum]
                if n == self.propTree.root or n.isLeaf:
                    theSpl = '0'
                    theSpl2 = '0'
                else:
                    theSpl = n.ss.spl.to01()
                    theSpl2 = n.ss.spl2.to01()
                self.stMcmc.fspa.setBigTNoSpl(self.chNum, nNum, theSpl, theSpl2)
    def startFrrf(self):
        # if using self.stMcmc.stRFCalc= 'fastReducedRF'
        self.frrf = self.stMcmc.Frrf(len(self.stMcmc.taxNames))
        self.bigTr = self.frrf.setBigT(
            len(self.propTree.nodes), self.propTree.nTax, self.propTree.postOrder)
        for n in self.propTree.nodes:
            if n.parent:
                self.bigTr.setParent(n.nodeNum, n.parent.nodeNum)
            if n.leftChild:
                self.bigTr.setLeftChild(n.nodeNum, n.leftChild.nodeNum)
            else:
                self.bigTr.setNodeTaxNum(
                    n.nodeNum, self.stMcmc.taxNames.index(n.name))
            if n.sibling:
                self.bigTr.setSibling(n.nodeNum, n.sibling.nodeNum)
        if 1:
            for t in self.stMcmc.trees:
                tr = self.frrf.appendInTree(len(t.nodes), t.nTax, t.postOrder)
                for n in t.nodes:
                    if n.parent:
                        tr.setParent(n.nodeNum, n.parent.nodeNum)
                    if n.leftChild:
                        tr.setLeftChild(n.nodeNum, n.leftChild.nodeNum)
                    else:
                        tr.setNodeTaxNum(
                            n.nodeNum, self.stMcmc.taxNames.index(n.name))
                    if n.sibling:
                        tr.setSibling(n.nodeNum, n.sibling.nodeNum)
        self.frrf.setInTreeTaxBits()
        self.frrf.setInTreeInternalBits()
        self.frrf.maybeFlipInTreeBits()
        self.frrf.setBigTInternalBits()
        # self.frrf.dump()
    def getTreeLogLike_ppy1(self):
        gm = ['STChain.getTreeLogLike_pp1']
        self.propTree.makeSplitKeys()
        self.propTree.skk = [
            n.br.splitKey for n in self.propTree.iterInternalsNoRoot()]
        self.propTree.logLike = 0.0
        for t in self.stMcmc.trees:
            # Get the distance
            thisDist = None
            if self.stMcmc.modelName.startswith('SR2008_rf'):
                thisDist, nCherries = maskedSymmetricDifference(self.propTree.skk, t.skSet,
                                                                t.taxBits, self.stMcmc.nTax, t.nTax, t.allOnes)
            else:
                raise P4Error(
                    "STChain.getTreeLogLike_ppy1() unknown model '%s'" % self.stMcmc.modelName)
            # Now multiply by beta, and do approximate Z_T
            assert thisDist != None
            beta_distance = self.propTree.beta * thisDist
            if self.stMcmc.modelName == 'SR2008_rf_ia':
                self.propTree.logLike -= beta_distance
            elif self.stMcmc.modelName.startswith('SR2008_rf_aZ'):
                log_approxZT = BS2009_Eqn30_ZTApprox(
                    t.nTax, self.propTree.beta, nCherries)
                if 0:
                    # Testing, testing ...
                    assert self.propTree.beta == 0.1
                    assert t.nTax == 6
                    if nCherries == 2:
                        log_approxZT = 4.13695897651  # exact
                    elif nCherries == 3:
                        log_approxZT = 4.14853562562
                self.propTree.logLike -= log_approxZT
                self.propTree.logLike -= beta_distance
            else:
                gm.append("Unknown modelName %s" % self.stMcmc.modelName)
                raise P4Error(gm)
    def getTreeLogLike_fastReducedRF(self):
        slowCheck = False
        if slowCheck:
            self.getTreeLogLike_ppy1()
            savedLogLike = self.propTree.logLike
        self.frrf.wipeBigTPointers()
        for n in self.propTree.nodes:
            if n.parent:
                self.bigTr.setParent(n.nodeNum, n.parent.nodeNum)
            if n.leftChild:
                self.bigTr.setLeftChild(n.nodeNum, n.leftChild.nodeNum)
            # else:
            #    bigTr.setNodeTaxNum(n.nodeNum, tNames.index(n.name))
            if n.sibling:
                self.bigTr.setSibling(n.nodeNum, n.sibling.nodeNum)
        self.frrf.setBigTInternalBits()
        if self.stMcmc.modelName == 'SR2008_rf_ia':
            sd = self.frrf.getSymmDiff()
            self.propTree.logLike = -sd * self.propTree.beta
        elif self.stMcmc.modelName.startswith('SR2008_rf_aZ'):
            self.propTree.logLike = self.frrf.getLogLike(self.propTree.beta)
        if slowCheck:
            if self.propTree.logLike != savedLogLike:
                gm = ['STChain.getTreeLogLike_fastReducedRF()']
                gm.append("Slow likelihood %f" % savedLogLike)
                gm.append("Fast likelihood %f" % self.propTree.logLike)
                raise P4Error(gm)
    def getTreeLogLike_bitarray(self):
        self.propTree.logLike = 0.0
        slowCheck = False
        if slowCheck:
            self.propTree.makeSplitKeys()
            self.propTree.skk = [
                n.br.splitKey for n in self.propTree.iterInternalsNoRoot()]
        for t in self.stMcmc.trees:
            if 0:
                print("-" * 50)
                t.draw()
                print("baTaxBits %s" % t.baTaxBits)
                print("firstTax at %i" % t.firstTax)
            # splitStuff objects with onesCount >= 2 and <= t.nTax = 2
            usables = []
            # No need to consider (masked) splits with less than two
            # 1s or more than nTax - 2 1s.  The nTax depends on the
            # input tree.
            upperGood = t.nTax - 2
            for n in self.propTree.iterInternalsNoRoot():
                # Choose which spl (spl or spl2) based on t.firstTax)
                if n.ss.spl[t.firstTax]:
                    n.ss.theSpl = n.ss.spl
                else:
                    n.ss.theSpl = n.ss.spl2
                n.ss.maskedSplitWithTheFirstTaxOne = n.ss.theSpl & t.baTaxBits
                n.ss.onesCount = n.ss.maskedSplitWithTheFirstTaxOne.count()
                if 0:
                    print("bigT node %i" % n.nodeNum)
                    print("  theSpl is %s" % n.ss.theSpl)
                    print("  maskedSplitWithTheFirstTaxOne %s" % n.ss.maskedSplitWithTheFirstTaxOne)
                    print("  onesCount %i" % n.ss.onesCount)
                    if n.ss.onesCount >= 2 and n.ss.onesCount <= upperGood:
                        print("    -> used")
                    else:
                        print("    -> not used")
                if n.ss.onesCount >= 2 and n.ss.onesCount <= upperGood:
                    usables.append(n.ss)
            usablesDict = {}
            for usable in usables:
                usable.bytes = usable.maskedSplitWithTheFirstTaxOne.tobytes()
                usablesDict[usable.bytes] = usable
            splSet = set()   # bytes, for RF calculation
            for usable in usables:
                # splSet.add(n.ss.maskedSplitWithTheFirstTaxOne.tobytes())
                splSet.add(usable.bytes)
            thisBaRF = len(splSet.symmetric_difference(t.splSet))
            if slowCheck:  # with purePython1
                thisPPyRF, thisPPyNCherries = maskedSymmetricDifference(self.propTree.skk, t.skSet,
                                                                        t.taxBits, self.stMcmc.nTax, t.nTax, t.allOnes)
                if thisBaRF != thisPPyRF:
                    raise P4Error("bitarray and purePython1 RF calcs differ.")
            beta_distance = self.propTree.beta * thisBaRF
            if self.stMcmc.modelName == 'SR2008_rf_ia':
                self.propTree.logLike -= beta_distance
            elif self.stMcmc.modelName.startswith('SR2008_rf_aZ'):
                nCherries = 0
                for ba in splSet:
                    theSS = usablesDict[ba]
                    # theSS.dump()
                    if theSS.onesCount == 2:
                        nCherries += 1
                    if theSS.onesCount == upperGood:
                        nCherries += 1
                if slowCheck:
                    if nCherries != thisPPyNCherries:
                        raise P4Error(
                            "bitarray and purePython1 nCherries calcs differ.")
                log_approxZT = BS2009_Eqn30_ZTApprox(
                    t.nTax, self.propTree.beta, nCherries)
                self.propTree.logLike -= log_approxZT
                self.propTree.logLike -= beta_distance
            else:
                gm.append("Unknown model %s" % self.stMcmc.modelName)
                raise P4Error(gm)
    def proposePolytomy(self, theProposal):
        theProposal.doAbort = False
        dbug = False
        if dbug:
            # print "proposePolytomy() starting with this tree ..."
            #self.propTree.draw(width=80, addToBrLen=0.2)
            print("j There are %i internal nodes." % self.propTree.nInternalNodes)
            if self.propTree.nInternalNodes == 1:
                print("-> so its a star tree -> proposeDeleteEdge is not possible.")
            elif self.propTree.nInternalNodes == self.propTree.nTax - 2:
                print("-> so its a fully-resolved tree, so proposeAddEdge is not possible.")
        if self.propTree.nInternalNodes == 1:  # a star tree
            self.proposeAddEdge(theProposal)
        elif self.propTree.nInternalNodes == self.propTree.nTax - 2:
            candidateNodes = self._getCandidateNodesForDeleteEdge()
            if candidateNodes:
                self.proposeDeleteEdge(theProposal, candidateNodes)
            else:
                #gm = ["proposePolytomy()"]
                #gm.append("The tree is fully resolved, so I can't proposeAddEdge()")
                #gm.append("But there are no suitable nodes to remove.")
                #raise P4Error(gm)
                theProposal.doAbort = True
                self.curTree._nInternalNodes = self.propTree._nInternalNodes
                return
        else:
            r = random.random()
            #r = 0.4
            if r < 0.5:
                self.proposeAddEdge(theProposal)
            else:
                candidateNodes = self._getCandidateNodesForDeleteEdge()
                if candidateNodes:
                    self.proposeDeleteEdge(theProposal, candidateNodes)
                else:
                    self.proposeAddEdge(theProposal)
    def proposeAddEdge(self, theProposal):
        gm = ["STChain.proposeAddEdge()"]
        # print "proposeAddEdge() here"
        dbug = False
        pTree = self.propTree
        if 0:
            print("proposeAddEdge(), starting with this tree ...")
            pTree.draw()
            print("k There are %i internal nodes." % pTree.nInternalNodes)
            print("root is node %i" % pTree.root.nodeNum)
        allPolytomies = []
        for n in pTree.iterInternalsNoRoot():
            if n.getNChildren() > 2:
                allPolytomies.append(n)
        if pTree.root.getNChildren() > 3:
            allPolytomies.append(pTree.root)
        theChosenPolytomy = random.choice(allPolytomies)
        # We want to choose one of the possible ways to add a node.  See
        # Lewis et al page 246, left top.  "The number of distinct ways of
        # dividing k edges into two groups, making sure that at least 3
        # edges are attached to each node afterwards, is 2^{k-1} - k - 1".
        # For non-root polytomies (with 3 or more children), it is
        # straightforward, but for root polytomies (ie with 4 or more
        # children) it is different.  I think in the case of root
        # polytomies that they will be equivalent to non-root polytomies
        # if I arbitrarily consider one randomly chosen child node to
        # take the role that the parent takes in the non-root-polytomies.
        # So a 4-child root will be considered to have a parent-like node
        # and 3 children.
        if theChosenPolytomy != pTree.root:
            nChildren = theChosenPolytomy.getNChildren()
            k = nChildren + 1
            childrenNodeNums = pTree.getChildrenNums(theChosenPolytomy)
        else:
            # Its the root.  So we say that a random child takes the role
            # of the "parent", for purposes of these calculations.
            nChildren = theChosenPolytomy.getNChildren() - 1  # n - 1 children
            k = nChildren + 1
            # Yes, all children.
            childrenNodeNums = pTree.getChildrenNums(theChosenPolytomy)
        nPossibleWays = math.pow(2, k - 1) - k - 1
        if dbug:
            print("These nodes are polytomies: %s" % [n.nodeNum for n in allPolytomies])
            print("We randomly choose to do node %i" % theChosenPolytomy.nodeNum)
            print("It has %i children, so k=%i, so there are %i possible ways to add a node." % (
                nChildren, k, nPossibleWays))
        # We want to choose one of the possible ways to add a node, but we
        # want to choose it randomly.  I'll describe it for the case with
        # nChildren=5, so k is 6.  We know already that there are
        # nPossibleWays=25 different ways to add a node.  The complication
        # is that we could make a new group of 2, 3, or 4 nInNewGroup, and it will be
        # different numbers of possible ways in each.  The numbers of each are given by
        # p4.func.nChoosek(), so there are 10 ways to make a group of 2 from 5
        # children, 10 ways to make a group of 3 from 5 children, and 5
        # ways to make a group of 4 from 5 children.  So thats [10, 10,
        # 5], which sums to 25 (nPossibleWays).  So we can make a
        # cumulative sum list ie [10, 20, 25], and use it to choose one
        # group randomly.
        nChooseKs = []
        for i in range(2, nChildren):
            nChooseKs.append(p4.func.nChooseK(nChildren, i))
        cumSum = [nChooseKs[0]]
        for i in range(len(nChooseKs))[1:]:
            cumSum.append(nChooseKs[i] + cumSum[i - 1])
        ran = random.randrange(nPossibleWays)
        for i in range(len(cumSum)):
            if ran < cumSum[i]:
                break
        nInNewGroup = i + 2
        # Ok, so we have decided that of the nChildren of
        # theChosenPolytomy, we will make a new node with a group of
        # nInNewGroup of them.  For that, we can use random.sample().
        newChildrenNodeNums = random.sample(childrenNodeNums, nInNewGroup)
        if dbug:
            print("The nChooseKs are %s" % nChooseKs)
            print("The cumSum is %s" % cumSum)
            print("Since there are nPossibleWays=%i, we choose a random number from 0-%i" % (
                nPossibleWays, nPossibleWays - 1))
            print("->We chose a random number: %i" % ran)
            print("So we choose the group at index %i, which means nInNewGroup=%i" % (i, nInNewGroup))
            print("So we make a new node with newChildrenNodeNums %s" % newChildrenNodeNums)
            # sys.exit()
        # Choose to add a node between theChosenPolytomy and the first in
        # the list of newChildrenNodeNums.  The node that we add will be
        # chosen from pTree.nodes for the first node where both the parent
        # and the leftChild are None.
        firstNode = pTree.nodes[newChildrenNodeNums[0]]
        for newNode in pTree.nodes:
            if not newNode.parent and not newNode.leftChild:
                break
        # print "Got newNode = %i" % newNode.nodeNum
        # Add the newNode between theChosenPolytomy and firstNode
        newNode.parent = theChosenPolytomy
        newNode.leftChild = firstNode
        firstNode.parent = newNode
        if theChosenPolytomy.leftChild == firstNode:
            theChosenPolytomy.leftChild = newNode
        else:
            oldCh = theChosenPolytomy.leftChild
            while oldCh.sibling != firstNode:
                oldCh = oldCh.sibling
            oldCh.sibling = newNode
        if firstNode.sibling:
            newNode.sibling = firstNode.sibling
            firstNode.sibling = None
        pTree.setPreAndPostOrder()
        pTree._nInternalNodes += 1
        if 0:
            # pTree.setPreAndPostOrder()
            pTree.draw()
        for nodeNum in newChildrenNodeNums[1:]:
            n = pTree.pruneSubTreeWithoutParent(nodeNum)
            pTree.reconnectSubTreeWithoutParent(n, newNode)
        # Calculate the rawSplitKey and splitKey.
        # if self.mcmc.constraints:
        #     children = [n for n in newNode.iterChildren()]
        #     x = children[0].br.rawSplitKey
        #     for n in children[1:]:
        #         y = n.br.rawSplitKey
        #         x = x | y  # '|' is bitwise "OR".
        #     newNode.br.rawSplitKey = x
        #     if 1 & newNode.br.rawSplitKey: # Ie "Does rawSplitKey contain a 1?" or "Is rawSplitKey odd?"
        #         if self.mcmc.constraints:
        #             newNode.br.splitKey = self.mcmc.constraints.allOnes ^ newNode.br.rawSplitKey # "^" is xor, a bit-flipper.
        #         else:
        #             allOnes = 2**(self.propTree.nTax) - 1
        #             newNode.br.splitKey = allOnes ^ newNode.br.rawSplitKey
        #     else:
        #         newNode.br.splitKey = newNode.br.rawSplitKey
        # Its a newly-added node, possibly in a new context.  We need to
        # deal with model stuff if it isHet.  The model.isHet if any part
        # isHet.
        if dbug:
            pTree.setPreAndPostOrder()
            pTree.draw()
        # Now the Hastings ratio.  First calculate gamma_B.  If the
        # current tree is a star tree (nInternalNodes == 1) and the
        # proposed tree is not fully resolved (ie is less than
        # len(self.propTree.nodes) - 2), then gamma_B is 0.5.
        if (self.curTree.nInternalNodes == 1) and (pTree.nInternalNodes < (len(pTree.nodes) - 2)):
            gamma_B = 0.5
        # If the proposed tree is fully resolved and the current tree is not
        # the star tree
        elif (pTree.nInternalNodes == (len(pTree.nodes) - 2)) and (self.curTree.nInternalNodes > 1):
            gamma_B = 2.0
        else:
            gamma_B = 1.0
        # n_e is number of internal edges present before the Add-edge move.
        # That would be self.curTree.nInternalNodes - 1
        n_e = float(self.curTree.nInternalNodes - 1)
        # n_p is the number of polytomies present before the move,
        # len(allPolytomies)
        n_p = float(len(allPolytomies))
        hastingsRatio = (gamma_B * n_p * float(nPossibleWays)) / (1.0 + n_e)
        if dbug:
            print("The new node is given a random branch length of %f" % newNode.br.len)
            print("For the Hastings ratio ...")
            print("gamma_B is %.1f" % gamma_B)
            print("n_e is %.0f" % n_e)
            print("k is (still) %i, and (2^{k-1} - k - 1) = nPossibleWays is still %i" % (k, nPossibleWays))
            print("n_p = %.0f is the number of polytomies present before the move." % n_p)
            print("So the hastings ratio is %f" % hastingsRatio)
        self.logProposalRatio = math.log(hastingsRatio)
        if 0:
            priorRatio = theProposal.brLenPriorLambda * \
                math.exp(- theProposal.brLenPriorLambda * newNode.br.len)
            if dbug:
                print("The theProposal.brLenPriorLambda is %f" % theProposal.brLenPriorLambda)
                print("So the prior ratio is %f" % priorRatio)
            self.logPriorRatio = math.log(priorRatio)
            # The Jacobian
            jacobian = 1.0 / (theProposal.brLenPriorLambda *
                              math.exp(- theProposal.brLenPriorLambda * newNode.br.len))
            self.logJacobian = math.log(jacobian)
            print("logPriorRatio = %f, logJacobian = %f" % (self.logPriorRatio, self.logJacobian))
        # Here I pull a fast one, as explained in Lewis et al.  The
        # priorRatio and the Jacobian terms cancel out.  So the logs might
        # as well be zeros.
        self.logPriorRatio = 0.0
        #self.logJacobian = 0.0
        # That was easy, wasn't it?
        if theProposal.polytomyUseResolutionClassPrior:
            # We are gaining a node.  So the prior ratio is T_{n,m + 1} /
            # (T_{n,m} * C) .  We have the logs, and the result is the
            # log.
            if 0:
                print("-" * 30)
                print('curTree.nInternalNodes', self.curTree.nInternalNodes)
                print('pTree.nInternalNodes', pTree.nInternalNodes)
                print('logBigT[curTree.nInternalNodes]', theProposal.logBigT[self.curTree.nInternalNodes])
                # print
                # math.exp(theProposal.logBigT[self.curTree.nInternalNodes])
                print('C ', theProposal.polytomyPriorLogBigC)
                print('logBigT[pTree.nInternalNodes]', theProposal.logBigT[pTree.nInternalNodes])
                # print math.exp(theProposal.logBigT[pTree.nInternalNodes])
                print("-" * 30)
            self.logPriorRatio = (theProposal.logBigT[self.curTree.nInternalNodes] -
                                  (theProposal.polytomyPriorLogBigC +
                                   theProposal.logBigT[pTree.nInternalNodes]))
        else:
            if theProposal.polytomyPriorLogBigC:
                self.logPriorRatio = -theProposal.polytomyPriorLogBigC
            else:
                self.logPriorRatio = 0.0
        # print "gaining a node, m %2i->%2i. logPriorRatio is %f" % (self.curTree.nInternalNodes,
        # pTree.nInternalNodes, self.logPriorRatio)
    def _getCandidateNodesForDeleteEdge(self):
        pTree = self.propTree
        nodesWithInternalEdges = [n for n in pTree.iterInternalsNoRoot()]
        # Remove any that might violate constraints.
        # if self.mcmc.constraints:
        #     nodesToRemove = []
        #     for n in nodesWithInternalEdges:
        #         if n.br.splitKey in self.mcmc.constraints.constraints:
        #             nodesToRemove.append(n)
        #     for n in nodesToRemove:
        #         nodesWithInternalEdges.remove(n)
        return nodesWithInternalEdges
    def proposeDeleteEdge(self, theProposal, candidateNodes):
        dbug = False
        pTree = self.propTree
        # print "doing proposeDeleteEdge()"
        if 0:
            print("proposeDeleteEdge(), starting with this tree ...")
            pTree.draw()
            print("m There are %i internal nodes (before deleting the edge)." % pTree.nInternalNodes)
        if not candidateNodes:
            raise P4Error(
                "proposeDeleteEdge() could not find a good node to attempt to delete.")
        theChosenNode = random.choice(candidateNodes)
        if dbug:
            print("There are %i candidateNodes." % len(candidateNodes))
            print("node nums %s" % [n.nodeNum for n in candidateNodes])
            print("Randomly choose node %s" % theChosenNode.nodeNum)
        theNewParent = theChosenNode.parent
        theRightmostChild = theChosenNode.rightmostChild()
        theLeftSib = theChosenNode.leftSibling()
        if theLeftSib:
            theLeftSib.sibling = theChosenNode.leftChild
        else:
            theNewParent.leftChild = theChosenNode.leftChild
        for n in theChosenNode.iterChildren():
            n.parent = theNewParent
        theRightmostChild.sibling = theChosenNode.sibling
        theChosenNode.wipe()
        pTree.setPreAndPostOrder()
        pTree._nInternalNodes -= 1
        # print pTree.preOrder
        # if dbug:
        #    pTree.draw()
        # Hastings ratio.  First calculate the gamma_D.  If the current
        # tree is fully resolved and the proposed tree is not the star
        # tree, then gamma_D is 0.5
        if (self.curTree.nInternalNodes == len(pTree.nodes) - 2) and pTree.nInternalNodes != 1:
            gamma_D = 0.5
        # If the proposed tree is the star tree and the current tree is not
        # fully resolved
        elif (self.curTree.nInternalNodes < len(pTree.nodes) - 2) and pTree.nInternalNodes == 1:
            gamma_D = 2.
        else:
            gamma_D = 1.
        # n_e is the number of internal edges in existence before the move,
        # which would be nInternalNodes - 1
        n_e = float(self.curTree.nInternalNodes - 1)
        # nStar_p is the number of polytomies in the tree after the move.
        nStar_p = 0
        for n in pTree.iterInternalsNoRoot():
            if n.getNChildren() > 2:
                nStar_p += 1
        if pTree.root.getNChildren() > 3:
            nStar_p += 1
        nStar_p = float(nStar_p)
        # kStar is the number of edges emanating from the polytomy created (or
        # enlarged) by the move.
        kStar = theNewParent.getNChildren()
        if theNewParent.parent:
            kStar += 1
        hastingsRatio = (gamma_D * n_e) / \
            (nStar_p * (2 ** (kStar - 1) - kStar - 1))
        self.logProposalRatio = math.log(hastingsRatio)
        if 0:
            # Now the prior ratio.  The prior probability density f(nu) for a
            # branch length is lambda * exp(-lambda * nu).  To a first
            # approximation, with equal priors on topologies, the prior ratio
            # is 1/f(nu)
            priorRatio = 1.0 / (theProposal.brLenPriorLambda *
                                math.exp(- theProposal.brLenPriorLambda * theChosenNode.br.len))
            if dbug:
                print("The theProposal.brLenPriorLambda is %f" % theProposal.brLenPriorLambda)
                print("So the prior ratio is %f" % priorRatio)
            self.logPriorRatio = math.log(priorRatio)
            # The Jacobian
            jacobian = theProposal.brLenPriorLambda * \
                math.exp(- theProposal.brLenPriorLambda *
                         theChosenNode.br.len)
            self.logJacobian = math.log(jacobian)
            print("logPriorRatio = %f, logJacobian = %f" % (self.logPriorRatio, self.logJacobian))
        # Here I pull a fast one, as explained in Lewis et al.  The
        # priorRatio and the Jacobian terms cancel out.  So the logs might
        # as well be zeros.
        self.logPriorRatio = 0.0
        #self.logJacobian = 0.0
        # That was easy, wasn't it?
        if theProposal.polytomyUseResolutionClassPrior:
            # We are losing a node.  So the prior ratio is (T_{n,m} * C) /
            # T_{n,m - 1}.  We have the logs, and the result is the log.
            if 0:
                print("-" * 30)
                print('curTree.nInternalNodes', self.curTree.nInternalNodes)
                print('pTree.nInternalNodes', pTree.nInternalNodes)
                print('logBigT[curTree.nInternalNodes]', theProposal.logBigT[self.curTree.nInternalNodes])
                # print
                # math.exp(theProposal.logBigT[self.curTree.nInternalNodes])
                print('C ', theProposal.polytomyPriorLogBigC)
                print('logBigT[pTree.nInternalNodes]', theProposal.logBigT[pTree.nInternalNodes])
                # print math.exp(theProposal.logBigT[pTree.nInternalNodes])
                print("-" * 30)
            self.logPriorRatio = ((theProposal.logBigT[self.curTree.nInternalNodes] +
                                   theProposal.polytomyPriorLogBigC) -
                                  theProposal.logBigT[pTree.nInternalNodes])
        else:
            if theProposal.polytomyPriorLogBigC:
                self.logPriorRatio = theProposal.polytomyPriorLogBigC
            else:
                self.logPriorRatio = 0.0
        # print " losing a node, m %2i->%2i. logPriorRatio is %f" % (self.curTree.nInternalNodes,
        # pTree.nInternalNodes, self.logPriorRatio)
    def propose(self, theProposal):
        gm = ['STChain.propose()']
        #print("propose() About to propose %s" % theProposal.name)
        if theProposal.name == 'nni':
            # self.proposeNni(theProposal)
            self.propTree.nni()             # this does setPreAndPostOrder()
            if theProposal.doAbort:
                pass
            # else:
            #    if not self.propTree.preAndPostOrderAreValid:    # not needed
            #        self.propTree.setPreAndPostOrder()
        elif theProposal.name == 'spr':
            self.propTree.randomSpr()
            if theProposal.doAbort:
                pass
            else:
                if not self.propTree.preAndPostOrderAreValid:
                    self.propTree.setPreAndPostOrder()
        elif theProposal.name == 'SR2008beta_uniform':
            mt = self.propTree.beta
            # Slider proposal
            mt += (random.random() - 0.5) * theProposal.tuning[self.tempNum]
            # Linear reflect
            isGood = False
            myMIN = 1.e-10
            myMAX = 1.e+10
            while not isGood:
                if mt < myMIN:
                    mt = (myMIN - mt) + myMIN
                elif mt > myMAX:
                    mt = myMAX - (mt - myMAX)
                else:
                    isGood = True
            self.propTree.beta = mt
            self.logProposalRatio = 0.0
            self.logPriorRatio = 0.0
        elif theProposal.name == 'spaQ_uniform':
            mt = self.propTree.spaQ[0]
            originally = mt
            # Slider proposal
            mt += (random.random() - 0.5) * theProposal.tuning[self.tempNum]
            # Linear reflect
            isGood = False
            myMIN = 1.e-10
            myMAX = 1.
            while not isGood:
                if mt < myMIN:
                    mt = (myMIN - mt) + myMIN
                elif mt > myMAX:
                    mt = myMAX - (mt - myMAX)
                else:
                    isGood = True
            self.propTree.spaQ[0] = mt
            self.logProposalRatio = 0.0
            if theProposal.spaQPriorType == 'flat':
                self.logPriorRatio = 0.0
            elif theProposal.spaQPriorType == 'exponential':
                self.logPriorRatio = theProposal.spaQExpPriorLambda * (originally - mt)
            else:
                raise P4Error("this should not happen! wxyzz")
            # print "proposing mt from %.3f to %.3f, diff=%g" % (originally,
            # mt, mt-originally)
        elif theProposal.name == 'polytomy':
            self.proposePolytomy(theProposal)
            if not self.propTree.preAndPostOrderAreValid:
                self.propTree.setPreAndPostOrder()
            # self.propTree.draw()
        else:
            gm.append('Unlisted proposal.name=%s  Fix me.' % theProposal.name)
            raise P4Error(gm)
        if theProposal.doAbort:
            return 0.0
        # print "...about to calculate the likelihood of the propTree.
        # Model %s" % self.stMcmc.modelName
        if self.stMcmc.modelName.startswith('SR2008_rf'):
            if self.stMcmc.stRFCalc == 'fastReducedRF':
                self.getTreeLogLike_fastReducedRF()
            elif self.stMcmc.stRFCalc == 'purePython1':
                self.getTreeLogLike_ppy1()
            elif self.stMcmc.stRFCalc == 'bitarray':
                self.refreshBitarrayPropTree()
                self.getTreeLogLike_bitarray()
        elif self.stMcmc.modelName == 'SPA':
            self.refreshBitarrayPropTree()
            if var.stmcmc_useFastSpa:
                if 0:  # check
                    self.getTreeLogLike_spa_bitarray()
                    #print("Here F.  bitarray propTree.logLike is %f" % self.propTree.logLike)
                    fspaLike = self.stMcmc.fspa.calcLogLike(self.chNum)
                    diff = math.fabs(self.propTree.logLike - fspaLike)
                    #print("Got fspaLike %f, diff %g" % (fspaLike, diff))
                    if diff > 1e-13:
                        gm.append("gen %i bad fastspa likelihood calc, %f vs %f, diff %f" % (
                            self.stMcmc.gen, self.propTree.logLike, fspaLike, diff))
                        raise P4Error(gm)
                else:
                    self.propTree.logLike = self.stMcmc.fspa.calcLogLike(self.chNum)
            else:
                self.getTreeLogLike_spa_bitarray()
        elif self.stMcmc.modelName == 'QPA':
            self.getTreeLogLike_qpa_slow()
        else:
            gm.append('Unknown model %s' % self.stMcmc.modelName)
            raise P4Error(gm)
        # if theProposal.name == 'polytomy':
        #print("propTree logLike is %f, curTree logLike is %f" % (
        #    self.propTree.logLike, self.curTree.logLike))
        #myDist = self.propTree.topologyDistance(self.curTree)
        # print "myDist %2i, propTree.logLike %.3f  curTree.logLike %.3f "
        # % (myDist, self.propTree.logLike, self.curTree.logLike)
        logLikeRatio = self.propTree.logLike - self.curTree.logLike
        # print logLikeRatio
        # To run "without the data", which shows the effect of priors.
        #logLikeRatio = 0.0
        # Mcmcmc
        if self.stMcmc.nChains > 1:
            if self.stMcmc.swapVector:
                heatBeta = 1.0 / (1.0 + self.stMcmc.chainTemps[self.tempNum])
            else:
                heatBeta = 1.0 / (1.0 + self.stMcmc.chainTemp * self.tempNum)
            logLikeRatio *= heatBeta
            self.logPriorRatio *= heatBeta
            #print("propose().  chainTemp=%s, heatBeta=%f" % (self.stMcmc.chainTemps, heatBeta))
        # Experimental Heating hack
        if self.stMcmc.doHeatingHack: # and theProposal.name in self.stMcmc.heatingHackProposalNames:
            heatFactor = 1.0 / (1.0 + self.stMcmc.heatingHackTemperature)
            logLikeRatio *= heatFactor
            self.logPriorRatio *= heatFactor
        theSum = logLikeRatio + self.logProposalRatio + self.logPriorRatio
        #theSum = self.logProposalRatio + self.logPriorRatio
        # if theProposal.name == 'polytomy':
        # print "%f  %f  %f  %f" % (theSum, logLikeRatio,
        # self.logProposalRatio, self.logPriorRatio)
        return theSum
    def gen(self, aProposal):
        gm = ['STChain.gen()']
        # doAborts means that it was not a valid generation,
        # neither accepted or rejected.  Give up, by returning True.
        acceptMove = False
        # print "Doing %s" % aProposal.name
        pRet = self.propose(aProposal)
        #if self.tempNum == 0:
        #    print(self.propTree.postOrder)
        # print "pRet = %.6f" % pRet,
        if not aProposal.doAbort:
            if pRet < -100.0:  # math.exp(-100.) is 3.7200759760208361e-44
                r = 0.0
            elif pRet >= 0.0:
                r = 1.0
            else:
                r = math.exp(pRet)
            if r == 1.0:
                acceptMove = True
            elif random.random() < r:
                acceptMove = True
        # if aProposal.name == 'polytomy':
        # print "acceptMove = %s" % acceptMove
        # print "------------"
        # print " %6.0f" % pRet
        if 0 and acceptMove:
            d1 = self.propTree.topologyDistance(self.curTree, metric='scqdist')
            d2 = self.stMcmc.tree.topologyDistance(
                self.propTree, metric='scqdist')
            print(" %6.0f    %5i   %5i  %5s" % (pRet, d1, d2, acceptMove))
        aProposal.nProposals[self.tempNum] += 1
        aProposal.tnNSamples[self.tempNum] += 1
        if acceptMove:
            aProposal.accepted = True
            aProposal.nAcceptances[self.tempNum] += 1
            aProposal.tnNAccepts[self.tempNum] += 1
        # if not aProposal.doAbort:
        if acceptMove:
            a = self.propTree
            b = self.curTree
        else:
            a = self.curTree
            b = self.propTree
        if aProposal.name in ['nni', 'spr', 'polytomy']:
            b.logLike = a.logLike
            a.copyToTree(b)
        elif aProposal.name in ['SR2008beta_uniform']:
            b.logLike = a.logLike
            b.beta = a.beta
        elif aProposal.name in ['spaQ_uniform']:
            b.logLike = a.logLike
            b.spaQ[0] = a.spaQ[0]
        else:
            gm.append('Unlisted proposal.name = %s  Fix me.' % aProposal.name)
            raise P4Error(gm)
# for proposal probs
fudgeFactor = {}
fudgeFactor['nni'] = 1.0
fudgeFactor['spr'] = 1.0
fudgeFactor['SR2008beta_uniform'] = 0.1
fudgeFactor['spaQ_uniform'] = 0.1
fudgeFactor['polytomy'] = 0.5
class STMcmcTunings(object):
    def __init__(self):
        self.default = {}
        self.default['SR2008beta_uniform'] = 0.2
        self.default['spaQ_uniform'] = 0.1
        
class STMcmcProposalProbs(dict):
    """User-settable relative proposal probabilities.
    An instance of this class is made as STMcmc.prob, where you can
    do, for example,
        yourSTMcmc.prob.nni = 2.0
    These are relative proposal probs, that do not sum to 1.0, and
    affect the calculation of the final proposal probabilities (ie the
    kind that do sum to 1).  It is a relative setting, and the default
    is 1.0.  Setting it to 0 turns it off.  For small
    probabilities, setting it to 2.0 doubles it.  For bigger
    probabilities, setting it to 2.0 makes it somewhat bigger.
    Check the effect that it has by doing a
        yourSTMcmc.writeProposalIntendedProbs()
    which prints out the final calculated probabilities. 
    """
    def __init__(self):
        object.__setattr__(self, 'nni', 1.0)
        object.__setattr__(self, 'spr', 1.0)
        object.__setattr__(self, 'SR2008beta_uniform', 1.0)
        object.__setattr__(self, 'spaQ_uniform', 1.0)
        object.__setattr__(self, 'polytomy', 0.0)
    def __setattr__(self, item, val):
        # complaintHead = "\nSTMcmcProposalProbs.__setattr__()"
        gm = ["\nSTMcmcProposalProbs(). (set %s to %s)" % (item, val)]
        theKeys = self.__dict__.keys()
        if item in theKeys:
            try:
                val = float(val)
                if val < 1e-9:
                    val = 0
                object.__setattr__(self, item, val)
            except:
                gm.append("Should be a float.  Got '%s'" % val)
                raise P4Error(gm)
        else:
            self.dump()
            gm.append("    Can't set '%s'-- no such proposal." % item)
            raise P4Error(gm)
    def reprString(self):
        stuff = ["\nUser-settable relative proposal probabilities, from yourStMcmc.prob"]
        stuff.append("  To change it, do eg ")
        stuff.append("    yourSTMcmc.prob.spaQ_uniform = 0.0 # turns spaQ_uniform proposals off")
        stuff.append("  Current settings:")
        theKeys = list(self.__dict__.keys())
        theKeys.sort()
        for k in theKeys:
            stuff.append("        %20s: %s" % (k, getattr(self, k)))
        return '\n'.join(stuff)
    def dump(self):
        print(self.reprString())
    def __repr__(self):
        return self.reprString()
class STProposal(object):
    def __init__(self, theSTMcmc=None):
        self.name = None
        self.stMcmc = theSTMcmc            # reference loop
        self.nChains = theSTMcmc.nChains
        self.pNum = -1
        self.weight = 1.0
        self.tuning = None
        self.tunings = {}
        self.nProposals = [0] * self.nChains
        self.nAcceptances = [0] * self.nChains
        self.accepted = 0
        self.doAbort = False
        self.nAborts = [0] * self.nChains
        self.tnSampleSize = 250
        self.tnNSamples = [0] * theSTMcmc.nChains
        self.tnNAccepts = [0] * theSTMcmc.nChains
        self.tnAccVeryHi = None
        self.tnAccHi = None
        self.tnAccLo = None
        self.tnAccVeryLo = None
        self.tnFactorVeryHi = None
        self.tnFactorHi = None
        self.tnFactorLo = None
        self.tnFactorVeryLo = None
        self.tnFactorZero = None
    def dump(self):
        print("proposal name=%-10s pNum=%2i, weight=%5.1f, tuning=%s" % (
            self.name, self.pNum, self.weight, self.tuning))
        #print("    nProposals   by temperature:  %s" % self.nProposals)
        #print("    nAcceptances by temperature:  %s" % self.nAcceptances)
    # def _getTuning(self):
    #     if self.name in ['nni', 'spr', 'SR2008beta_uniform', 'spaQ_uniform']:
    #         # print "getting tuning for %s, returning %f" % (self.name, getattr(self.mcmc.tunings, self.name))
    #         # print self.stMcmc.tunings
    #         return getattr(self.stMcmc.tunings, self.name)
    #     else:
    #         return None
    # def _setTuning(self, whatever):
    #     raise P4Error("Can't set tuning this way.")
    # def _delTuning(self):
    #     raise P4Error("Can't del tuning.")
    # tuning = property(_getTuning, _setTuning, _delTuning)
    def tune(self, tempNum):
        assert self.tnSampleSize >= 100.
        assert self.tnNSamples[tempNum] >= self.tnSampleSize
        acc = float(self.tnNAccepts[tempNum]) / self.tnNSamples[tempNum]   # float() for Py2
        doMessage = False
        if acc > self.tnAccHi:
            oldTn = self.tuning[tempNum]
            if acc > self.tnAccVeryHi:
                self.tuning[tempNum] *= self.tnFactorVeryHi
            else:
                self.tuning[tempNum] *= self.tnFactorHi
            doMessage = True
        elif acc < self.tnAccLo:
            oldTn = self.tuning[tempNum]
            if acc < self.tnAccVeryLo:
                self.tuning[tempNum] *= self.tnFactorVeryLo
            else:
                self.tuning[tempNum] *= self.tnFactorLo
            doMessage = True
        self.tnNSamples[tempNum] = 0
        self.tnNAccepts[tempNum] = 0
        if doMessage:
            message = "%s tune  gen=%i tempNum=%i acceptance=%.3f " % (self.name, self.stMcmc.gen, tempNum, acc)
            message += "(target %.3f -- %.3f) " % (self.tnAccLo, self.tnAccHi)
            message += "Adjusting tuning from %g to %g" % (oldTn, self.tuning[tempNum])
            #print(message)
            self.stMcmc.logger.info(message)
class Proposals(object):
    def __init__(self):
        self.proposals = []
        self.proposalsDict = {}
        self.propWeights = []
        self.cumPropWeights = []
        self.totalPropWeights = 0.0
        self.intended = None
    def summary(self):
        print("There are %i proposals" % len(self.proposals))
        for p in self.proposals:
            print("proposal name=%-10s pNum=%2s, weight=%s, tuning=%s" % (
                '%s,' % p.name, p.pNum, p.weight, p.tuning))
            
    def calculateWeights(self):
        gm = ["Proposals.calculateWeights()"]
        self.propWeights = []
        for p in self.proposals:
            #print("%s: %s" % (p.name, p.weight))
            self.propWeights.append(p.weight)
        #print(self.propWeights)
        self.cumPropWeights = [self.propWeights[0]]
        for i in range(len(self.propWeights))[1:]:
            self.cumPropWeights.append(
                self.cumPropWeights[i - 1] + self.propWeights[i])
        self.totalPropWeights = sum(self.propWeights)
        if self.totalPropWeights < 1e-9:
            gm.append("No proposal weights?")
            raise P4Error(gm)
        self.intended = self.propWeights[:]
        for i in range(len(self.intended)):
            self.intended[i] /= self.totalPropWeights
        if math.fabs(sum(self.intended) - 1.0 > 1e-14):
            raise P4Error("bad sum of intended proposal probs. %s" % sum(self.intended))
        #print(self.intended)
    def chooseProposal(self, equiProbableProposals):
        if equiProbableProposals:
            return random.choice(self.proposals)
        else:
            theRan = random.uniform(0.0, self.totalPropWeights)
            for i in range(len(self.cumPropWeights)):
                if theRan < self.cumPropWeights[i]:
                    break
            return self.proposals[i]
        
    def writeProposalIntendedProbs(self):
        """Tabulate the intended proposal probabilities"""
        spacer = ' ' * 4
        print("\nIntended proposal probabilities (%)")
        print("There are %i proposals" % len(self.proposals))
        print("%2s %11s %30s %5s %12s" % ('', 'intended(%)', 'proposal', 'part', 'tuning'))
        for i in range(len(self.proposals)):
            print("%2i" % i, end=' ')
            p = self.proposals[i]
            print("   %6.2f    " % (100. * self.intended[i]), end=' ')
            print(" %27s" % p.name, end=' ')
            if p.pNum != -1:
                print(" %3i " % p.pNum, end=' ')
            else:
                print("   - ", end=' ')
            if p.tuning == None:
                print(" %12s "% '    -   ', end=' ')
            else:
                if p.tuning[0] < 0.1:
                    print(" %12.4g" % p.tuning[0], end=' ')
                elif p.tuning[0] < 1.0:
                    print(" %12.4f" % p.tuning[0], end=' ')
                elif p.tuning[0] < 10.0:
                    print(" %12.3f" % p.tuning[0], end=' ')
                elif p.tuning[0] < 1000.0:
                    print(" %12.1f" % p.tuning[0], end=' ')
                else:
                    print(" %12.2g " % p.tuning[0], end=' ')
            print()
    def writeTunings(self):
        print("Proposal tunings:")
        print("%20s %12s" % ("proposal name", "tuning"))
        for p in self.proposals:
            print("%20s" % p.name, end=' ')
            if p.tuning:
                # if p.tuning < 10.0:
                #     print("%12.3f" % p.tuning, end=' ')
                # else:
                #     print("%12.1f" % p.tuning, end=' ')
                print(p.tuning)
            else:
                print("    %4s    " % '-', end=' ')
            print()
class SwapTuner(object):
    """Continuous tuning for swap temperature"""
    def __init__(self, sampleSize):
        assert sampleSize >= 100
        self.sampleSize = sampleSize
        self.swaps01_nAttempts = 0
        self.swaps01_nSwaps = 0
        self.tnAccVeryHi = 0.18
        self.tnAccHi = 0.12
        self.tnAccLo = 0.04
        self.tnAccVeryLo = 0.01
        self.tnFactorVeryHi = 1.4
        self.tnFactorHi = 1.2
        self.tnFactorLo = 0.9
        self.tnFactorVeryLo = 0.6
        self.tnFactorZero = 0.4
    def tune(self, theMcmc):
        assert self.swaps01_nAttempts >= self.sampleSize
        acc = float(self.swaps01_nSwaps) / self.swaps01_nAttempts    # float() for Py2
        #print("SwapTuner.tune() nSwaps %i, nAttemps %i, acc %s" % (
        #    self.swaps01_nSwaps, self.swaps01_nAttempts, acc))
        doMessage = False
        direction = None
        if acc > self.tnAccHi:
            oldTn = theMcmc.chainTemp
            if acc > self.tnAccVeryHi:
                theMcmc.chainTemp *= self.tnFactorVeryHi
            else:
                theMcmc.chainTemp *= self.tnFactorHi
            doMessage = True
            direction = 'Increase'
        elif acc < self.tnAccLo:
            oldTn = theMcmc.chainTemp
            if acc == 0.0:   # no swaps at all
                theMcmc.chainTemp *= self.tnFactorZero
            elif acc < self.tnAccVeryLo:
                theMcmc.chainTemp *= self.tnFactorVeryLo
            else:
                theMcmc.chainTemp *= self.tnFactorLo
            doMessage = True
            direction = 'Decrease'
        self.swaps01_nAttempts = 0
        self.swaps01_nSwaps = 0
        if doMessage:
            message = "%s tune  gen=%i acceptance=%.3f " % ('chainTemp', theMcmc.gen, acc)
            message += "(target %.3f -- %.3f) " % (self.tnAccLo, self.tnAccHi)
            message += "%s chainTemp from %g to %g" % (direction, oldTn, theMcmc.chainTemp)
            #print(message)
            theMcmc.logger.info(message)
class STSwapTunerV(object):
    """Continuous tuning for swap temperature"""
    def __init__(self, theMcmc):
        assert var.mcmc_swapTunerSampleSize >= 100
        self.mcmc = theMcmc
        self.nChains = self.mcmc.nChains
        # These are for adjacent pairs. Eg for attempts between chains 0 and 1,
        # we increment self.nAttempts[0], ie it is indexed with the lower number
        # in the pair.
        self.nAttempts = [0] * self.nChains
        self.nSwaps = [0] * self.nChains
        self.tnAccVeryHi = 0.30
        self.tnAccHi = 0.25
        self.tnAccLo = 0.10
        self.tnAccVeryLo = 0.05
        self.tnFactorVeryHi = 1.4
        self.tnFactorHi = 1.2
        self.tnFactorLo = 0.9
        self.tnFactorVeryLo = 0.6
        self.tnFactorZero = 0.4
        self.tnLimitHi = 10.0
        self.tnLimitLo = 0.2
    def tune(self, theTempNum):
        assert self.nAttempts[theTempNum] >= var.mcmc_swapTunerSampleSize
        acc = float(self.nSwaps[theTempNum]) / self.nAttempts[theTempNum]    # float() for Py2
        # print("STSwapTunerV.tune() theTempNum %i, nSwaps %i, nAttemps %i, acc %s" % (
        #     theTempNum, self.nSwaps[theTempNum], self.nAttempts[theTempNum], acc))
        # print("tempDiffs %s" % self.mcmc.chainTempDiffs)
        # print("temps     %s" % self.mcmc.chainTemps)
        doMessage = False
        direction = None
        oldTn = self.mcmc.chainTempDiffs[theTempNum]
        if acc > self.tnAccHi:
            if self.mcmc.chainTempDiffs[theTempNum] >= self.tnLimitHi:
                direction = "no change"
            else:
                if acc > self.tnAccVeryHi:
                    self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorVeryHi
                else:
                    self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorHi
                doMessage = True
                direction = 'Increase'
        elif acc < self.tnAccLo:
            if self.mcmc.chainTempDiffs[theTempNum] <= self.tnLimitLo:
                direction = "no change"
            else:
                if acc == 0.0:   # no swaps at all
                    self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorZero
                elif acc < self.tnAccVeryLo:
                    self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorVeryLo
                else:
                    self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorLo
                doMessage = True
                direction = 'Decrease'
        self.nAttempts[theTempNum] = 0
        self.nSwaps[theTempNum] = 0
        if direction != "no change":
            if doMessage:
                message = "%s tune  gen=%i tempNum=%i acceptance=%.3f " % ('chainTemp', self.mcmc.gen, theTempNum, acc)
                message += "(target %.3f -- %.3f) " % (self.tnAccLo, self.tnAccHi)
                message += "%s chainTempDiff from %g to %g" % (direction, oldTn, self.mcmc.chainTempDiffs[theTempNum])
                #print(message)
                self.mcmc.logger.info(message)
            # Make chainTemps from chainTempDiffs
            self.mcmc.chainTemps = [0.0]
            for dNum in range(self.mcmc.nChains - 1):
                self.mcmc.chainTemps.append(self.mcmc.chainTempDiffs[dNum] + self.mcmc.chainTemps[-1])
            if doMessage:
                message = "new chainTemps gen=%i " % (self.mcmc.gen)
                for cT in self.mcmc.chainTemps:
                    message += "%10.2f" % cT
                self.mcmc.logger.info(message)
class BigTSplitStuff(object):
    # An organizer for splits on STMcmc.tree (ie bigT) internal nodes, only
    # for use with bitarray
    def __init__(self):
        self.spl = None
        self.spl2 = None
        self.theSpl = None
        self.maskedSplitWithFirstTaxOne = None
        self.onesCount = None
        self.bytes = None
    def dump(self):
        print("ss: spl=%s, spl2=%s, masked=%s, onesCount=%s" % (
            self.spl, self.spl2, self.maskedSplitWithFirstTaxOne, self.onesCount))
[docs]
class STMcmc(object):
    """An MCMC for making supertrees from a set of input trees.
    This week, it implements the Steel and Rodrigo 2008 model, with the
    alpha calculation using the approximation in Bryant and Steel 2009.
    **Arguments**
    inTrees
        A list of p4 tree objects.  You could just use ``var.trees``.
    modelName
        The SR2008 models implemented here are based on the Steel and
        Rodrigo 2008 description of a likelihood model, "Maximum
        likelihood supertrees" Syst. Biol. 57(2):243--250, 2008.  At
        the moment, they are all SR2008_rf, meaning that they use
        Robinson-Foulds distances.
        SR2008_rf_ia 
            Here 'ia' means 'ignore alpha'.  The alpha values are not
            calculated at all, as they are presumed (erroneously, but
            not too badly) to cancel out.
        SR2008_rf_aZ
            This uses the approximation for Z_T = alpha^{-1} as described
            in Equation 30 in the Bryant and Steel paper "Computing the
            distribution of a tree metric" in IEEE/ACM Transactions on
            computational biology and bioinformatics, VOL. 6, 2009.
        SR2008_rf_aZ_fb
            This is as SR2008_rf_aZ above, but additionally it allows
            beta to be a free parameter, and it is sampled.  Samples
            are written to mcmc_prams* files.
    beta
        This only applies to SR2008.  The beta is the weight as
        given in Steel and Rodrigo 2008. By default it is 1.0.
    stRFCalc 
        There are three ways to calculate the RF distances and
        likelihood, for these SR2008_rf models above --- all giving
        the same answer.
        1.  purePython1.  Slow.
        2.  bitarray, using the bitarray module.  About twice as fast
            as purePython1
        3.  fastReducedRF, written in C++ using boost and ublas.
            About 10 times faster than purePython1, but perhaps a bit
            of a bother to get going.  It needs the fastReducedRF
            module, included in the p4 source code.
        It is under control of the argument stRFCalc, which can be one
        of 'purePython1', 'bitarray', and 'fastReducedRF'.  By default
        it is purePython1, so you may want to at least install
        bitarray.
    runNum
        You may want to do more than one 'run' in the same directory,
        to facilitate convergence testing.  The first runNum would be
        0, and samples, likelihoods, and checkPoints are written to
        files with that number.
    sampleInterval
        Interval at which the chain is sampled, including writing a tree,
        and the logLike.  Plan to get perhaps 1000 samples; so if you are
        planning to make a run of 10000 generations then you might set
        sampleInterval=10.
    checkPointInterval
        Interval at which checkpoints are made.  If set to None (the
        default) it means don't make checkpoints.  My taste is to aim to
        make perhaps 2 to 4 per run.  So if you are planning to start out
        with a run of 10000 generations, you could set
        checkPointInterval=5000, which will give you 2 checkpoints.  See
        more about checkpointing below.
    To prepare for a run, instantiate an Mcmc object, for example::
        m = STMcmc(treeList, modelName='SR2008_rf_aZ_fb', stRFCalc='fastReducedRF', sampleInterval=10)
    To start it running, do this::
        # Tell it the number of generations to do
        m.run(10000)
    As it runs, it saves trees and likelihoods at sampleInterval
    intervals (actually whenever the current generation number is
    evenly divisible by the sampleInterval).
    **CheckPoints**
    Whenever the current generation number is evenly divisible by the
    checkPointInterval it will write a checkPoint file.  A checkPoint
    file is the whole MCMC, pickled.  Using a checkPoint, you can
    re-start an STMcmc from the point you left off.  Or, in the event
    of a crash, you can restart from the latest checkPoint.  But the
    most useful thing about them is that you can query checkPoints to
    get information about how the chain has been running, and about
    convergence diagnostics.
    In order to restart the MCMC from the end of a previous run:: 
        # read the last checkPoint file
        m = func.unPickleSTMcmc(0)  # runNum 0
        m.run(20000)
    Its that easy if your previous run finished properly.  However, if
    your previous run has crashed and you want to restart it from a
    checkPoint, then you will need to repair the sample output files
    to remove samples that were taken after the last checkPoint, but
    before the crash.  Fix the trees, likelihoods, prams, and sims.
    (You probably do not need to beware of confusing gen (eg 9999) and
    gen+1 (eg 10000) issues.)  When you remove trees from the tree
    files be sure to leave the 'end;' at the end-- p4 needs it, and
    will deal with it.
    The checkPoints can help with convergence testing.  To help with
    that, you can use the STMcmcCheckPointReader class.  It will print
    out a table of average standard deviations of split supports
    between 2 runs, or between 2 checkPoints from the same run.  It
    will print out tables of proposal acceptances to show whether they
    change over the course of the MCMC.
    **Making a consensus tree**
    See :class:`TreePartitions`.
    """
    def __init__(self, inTrees, bigT=None, modelName='SR2008_rf_aZ',
                 beta=1.0, spaQ=0.5, stRFCalc='purePython1',
                 nChains=1, runNum=0, sampleInterval=100,
                 checkPointInterval=None, useSplitSupport=False, verbose=True,
                 checkForOutputFiles=True, swapTuner=250):
        import p4.func  # This should not be needed, but it is.  Why?
        #print(p4.func)
        gm = ['STMcmc.__init__()']
        assert inTrees
        for t in inTrees:
            assert isinstance(t, Tree)
        if bigT:
            assert isinstance(bigT, Tree)
            assert bigT.taxNames
            bigT.stripBrLens()
            for n in bigT.iterInternalsNoRoot():
                n.name = None
        goodModelNames = ['SR2008_rf_ia', 'SR2008_rf_aZ', 'SR2008_rf_aZ_fb',
                          'SPA', 'QPA']
        if modelName not in goodModelNames:
            gm.append("Arg modelName '%s' is not recognized. " % modelName)
            gm.append("Good modelNames are %s" % goodModelNames)
            raise P4Error(gm)
        self.modelName = modelName
        self.tree = None
        self.stRFCalc = None
        if modelName.startswith("SR2008"):
            try:
                fBeta = float(beta)
            except ValueError:
                gm.append("Arg beta (%s) should be a float" % beta)
                raise P4Error(gm)
            self.beta = fBeta
            for t in inTrees:
                if t.isFullyBifurcating():
                    pass
                else:
                    gm.append("The SR2008 model wants trees that are fully bifurcating.")
                    raise P4Error(gm)
            goodSTRFCalcNames = ['purePython1', 'bitarray', 'fastReducedRF']
            if stRFCalc not in goodSTRFCalcNames:
                gm.append("Arg stRFCalc '%s' is not recognized. " % modelName)
                gm.append("Good stRFCalc names are %s" % goodSTRFCalcNames)
                raise P4Error(gm)
            self.stRFCalc = stRFCalc
        try:
            nChains = int(nChains)
        except (ValueError, TypeError):
            gm.append("nChains should be an int, 1 or more.  Got %s" % nChains)
            raise P4Error(gm)
        if nChains < 1:
            gm.append("nChains should be an int, 1 or more.  Got %s" % nChains)
            raise P4Error(gm)
        self.nChains = nChains
        self.chains = []
        self.gen = -1
        self.startMinusOne = -1
        self.chainTemp = 1.0
        self.constraints = None
        self.simulate = None
        # spaQ is a property.  Whenever it is set, it is propagated to all the chains.
        self._spaQ = None
        if modelName in ['SPA', 'QPA']:
            try:
                self._spaQ = float(spaQ)
            except ValueError:
                gm.append("Arg spaQ (%s) should be a float" % spaQ)
                raise P4Error(gm)
            self.spaQ = self._spaQ
        try:
            runNum = int(runNum)
        except (ValueError, TypeError):
            gm.append("runNum should be an int, 0 or more.  Got %s" % runNum)
            raise P4Error(gm)
        if runNum < 0:
            gm.append("runNum should be an int, 0 or more.  Got %s" % runNum)
            raise P4Error(gm)
        self.runNum = runNum
        self._setLogger()
        if checkForOutputFiles:
            # Check that we are not going to over-write good stuff
            ff = os.listdir(os.getcwd())
            hasPickle = False
            for fName in ff:
                if fName.startswith("mcmc_checkPoint_%i." % self.runNum):
                    hasPickle = True
                    break
            if hasPickle:
                gm.append("runNum is set to %i" % self.runNum)
                gm.append("There is at least one mcmc_checkPoint_%i.xxx file in this directory." % self.runNum)
                gm.append("This is a new STMcmc, and I am refusing to over-write exisiting files.")
                gm.append("Maybe you want to re-start from the latest mcmc_checkPoint_%i file?" % self.runNum)
                gm.append("Otherwise, get rid of the existing mcmc_xxx_%i.xxx files and start again." % self.runNum)
                raise P4Error(gm)
            if var.strictRunNumberChecking:
                # We want to start runs with number 0, so if runNum is more than
                # that, check that there are other runs.
                if self.runNum > 0:
                    for runNum2 in range(self.runNum):
                        hasTrees = False
                        for fName in ff:
                            if fName.startswith("mcmc_trees_%i" % runNum2):
                                hasTrees = True
                                break
                        if not hasTrees:
                            gm.append("runNum is set to %i" % self.runNum)
                            gm.append("runNums should go from zero up.")
                            gm.append("There are no mcmc_trees_%i.nex files to show that run %i has been done." % (
                                runNum2, runNum2))
                            gm.append("Set the runNum to that, first.")
                            gm.append("Or else turn var.strictRunNumberChecking off to prevent checking.")
                            raise P4Error(gm)
        self.sampleInterval = sampleInterval
        self.checkPointInterval = checkPointInterval
        self.props = Proposals()
        self.tunableProps = """SR2008beta_uniform spaQ_uniform""".split()
        # maybeTunableButNotNow  polytomy
        self.treePartitions = None
        self.likesFileName = "mcmc_likes_%i" % runNum
        self.treeFileName = "mcmc_trees_%i.nex" % runNum
        #self.simFileName = "mcmc_sims_%i" % runNum
        self.pramsFileName = "mcmc_prams_%i" % runNum
        self.writePrams = False
        if self.modelName in ['SR2008_rf_aZ_fb', "SPA", "QPA"]:
            self.writePrams = True
        self.lastTimeCheck = None
        if self.nChains > 1:
            self.swapMatrix = []
            for i in range(self.nChains):
                self.swapMatrix.append([0] * self.nChains)
            # if self.swapVector:
            #     self.swapTuner = STSwapTunerV(self)
            # else:
            #     if swapTuner:             # a kwarg
            #         myST = int(swapTuner)
            #         if myST >= 100:
            #             self.swapTuner = SwapTuner(myST)
            #         else:
            #             gm.append("The swapTuner kwarg, the sample size, should be at least 100.  Got %i." % myST)
            #             raise P4Error(gm)
            #     else:
            #         self.swapTuner = None
        else:
            self.swapMatrix = None
        self.swapTuner = None
        self._tunings = STMcmcTunings()
        self.polytomyUseResolutionClassPrior = False
        self.polytomyPriorLogBigC = 0.0
        self.prob = STMcmcProposalProbs()
        if self.modelName in ['SPA', 'QPA']:
            self.prob.polytomy = 1.0
            self.prob.spr = 0.0
        # Zap internal node names
        # for n in aTree.root.iterInternals():
        #     if n.name:
        #         n.name = None
        if not bigT:
            allNames = set()
            for t in inTrees:
                t.unsorted_taxNames = [n.name for n in t.iterLeavesNoRoot()]
                # Get the union of a set and other stuff using set.update(stuff).
                allNames.update(t.unsorted_taxNames)
            self.taxNames = list(allNames)
            # not needed, but nice for debugging
            self.taxNames.sort()
        else:
            for t in inTrees:
                t.unsorted_taxNames = [n.name for n in t.iterLeavesNoRoot()]
            self.taxNames = bigT.taxNames
        # print self.taxNames
        self.nTax = len(self.taxNames)
        if self.modelName in ['SPA'] or self.stRFCalc == 'bitarray':
            # print "self.taxNames = ", self.taxNames
            for t in inTrees:
                # print "-" * 50
                # t.draw()
                sorted_taxNames = []
                t.baTaxBits = []
                for tNum in range(self.nTax):
                    tN = self.taxNames[tNum]
                    if tN in t.unsorted_taxNames:
                        sorted_taxNames.append(tN)
                        t.baTaxBits.append(True)
                    else:
                        t.baTaxBits.append(False)
                t.taxNames = sorted_taxNames
                t.baTaxBits = bitarray.bitarray(t.baTaxBits)
                t.firstTax = t.baTaxBits.index(1)
                # print "intree baTaxBits is %s" % t.baTaxBits
                # print "intree firstTax is %i" % t.firstTax
                # Can't use Tree.makeSplitKeys(), unfortunately.  So
                # make split keys here.  STMcmc.tBits is only used for
                # the leaves, here and in
                # STChain.setupBitarrayCalcs(), and there only once,
                # during STChain.__init__().  So probably does not
                # need to be an instance attribute.  Maybe delete?
                self.tBits = [False] * self.nTax
                for n in t.iterPostOrder():
                    if n == t.root:
                        break
                    if n.isLeaf:
                        spot = self.taxNames.index(n.name)
                        self.tBits[spot] = True
                        n.stSplitKey = bitarray.bitarray(self.tBits)
                        self.tBits[spot] = False
                    else:
                        n.stSplitKey = n.leftChild.stSplitKey.copy()
                        p = n.leftChild.sibling
                        while p:
                            n.stSplitKey |= p.stSplitKey    # "or", in-place
                            p = p.sibling
                        # print "setting node %i stSplitKey to %s" %
                        # (n.nodeNum, n.stSplitKey)
                if self.stRFCalc == 'bitarray':
                    t.splSet = set()
                    for n in t.iterInternalsNoRoot():
                        # make sure splitKey[firstTax] is a '1'
                        if not n.stSplitKey[t.firstTax]:
                            n.stSplitKey.invert()
                            n.stSplitKey &= t.baTaxBits     # 'and', in-place
                            # print "inverting and and-ing node %i stSplitKey
                            # to %s" % (n.nodeNum, n.stSplitKey)
                        # bytes so that I can use it as a set element
                        t.splSet.add(n.stSplitKey.tobytes())
                if self.modelName in ['SPA']:
                    t.internals = []
                    for n in t.iterInternalsNoRoot():
                        # make sure splitKey[firstTax] is a '1'
                        if not n.stSplitKey[t.firstTax]:
                            n.stSplitKey.invert()
                            n.stSplitKey &= t.baTaxBits     # 'and', in-place
                            # print "inverting and and-ing node %i stSplitKey
                            # to %s" % (n.nodeNum, n.stSplitKey)
                        # bytes so that I can use it as a set element
                        n.stSplitKeyBytes = n.stSplitKey.tobytes()
                        t.internals.append(n)
        self.fspa = None
        if self.modelName == 'SPA' and var.stmcmc_useFastSpa:
            import p4.fastspa as fastspa
            self.fspa = fastspa.FastSpa(useSplitSupport)
            for tNum, t in enumerate(inTrees):
                self.fspa.setInTr(tNum, t.nTax, self.nTax, t.baTaxBits.to01(), t.firstTax)
                for n in t.internals:
                    if n.br and hasattr(n.br, "support"):
                        support = n.br.support
                    else:
                        support = -1.0
                    self.fspa.setInTrNo(tNum, n.stSplitKey.to01(), support)
            #self.fspa.summarizeInTrs()
        if self.modelName in ['QPA']:
            for t in inTrees:
                sorted_taxNames = []
                t.taxBits = []
                for tNum in range(self.nTax):
                    tN = self.taxNames[tNum]
                    if tN in t.unsorted_taxNames:
                        sorted_taxNames.append(tN)
                        t.taxBits.append(1 << tNum)
                    else:
                        t.taxBits.append(0)
                t.taxNames = sorted_taxNames
                # print "intree taxBits is %s" % t.taxBits
                # Can't use Tree.makeSplitKeys(), unfortunately.  So
                # make split keys here.  STMcmc.tBits is only used for
                # the leaves, here and in
                # STChain.setupBitarrayCalcs(), and there only once,
                # during STChain.__init__().  So probably does not
                # need to be an instance attribute.  Maybe delete?
                #self.tBits = [False] * self.nTax
                for n in t.iterPostOrder():
                    if n == t.root:
                        break
                    if n.isLeaf:
                        spot = self.taxNames.index(n.name)
                        #self.tBits[spot] = True
                        n.stSplitKey = 1 << spot
                        #self.tBits[spot] = False
                    else:
                        n.stSplitKey = n.leftChild.stSplitKey
                        p = n.leftChild.sibling
                        while p:
                            n.stSplitKey |= p.stSplitKey    # "or", in-place
                            p = p.sibling
                        # print "setting node %i stSplitKey to %s" % (n.nodeNum, n.stSplitKey)
                # t.splSet = set()
                # for n in t.iterInternalsNoRoot():
                #     if not n.stSplitKey[t.firstTax]:   # make sure splitKey[firstTax] is a '1'
                #         n.stSplitKey.invert()
                #         n.stSplitKey &= t.baTaxBits     # 'and', in-place
                #         #print "inverting and and-ing node %i stSplitKey to %s" % (n.nodeNum, n.stSplitKey)
                # t.splSet.add(n.stSplitKey.tobytes()) # bytes so that I can
                # use it as a set element
                t.skk = [n.stSplitKey for n in t.iterInternalsNoRoot()]
                t.qSet = set()
                for sk in t.skk:
                    ups = [txBit for txBit in t.taxBits if (sk & txBit)]
                    downs = [txBit for txBit in t.taxBits if not (sk & txBit)]
                    for down in itertools.combinations(downs, 2):
                        if down[0] > down[1]:
                            down = (down[1], down[0])
                        for up in itertools.combinations(ups, 2):
                            if up[0] > up[1]:
                                up = (up[1], up[0])
                            if down[0] < up[0]:
                                t.qSet.add(down + up)
                            else:
                                t.qSet.add(up + down)
                # print t.qSet
                t.nQuartets = len(t.qSet)
        self.trees = inTrees
        if bigT:
            self.tree = bigT
        else:
            self.tree = p4.func.randomTree(taxNames=self.taxNames, name='stTree', randomBrLens=False)
        if self.stRFCalc in ['purePython1', 'fastReducedRF']:
            for t in inTrees:
                sorted_taxNames = []
                t.taxBits = 0
                for tNum in range(self.nTax):
                    tN = self.taxNames[tNum]
                    if tN in t.unsorted_taxNames:
                        sorted_taxNames.append(tN)
                        adder = 1 << tNum
                        t.taxBits += adder
                t.taxNames = sorted_taxNames
                t.allOnes = 2 ** (t.nTax) - 1
                t.makeSplitKeys()
                t.skSet = set([n.br.splitKey for n in t.iterInternalsNoRoot()])
        if self.stRFCalc in ['purePython1', 'fastReducedRF']:
            self.tree.makeSplitKeys()
            self.Frrf = None
            if self.stRFCalc == 'fastReducedRF':
                try:
                    import p4.fastReducedRF
                    self.Frrf = p4.fastReducedRF.Frrf
                    # not explicitly used--but makes converters available
                    import pyublas
                except ImportError:
                    gm.append("var.stRFCalc is set to 'fastReducedRF', but I could not import")
                    gm.append("at least one of fastReducedRF or pyublas.")
                    gm.append("Make sure they are installed.")
                    raise P4Error(gm)
        if self.modelName in ['QPA']:
            t = self.tree
            t.taxBits = [1 << i for i in range(t.nTax)]
            for n in t.iterPostOrder():
                if n == t.root:
                    break
                if n.isLeaf:
                    spot = self.taxNames.index(n.name)
                    n.stSplitKey = 1 << spot
                else:
                    n.stSplitKey = n.leftChild.stSplitKey
                    p = n.leftChild.sibling
                    while p:
                        n.stSplitKey |= p.stSplitKey    # "or", in-place
                        p = p.sibling
            t.skk = [n.stSplitKey for n in t.iterInternalsNoRoot()]
            t.qSet = set()
            for sk in t.skk:
                ups = [txBit for txBit in t.taxBits if (sk & txBit)]
                downs = [txBit for txBit in t.taxBits if not (sk & txBit)]
                for down in itertools.combinations(downs, 2):
                    assert down[0] < down[1]   # probably not needed
                    for up in itertools.combinations(ups, 2):
                        assert up[0] < up[1]  # probably not needed
                        if down[0] < up[0]:
                            t.qSet.add(down + up)
                        else:
                            t.qSet.add(up + down)
            # print t.qSet
            t.nQuartets = len(t.qSet)
        self.useSplitSupport = False
        if useSplitSupport:
            if self.modelName.startswith("SR2008"):
                gm.append(
                    "Arg useSplitSupport is turned on, but it is not implemented with SR2008")
                raise P4Error(gm)
            assert useSplitSupport in [True, 'percent']
            self.useSplitSupport = True
            hasSplitInfo = False
            for it in self.trees:
                for n in it.iterInternalsNoRoot():
                    if not hasattr(n.br, 'support'):
                        n.br.support = None
                    else:
                        assert n.br.support == None
                    if n.name:
                        flName = float(n.name)
                        hasSplitInfo = True
                        if useSplitSupport == 'percent':
                            flName *= 0.01
                        if flName < 0.0 or flName > 1.0:
                            gm.append("Input tree %s" %
                                      it.writeNewick(toString=True))
                            gm.append(
                                "Got support value %s, outside of range 0 to 1" % n.name)
                            if flName > 1.0 and useSplitSupport == True:
                                gm.append(
                                    "Maybe it is percent support?  If so, set useSplitSupport to 'percent' rather than True")
                            raise P4Error(gm)
                        n.br.support = flName
                        #n.br.logSupport = math.log(n.br.support)
                    else:
                        #n.br.logSupport = None
                        pass
            if not hasSplitInfo:
                gm.append(
                    "Arg useSplitSupport is turned on, but none of the trees seem to have split info.")
                raise P4Error(gm)
        splash = p4.func.splash2(verbose=False)
        for aLine in splash:
            self.logger.info(aLine)
        self.swapVector = True
        if self.nChains > 1:
            self.swapTuner = STSwapTunerV(self)
        # Hidden experimental hacking
        self.doHeatingHack = False
        self.heatingHackTemperature = 5.0
        #self.heatingHackProposalNames = ['nni', 'spr']
        if verbose:
            self.loggerPrinter.info("Initializing STMcmc")
            self.loggerPrinter.info("%-16s: %s" % ('modelName', modelName))
            if self.modelName.startswith("SR2008"):
                self.loggerPrinter.info("%-16s: %s" % ('stRFCalc', self.stRFCalc))
            if self.modelName in ["SPA", "QPA"]:
                self.loggerPrinter.info("%-16s: %s" % ('useSplitSupport', self.useSplitSupport))
            self.loggerPrinter.info("%-16s: %s" % ('inTrees', len(self.trees)))
            self.loggerPrinter.info("%-16s: %s" % ('nTax', self.nTax))
            if self.nChains == 1:
                self.loggerPrinter.info("%-16s: %s" % ('mcmcmc', "off: 1 chain"))
            elif self.nChains > 1:
                self.loggerPrinter.info("%-16s: %s" % ('mcmcmc', "on -- %i chains" % self.nChains))
                if self.swapVector:
                    self.loggerPrinter.info("%-16s: %s" % ('swapVector', "on"))
                    self.loggerPrinter.info("%-16s: %s" % ('swapTuner', "on"))
    def _del_nothing(self):
        gm = ["Don't/Can't delete this property."]
        raise P4Error(gm)
    def _get_spaQ(self):
        return self._spaQ
    def _set_spaQ(self, newVal):
        try:
            newVal = float(newVal)
        except:
            gm = ['This property should be set to a float.']
            raise P4Error(gm)
        self._spaQ = newVal
        if self.chains:
            for ch in self.chains:
                ch.propTree.spaQ = newVal
                ch.curTree.spaQ = newVal
    spaQ = property(_get_spaQ, _set_spaQ, _del_nothing)
    """(property) The current spaQ"""
[docs]
    def _setLogger(self):
        """Make two loggers; one that writes to a file and to stderr, and one that writes only to a file."""
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s %(message)s',
                            datefmt='[%Y-%m-%d %H:%M]',
                            filename="mcmc_log_%i" % self.runNum,
                            filemode='a')
        # define a Handler which writes INFO messages or higher to the sys.stderr
        console = logging.StreamHandler()
        console.setLevel(logging.INFO)
        # set a format which is simpler for console use
        formatter = logging.Formatter('%(message)s')
        # tell the handler to use this format
        console.setFormatter(formatter)
        # add the handler to the root logger
        # Using named loggers allows me to keep them separate.
        self.loggerPrinter = logging.getLogger('withPrint')
        self.loggerPrinter.addHandler(console)
        # This logger only logs to the file, not to stderr.
        self.logger = logging.getLogger("logFileOnly") 
[docs]
    def _makeProposals(self):
        """Make proposals for the STMcmc."""
        gm = ['STMcmc._makeProposals()']
        # nni
        if self.prob.nni:
            p = STProposal(self)
            p.name = 'nni'
            # * (len(self.tree.nodes) - 1) * fudgeFactor['nni']
            #print(self.prob.nni)
            #print(fudgeFactor)
            
            p.weight = self.prob.nni * fudgeFactor['nni']
            self.props.proposals.append(p)
        if self.prob.spr:
            p = STProposal(self)
            p.name = 'spr'
            # * (len(self.tree.nodes) - 1) * fudgeFactor['spr']
            p.weight = self.prob.spr * fudgeFactor['spr']
            self.props.proposals.append(p)
        if self.modelName in ['SR2008_rf_aZ_fb']:
            if self.prob.SR2008beta_uniform:
                p = STProposal(self)
                p.name = 'SR2008beta_uniform'
                p.tuning = [self._tunings.default[p.name]] * self.nChains
                # * (len(self.tree.nodes) - 1) * fudgeFactor['SR2008beta_uniform']
                p.weight = self.prob.SR2008beta_uniform * fudgeFactor['SR2008beta_uniform']
                p.tnAccVeryHi = 0.7
                p.tnAccHi = 0.6
                p.tnAccLo = 0.05
                p.tnAccVeryLo = 0.03
                p.tnFactorVeryHi = 1.6
                p.tnFactorHi = 1.2
                p.tnFactorLo = 0.8
                p.tnFactorVeryLo = 0.7
                self.props.proposals.append(p)
        if self.modelName in ['SPA', 'QPA']:
            if self.prob.spaQ_uniform:
                p = STProposal(self)
                p.name = 'spaQ_uniform'
                p.tuning = [self._tunings.default[p.name]] * self.nChains
                p.spaQPriorType = 'flat'
                p.spaQExpPriorLambda = 100.
                # * (len(self.tree.nodes) - 1) * fudgeFactor['spaQ_uniform']
                p.weight = self.prob.spaQ_uniform * fudgeFactor['spaQ_uniform']
  
                p.tnAccVeryHi = 0.7
                p.tnAccHi = 0.6
                p.tnAccLo = 0.05
                p.tnAccVeryLo = 0.03
                p.tnFactorVeryHi = 1.6
                p.tnFactorHi = 1.2
                p.tnFactorLo = 0.8
                p.tnFactorVeryLo = 0.7
                self.props.proposals.append(p)
            if self.prob.polytomy:
                p = STProposal(self)
                p.name = 'polytomy'
                p.polytomyUseResolutionClassPrior = self.polytomyUseResolutionClassPrior
                p.polytomyPriorLogBigC = self.polytomyPriorLogBigC
                p.weight = self.prob.polytomy * fudgeFactor['polytomy']
                self.props.proposals.append(p)
        if not self.props.proposals:
            gm.append("No proposals?")
            raise P4Error(gm)
        for p in self.props.proposals:
            self.props.proposalsDict[p.name] = p
        self.props.calculateWeights() 
[docs]
    def _refreshProposalProbsAndTunings(self):
        """Adjust proposals after a restart."""
        gm = ['STMcmc._refreshProposalProbsAndTunings()']
        for p in self.props.proposals:
            # nni
            if p.name == 'nni':
                #p.weight = self.prob.local * (len(self.tree.nodes) - 1) * fudgeFactor['local']
                p.weight = self.prob.nni
        self.propWeights = []
        for p in self.props.proposals:
            self.propWeights.append(p.weight)
        self.cumPropWeights = [self.propWeights[0]]
        for i in range(len(self.propWeights))[1:]:
            self.cumPropWeights.append(
                self.cumPropWeights[i - 1] + self.propWeights[i])
        self.totalPropWeights = sum(self.propWeights)
        if self.totalPropWeights < 1e-9:
            gm.append("No proposal weights?")
            raise P4Error(gm) 
[docs]
    def writeProposalAcceptances(self):
        """Pretty-print the proposal acceptances."""
        if (self.gen - self.startMinusOne) <= 0:
            print("\nSTMcmc.writeProposalAcceptances()  There is no info in memory. ")
            print(" Maybe it was just emptied after writing to a checkpoint?  ")
            print("If so, read the checkPoint and get the proposalAcceptances from there.")
            return
        spacer = ' ' * 8
        print("\nProposal acceptances, run %i, for %i gens, from gens %i to %i, inclusive." % (
            self.runNum, (self.gen - self.startMinusOne), self.startMinusOne + 1, self.gen))
        print("%s %20s %10s %13s%8s" % (spacer, 'proposal', 'nProposals', 'acceptance(%)', 'tuning'))
        for p in self.props.proposals:
            print("%s" % spacer, end=' ')
            print("%20s" % p.name, end=' ')
            print("%10i" % p.nProposals[0], end=' ')
            if p.nProposals[0]:  # Don't divide by zero
                print("       %5.1f " % (100.0 * float(p.nAcceptances[0]) / float(p.nProposals[0])), end=' ')
            else:
                print("           - ", end=' ')
            if p.tuning == None:
                print("      -", end=' ')
            elif p.tuning[0] < 2.0:
                print("  %8.4f" % p.tuning[0], end=' ')
            elif p.tuning[0] < 20.0:
                print("  %8.3f" % p.tuning[0], end=' ')
            elif p.tuning[0] < 200.0:
                print("  %8.1f" % p.tuning[0], end=' ')
            else:
                print("  %8.3g" % p.tuning[0], end=' ')
            print()
        # # Tabulate topology changes by temperature
        if self.nChains > 1:
            for propName in ['nni', 'spr']:
                p = self.props.proposalsDict.get(propName)
                if p:
                    print("'%s' proposal-- topology changes by temperature" % (propName))
                    print("%s tempNum   nProps nAccepts percent" % spacer)
                    for tNum in range(self.nChains):
                        print("%s" % spacer, end=' ')
                        print("%4i " % tNum, end=' ')
                        print("%9i" % p.nProposals[tNum], end=' ')
                        print("%8i" % p.nAcceptances[tNum], end=' ')
                        print("  %5.1f" % (100.0 * float(p.nAcceptances[tNum]) / float(p.nProposals[tNum])))
        #     # Check for aborts.
        #     p = self.proposalsHash.get(propName)
        #     if p:
        #         if hasattr(p, 'nAborts'):
        #             if p.nAborts[0]:
        #                 print("The '%s' proposal had %i aborts in the cold chain." % (propName, p.nAborts[0]))
        #                 if self.constraints:
        #                     print("(Aborts might be due to violated constraints.)")
        #             else:
        #                 print("The '%s' proposal had no aborts in the cold chain" % propName)
        for pN in ['polytomy']:
            p = None
            try:
                p = self.props.proposalsDict[pN]
            except KeyError:
                pass
            if p:
                if hasattr(p, 'nAborts'):
                    print("The %s proposal had %5i aborts." % (p.name, p.nAborts[0]))
        if self.nChains > 1:
            print("\n\nAcceptances and tunings by temperature")
            print("%s %30s %5s %5s %10s %13s%10s" % (
                spacer, 'proposal', 'part', 'tempNum', 'nProposals', 'acceptance(%)', 'tuning'))
            for p in self.props.proposals:
                for tempNum in range(self.nChains):
                    print("%s" % spacer, end=' ')
                    print("%30s" % p.name, end=' ')
                    if p.pNum != -1:
                        print(" %3i " % p.pNum, end=' ')
                    else:
                        print("   - ", end=' ')
                    print(" %3i " % tempNum, end=' ')
                    print("%10i" % p.nProposals[tempNum], end=' ')
                    if p.nProposals[tempNum]:  # Don't divide by zero
                        print("       %5.1f " % (
                            100.0 * float(p.nAcceptances[tempNum]) / float(p.nProposals[tempNum])), end=' ')
                    else:
                        print("           - ", end=' ')
                    if p.tuning == None:
                        print("      -", end=' ')
                    elif p.tuning[tempNum] < 2.0:
                        print("  %8.4f" % p.tuning[tempNum], end=' ')
                    elif p.tuning[tempNum] < 20.0:
                        print("  %8.3f" % p.tuning[tempNum], end=' ')
                    elif p.tuning[tempNum] < 200.0:
                        print("  %8.1f" % p.tuning[tempNum], end=' ')
                    else:
                        print("  %8.3g" % p.tuning[tempNum], end=' ')
                    print() 
[docs]
    def writeSwapMatrix(self):
        print("\nChain swapping, for %i gens, from gens %i to %i, inclusive." % (
            (self.gen - self.startMinusOne), self.startMinusOne + 1, self.gen))
        #print("    Swaps are presented as a square matrix, nChains * nChains.")
        print("    Upper triangle is the number of swaps proposed between two chains.")
        print("    Lower triangle is the percent swaps accepted.")
        #print("    The current tunings.chainTemp is %5.3f\n" % self.chainTemp)
        #if var.mcmc_swapVector:
        #    print("    The chainTemp is continuously tuned for each chain\n")
        #else:
        #print("    The chainTemp is %f.\n" % self.chainTemp)
        print(" " * 10, end=' ')
        for i in range(self.nChains):
            print("%7i" % i, end=' ')
        print()
        print(" " * 10, end=' ')
        for i in range(self.nChains):
            print("   ----", end=' ')
        print()
        for i in range(self.nChains):
            print(" " * 7, "%2i" % i, end=' ')
            for j in range(self.nChains):
                if i < j:  # upper triangle
                    print("%7i" % self.swapMatrix[i][j], end=' ')
                elif i == j:
                    print("      -", end=' ')
                else:
                    if self.swapMatrix[j][i] == 0:  # no proposals
                        print("      -", end=' ')
                    else:
                        print("  %5.1f" % (100.0 * float(self.swapMatrix[i][j]) / float(self.swapMatrix[j][i])), end=' ')
            print() 
[docs]
    def _makeChainsAndProposals(self):
        """Make chains and proposals."""
        gm = ['STMcmc._makeChainsAndProposals()']
        # random.seed(0)
        # Make chains, if needed
        if not self.chains:
            self.chains = []
            # chNum is used by fastspa; it is also a starting point for tempNum (which changes in swaps)
            for chNum in range(self.nChains):
                aChain = STChain(self, chNum)
                self.chains.append(aChain)
        if not self.props.proposals:
            self._makeProposals()
            # If we are going to be doing the resolution class prior
            # in the polytomy move, we want to pre-compute the logs of
            # T_{n,m}.  Its a vector with indices (ie m) from zero to
            # nTax-2 inclusive.
            p = self.props.proposalsDict.get('polytomy')
            if p and self.polytomyUseResolutionClassPrior:
                bigT = p4.func.nUnrootedTreesWithMultifurcations(self.tree.nTax)
                p.logBigT = [0.0] * (self.tree.nTax - 1)
                for i in range(1, self.tree.nTax - 1):
                    p.logBigT[i] = math.log(bigT[i]) 
                #print p.logBigT
[docs]
    def _setOutputTreeFile(self):
        """Setup the (output) tree file for the STMcmc."""
        gm = ['STMcmc._setOutputTreeFile()']
        # Write the preamble for the trees outfile.
        treeFile = open(self.treeFileName, 'w')
        treeFile.write('#nexus\n\n')
        treeFile.write('begin taxa;\n')
        treeFile.write('  dimensions ntax=%s;\n' % self.tree.nTax)
        treeFile.write('  taxlabels')
        for tN in self.tree.taxNames:
            treeFile.write(' %s' % p4.func.nexusFixNameIfQuotesAreNeeded(tN))
        treeFile.write(';\nend;\n\n')
        treeFile.write('begin trees;\n')
        self.translationHash = {}
        i = 1
        for tName in self.tree.taxNames:
            self.translationHash[tName] = i
            i += 1
        treeFile.write('  translate\n')
        for i in range(self.tree.nTax - 1):
            treeFile.write('    %3i %s,\n' % (
                i + 1, p4.func.nexusFixNameIfQuotesAreNeeded(self.tree.taxNames[i])))
        treeFile.write('    %3i %s\n' % (
            self.tree.nTax, p4.func.nexusFixNameIfQuotesAreNeeded(self.tree.taxNames[-1])))
        treeFile.write('  ;\n')
        treeFile.write('  [Tree numbers are gen+1]\n')
        treeFile.close() 
[docs]
    def run(self, nGensToDo, verbose=True, equiProbableProposals=False, writeSamples=True):
        """Start the STMcmc running."""
        gm = ['STMcmc.run()']
        # Hidden experimental hack
        if self.doHeatingHack:
            print("Heating hack is turned on.")
            assert self.nChains == 1, "MCMCMC does not work with the heating hack"
            print("Heating hack temperature is %.2f" % self.heatingHackTemperature)
            #print("Heating hack affects proposals %s" % self.heatingHackProposalNames)
        # Keep track of the first gen of this call to run(), maybe restart
        firstGen = self.gen + 1
        if self.checkPointInterval:
            # We want a couple of things:
            #  1.  The last gen should be on checkPointInterval.  For
            #      example, if the checkPointInterval is 200, then doing
            #      100 or 300 generations will not be allowed cuz the
            #      chain would continue past the checkPoint-- bad.  Or if
            #      you re-start after 500 gens and change to a
            #      checkPointInterval of 200, then you won't be allowed to
            #      do 500 gens.
            # if ((self.gen + 1) + nGensToDo) % self.checkPointInterval == 0:
            if nGensToDo % self.checkPointInterval == 0:
                pass
            else:
                gm.append(
                    "With the current settings, the last generation won't be on a checkPointInterval.")
                gm.append("self.gen+1=%i, nGensToDo=%i, checkPointInterval=%i" % ((self.gen + 1),
                                                                                  nGensToDo, self.checkPointInterval))
                raise P4Error(gm)
            #  2.  We also want the checkPointInterval to be evenly
            #      divisible by the sampleInterval.
            if self.checkPointInterval % self.sampleInterval == 0:
                pass
            else:
                gm.append(
                    "The checkPointInterval (%i) should be evenly divisible" % self.checkPointInterval)
                gm.append("by the sampleInterval (%i)." % self.sampleInterval)
                raise P4Error(gm)
        if self.props.proposals:
            # Its either a re-start, or it has been thru autoTune().
            # I can tell the difference by self.gen, which is -1 after
            # autoTune()
            if self.gen == -1:
                self._makeChainsAndProposals()
                self._setOutputTreeFile()
                # if self.simulate:
                #    self.writeSimFileHeader(self.tree)
            # The probs and tunings may have been changed by the user.
            self._refreshProposalProbsAndTunings()
            # This stuff below should be the same as is done after pickling,
            # see below.
            self.startMinusOne = self.gen
            # Start the tree partitions over.
            self.treePartitions = None
            # Zero the proposal counts
            for p in self.props.proposals:
                p.nProposals = [0] * self.nChains
                p.nAcceptances = [0] * self.nChains
                #p.nTopologyChangeAttempts = [0] * self.nChains
                #p.nTopologyChanges = [0] * self.nChains
            # Zero the swap matrix
            if self.nChains > 1:
                self.swapMatrix = []
                for i in range(self.nChains):
                    self.swapMatrix.append([0] * self.nChains)
        else:
            self._makeChainsAndProposals()
            self._setOutputTreeFile()
            # if self.simulate:
            #    self.writeSimFileHeader(self.tree)
            # The swap vector is just the diagonal of the swap matrix
            if self.swapVector and self.nChains > 1:
                # These are differences in temperatures between adjacent chains.  The last one is not used.
                self.chainTempDiffs = [self.chainTemp] * self.nChains 
                # These are cumulative, summed over the diffs.  This needs to be done whenever the diffs change
                self.chainTemps = [0.0]
                for dNum in range(self.nChains - 1):
                    self.chainTemps.append(self.chainTempDiffs[dNum] + self.chainTemps[-1])
        if verbose:
            self.props.writeProposalIntendedProbs()
            sys.stdout.flush()
        coldChainNum = 0            
        # If polytomy is turned on, then it is possible to get a star
        # tree, in which case local will not work.  So if we have both
        # polytomy and local proposals, we should also have brLen.
        # if "polytomy" in self.proposalsHash and 'local' in self.proposalsHash:
        #     if 'brLen' not in self.proposalsHash:
        #         gm.append("If you have polytomy and local proposals, you should have a brLen proposal as well.")
        #         gm.append("It can have a low proposal probability, but it needs to be there.")
        #         gm.append("Turn it on by eg yourMcmc.prob.brLen = 0.001")
        #         raise P4Error(gm)
        if self.gen > -1:
            # it is a re-start, so we need to back over the "end;" in the tree
            # files.
            f2 = open(self.treeFileName, 'r+b')
            pos = -1
            while 1:
                f2.seek(pos, 2)
                c = f2.read(1)
                if c == b';':
                    break
                pos -= 1
            # print "pos now %i" % pos
            pos -= 3  # end;
            f2.seek(pos, 2)
            c = f2.read(4)
            # print "got c = '%s'" % c
            if c != b"end;":
                gm.append("Stmcmc.run().  Failed to find and remove the 'end;' at the end of the tree file.")
                raise P4Error(gm)
            else:
                f2.seek(pos, 2)
                f2.truncate()
            f2.close()
            self.logger.info("Re-starting the ST MCMC run %i from gen=%i" % (self.runNum, self.gen))
            if verbose:
                print()
                print("Re-starting the ST MCMC run %i from gen=%i" % (self.runNum, self.gen))
                if not writeSamples:
                    print("Arg 'writeSamples' is off" )
                print("Set to do %i more generations." % nGensToDo)
                # if self.writePrams:
                #    if self.chains[0].curTree.model.nFreePrams == 0:
                #        print "There are no free prams in the model, so I am turning writePrams off."
                #        self.writePrams = False
                sys.stdout.flush()
            self.startMinusOne = self.gen
        else:
            self.logger.info("Starting the ST MCMC %s run %i" % ((self.constraints and "(with constraints)" or ""), self.runNum))
            self.logger.info("nChains %i" % (self.nChains))
            ret = self.props.proposalsDict.get('polytomy')
            if ret:
                message = "Doing polytomy proposal, with polytomyUseResolutionClassPrior=%s" % ret.polytomyUseResolutionClassPrior
                self.loggerPrinter.info(message)
                message = "polytomy: polytomyPriorLogBigC=%f" % ret.polytomyPriorLogBigC
                self.loggerPrinter.info(message)
            if verbose:
                if self.nChains > 1:
                    print("Using Metropolis-coupled MCMC, with %i chains.  Temperature %f." % (self.nChains, self.chainTemp))
                else:
                    print("Not using Metropolis-coupled MCMC.")
                self.loggerPrinter.info("Starting the ST MCMC %s run %i" % ((self.constraints and "(with constraints)" or ""), self.runNum))
                self.loggerPrinter.info("Set to do %i generations." % nGensToDo)
                if self.writePrams:
                    # if self.chains[0].curTree.model.nFreePrams == 0:
                    #     print "There are no free prams in the model, so I am turning writePrams off."
                    #     self.writePrams = False
                    # else:
                    pramsFile = open(self.pramsFileName, 'a')
                    if self.modelName.startswith("SR2008"):
                        pramsFile.write("    genPlus1     beta\n")
                    elif self.modelName.startswith("SPA"):
                        pramsFile.write("    genPlus1     spaQ\n")
                    elif self.modelName.startswith("QPA"):
                        pramsFile.write("    genPlus1     spaQ\n")
                    pramsFile.close()
                sys.stdout.flush()
        if not writeSamples:
            self.logger.info("STMcmc.run() arg 'writeSamples' is off, so samples are not being written")
        if equiProbableProposals:
            self.logger.info("STMcmc.run() arg 'equiProbableProposals' is turned on")
        if verbose:
            if writeSamples:
                print("Sampling every %i." % self.sampleInterval)
            else:
                print("Arg 'writeSamples' is off, so samples are not being written")
            if equiProbableProposals:
                print("Arg 'equiProbableProposals' is turned on")
            if self.checkPointInterval:
                print("CheckPoints written every %i." % self.checkPointInterval)
            if nGensToDo <= 20000:
                print("One dot is 100 generations.")
            else:
                print("One dot is 1000 generations.")
            sys.stdout.flush()
        self.treePartitions = None
        realTimeStart = time.time()
        self.lastTimeCheck = time.time()
        abortableProposals = ['nni', 'spr', 'polytomy']
        ##################################################
        ############### Main loop ########################
        ##################################################
        self.swapInterval = 2
        for gNum in range(nGensToDo):
            self.gen += 1
            # Do an initial time estimate based on 100 gens
            if nGensToDo > 100 and self.gen - firstGen == 100:
                diff_secs = time.time() - realTimeStart
                total_secs = (float(nGensToDo) / float(100)) * float(diff_secs)
                deltaTime = datetime.timedelta(seconds=int(round(total_secs)))
                print("Estimated completion time: %s days, %s" % (
                    deltaTime.days, time.strftime("%H:%M:%S", time.gmtime(deltaTime.seconds))))
            # Above is a list of proposals where it is possible to abort.
            # When a gen(aProposal) is made, below, aProposal.doAbort
            # might be set, in which case we want to skip it for this
            # gen.  But we want to start each 'for chNum' loop with
            # doAborts all turned off.
            for chNum in range(self.nChains):
                failure = True
                nAttempts = 0
                while failure:
                    # Get the next proposal
                    gotIt = False
                    safety = 0
                    while not gotIt:
                        # equiProbableProposals is True or False.  Usually False.
                        aProposal = self.props.chooseProposal(equiProbableProposals)
                        if aProposal:
                            gotIt = True
                        if aProposal.name == 'nni':
                            # Can't do nni on a star tree.
                            if self.chains[chNum].curTree.nInternalNodes == 1:
                                #aProposal = self.props.proposalsDict['polytomy']
                                gotIt = False
                        if aProposal.doAbort:
                            gotIt = False
                        safety += 1
                        if safety > 1000:
                            gm.append(
                                "Could not find a proposal after %i attempts." % safety)
                            gm.append("Possibly a programming error.")
                            gm.append(
                                "Or possibly it is just a pathologically frustrating Mcmc.")
                            raise P4Error(gm)
                    if 0:
                        print("==== gNum=%i, chNum=%i, aProposal=%s" % (
                            gNum, chNum, aProposal.name), end=' ')
                        sys.stdout.flush()
                        # print gNum,
                    # success returns None
                    failure = self.chains[chNum].gen(aProposal)
                    if failure:
                        myWarn = "STMcmc.run() main loop.  Proposal %s generated a 'failure'.  Why?" % aProposal.name
                        self.logger.warning(myWarn)
                    if 0:
                        if failure:
                            print("    failure")
                        else:
                            print()
                    nAttempts += 1
                    if nAttempts > 1000:
                        gm.append("Was not able to do a successful generation after %i attempts." % nAttempts)
                        raise P4Error(gm)
                    # Continuous tuning.  We have a tuning, and propose/accept
                    # tallies for each temperature, kept separately.  Note that
                    # since chNum does not equal tempNum, the most recently
                    # incremented values (and the ones we want to tune now) will
                    # be aProposal.tnNSamples[tempNum] and
                    # aProposal.tnNAccepts[tempNum], where tempNum will most
                    # likely not be chNum.  So we get the tempNum from this
                    # chNum, and tune it.
                    if aProposal.name in self.tunableProps:
                        tempNum = self.chains[chNum].tempNum
                        if aProposal.tnNSamples[tempNum] >= aProposal.tnSampleSize:
                            aProposal.tune(tempNum)
                # print "   Mcmc.run(). finished a gen on chain %i" % (chNum)
                for prNm in abortableProposals:
                    ret = self.props.proposalsDict.get(prNm)
                    if ret:
                        ret.doAbort = False
            # Do swap, if there is more than 1 chain.
            if (self.gen + 1) % self.swapInterval == 0:
                if self.nChains == 1:
                    coldChain = 0
                else:
                    if self.swapVector:
                        rTempNum1 = random.randrange(self.nChains - 1)
                        rTempNum2 = rTempNum1 + 1
                        chain1 = None
                        chain2 = None
                        for ch in self.chains:
                            if ch.tempNum == rTempNum1:
                                chain1 = ch
                            elif ch.tempNum == rTempNum2:
                                chain2 = ch
                        assert chain1 and chain2
                        # Use the upper triangle of swapMatrix for nAttempts
                        self.swapMatrix[chain1.tempNum][chain2.tempNum] += 1
                        lnR = (1.0 / (1.0 + (self.chainTemps[chain1.tempNum]))
                                ) * chain2.curTree.logLike
                        lnR += (1.0 / (1.0 + (self.chainTemps[chain2.tempNum]))
                                ) * chain1.curTree.logLike
                        lnR -= (1.0 / (1.0 + (self.chainTemps[chain1.tempNum]))
                                ) * chain1.curTree.logLike
                        lnR -= (1.0 / (1.0 + (self.chainTemps[chain2.tempNum]))
                                ) * chain2.curTree.logLike
                        # # An alternative calculation
                        # heatBeta1 = 1.0 / (1.0 + self.chainTemps[chain1.tempNum])
                        # heatBeta2 = 1.0 / (1.0 + self.chainTemps[chain2.tempNum])
                        # likeRatio12 = (chain2.curTree.logLike - chain1.curTree.logLike) * heatBeta1
                        # likeRatio21 = (chain1.curTree.logLike - chain2.curTree.logLike) * heatBeta2
                        # lnR2 = likeRatio12 + likeRatio21
                        # rDiff = math.fabs(lnR - lnR2)
                        # if rDiff > 1e-12:
                        #     print("bad swap rDiff %f (%g)   lnR=%f, lnR2=%f" % (rDiff, rDiff, lnR, lnR2))
                        # lnR = lnR2
                        if lnR < -100.0:
                            r = 0.0
                        elif lnR >= 0.0:
                            r = 1.0
                        else:
                            r = math.exp(lnR)
                        acceptSwap = 0
                        if random.random() < r:
                            acceptSwap = 1
                        # self.logger.info("swap proposed gen=%i between tempNum1=%i chNum1=%i temp1=%f and tempNum2=%i chNum2=%i temp2=%f acceptSwap=%s" % (
                        #     self.gen, rTempNum1, chain1.chNum, self.chainTemps[chain1.tempNum], 
                        #     rTempNum2, chain2.chNum, self.chainTemps[chain2.tempNum], acceptSwap))
                        # for continuous temperature tuning with self.swapTuner
                        if self.swapTuner:
                            # Index the nAttempts and nSwaps with the lower of the two tempNum's, which would be chain1.tempNum
                            self.swapTuner.nAttempts[chain1.tempNum] += 1
                            if acceptSwap:
                                self.swapTuner.nSwaps[chain1.tempNum] += 1
                            if self.swapTuner.nAttempts[chain1.tempNum] >= var.mcmc_swapTunerSampleSize:
                                self.swapTuner.tune(chain1.tempNum)
                                # tune() zeros nAttempts and nSwaps counters
                        if acceptSwap:
                            # Use the lower triangle of swapMatrix to keep track of
                            # nAccepted's
                            assert chain1.tempNum < chain2.tempNum
                            self.swapMatrix[chain2.tempNum][chain1.tempNum] += 1
                            # Do the swap
                            chain1.tempNum, chain2.tempNum = chain2.tempNum, chain1.tempNum
                    else:     # swap matrix
                        # Chain swapping stuff was lifted from MrBayes.  Thanks again.
                        chain1, chain2 = random.sample(self.chains, 2)
                        # Use the upper triangle of swapMatrix for nProposed's
                        if chain1.tempNum < chain2.tempNum:
                            self.swapMatrix[chain1.tempNum][chain2.tempNum] += 1
                            thisCh1Temp = chain1.tempNum
                            thisCh2Temp = chain2.tempNum
                        else:
                            self.swapMatrix[chain2.tempNum][chain1.tempNum] += 1
                            thisCh1Temp = chain2.tempNum
                            thisCh2Temp = chain1.tempNum
                        lnR = (1.0 / (1.0 + (self.chainTemp * chain1.tempNum))
                                ) * chain2.curTree.logLike
                        lnR += (1.0 / (1.0 + (self.chainTemp * chain2.tempNum))
                                ) * chain1.curTree.logLike
                        lnR -= (1.0 / (1.0 + (self.chainTemp * chain1.tempNum))
                                ) * chain1.curTree.logLike
                        lnR -= (1.0 / (1.0 + (self.chainTemp * chain2.tempNum))
                                ) * chain2.curTree.logLike
                        if lnR < -100.0:
                            r = 0.0
                        elif lnR >= 0.0:
                            r = 1.0
                        else:
                            r = math.exp(lnR)
                        acceptSwap = 0
                        if random.random() < r:
                            acceptSwap = 1
                        # for continuous temperature tuning with self.swapTuner
                        if self.swapTuner and thisCh1Temp == 0 and thisCh2Temp == 1:
                            self.swapTuner.swaps01_nAttempts += 1
                            if acceptSwap:
                                self.swapTuner.swaps01_nSwaps += 1
                            if self.swapTuner.swaps01_nAttempts >= self.swapTuner.sampleSize:
                                self.swapTuner.tune(self)
                                # tune() zeros nAttempts and nSwaps counters
                        if acceptSwap:
                            # Use the lower triangle of swapMatrix to keep track of
                            # nAccepted's
                            if chain1.tempNum < chain2.tempNum:
                                self.swapMatrix[chain2.tempNum][chain1.tempNum] += 1
                            else:
                                self.swapMatrix[chain1.tempNum][chain2.tempNum] += 1
                            # Do the swap
                            chain1.tempNum, chain2.tempNum = chain2.tempNum, chain1.tempNum
                    # Find the cold chain, the one where tempNum is 0
                    coldChainNum = -1
                    for i in range(len(self.chains)):
                        if self.chains[i].tempNum == 0:
                            coldChainNum = i
                            break
                    if coldChainNum == -1:
                        gm.append("Unable to find which chain is the cold chain.  Bad.")
                        raise P4Error(gm)
            # If it is a writeInterval, write stuff
            if (self.gen + 1) % self.sampleInterval == 0:
                if writeSamples:
                    likesFile = open(self.likesFileName, 'a')
                    likesFile.write(
                        '%11i %f\n' % (self.gen + 1, self.chains[coldChainNum].curTree.logLike))
                    likesFile.close()
                    treeFile = open(self.treeFileName, 'a')
                    treeFile.write("  tree t_%i = [&U] " % (self.gen + 1))
                    self.chains[coldChainNum].curTree.writeNewick(treeFile,
                                                                  withTranslation=1,
                                                                  translationHash=self.translationHash,
                                                                  doMcmcCommandComments=False)
                    treeFile.close()
                if writeSamples and self.writePrams:
                    pramsFile = open(self.pramsFileName, 'a')
                    #pramsFile.write("%12i " % (self.gen + 1))
                    pramsFile.write("%12i" % (self.gen + 1))
                    if self.modelName.startswith("SR2008"):
                        pramsFile.write(
                            "  %f\n" % self.chains[coldChainNum].curTree.beta)
                    elif self.modelName in ["SPA", "QPA"]:
                        pramsFile.write(
                            "  %f\n" % self.chains[coldChainNum].curTree.spaQ)
                    pramsFile.close()
                # Do a simulation
                if self.simulate:
                    # print "about to simulate..."
                    self.doSimulate(self.chains[coldChainNum].curTree)
                    # print "...finished simulate."
                # Do other stuff.
                if hasattr(self, 'hook'):
                    self.hook(self.chains[coldChainNum].curTree)
                if 0 and self.constraints:
                    print("Mcmc x1c")
                    print(self.chains[0].verifyIdentityOfTwoTreesInChain())
                    print("b checking curTree ..")
                    self.chains[0].curTree.checkSplitKeys()
                    print("b checking propTree ...")
                    self.chains[0].propTree.checkSplitKeys()
                    print("Mcmc xxx")
                # Add curTree to treePartitions
                if self.treePartitions:
                    self.treePartitions._getSplitsFromTree(
                        self.chains[coldChainNum].curTree)
                else:
                    self.treePartitions = TreePartitions(
                        self.chains[coldChainNum].curTree)
                # After _getSplitsFromTree, need to follow, at some point,
                # with _finishSplits().  Do that when it is pickled, or at the
                # end of the run.
                # Checking and debugging constraints
                if 0 and self.constraints:
                    print("Mcmc x1d")
                    print(self.chains[coldChainNum].verifyIdentityOfTwoTreesInChain())
                    print("c checking curTree ...")
                    self.chains[coldChainNum].curTree.checkSplitKeys()
                    print("c checking propTree ...")
                    self.chains[coldChainNum].propTree.checkSplitKeys()
                    # print "c checking that all constraints are present"
                    #theSplits = [n.br.splitKey for n in self.chains[0].curTree.iterNodesNoRoot()]
                    # for sk in self.constraints.constraints:
                    #    if sk not in theSplits:
                    #        gm.append("split %i is not present in the curTree." % sk)
                    #        raise P4Error(gm)
                    print("Mcmc zzz")
                # Check that the curTree has all the constraints
                if self.constraints:
                    splitsInCurTree = [
                        n.br.splitKey for n in self.chains[coldChainNum].curTree.iterInternalsNoRoot()]
                    for sk in self.constraints.constraints:
                        if sk not in splitsInCurTree:
                            gm.append("Programming error.")
                            gm.append(
                                "The current tree (the last tree sampled) does not contain constraint")
                            gm.append(
                                "%s" % p4.func.getSplitStringFromKey(sk, self.tree.nTax))
                            raise P4Error(gm)
                # If it is a checkPointInterval, pickle
                if self.checkPointInterval and (self.gen + 1) % self.checkPointInterval == 0:
                    self.checkPoint()
                    # The stuff below needs to be done in a re-start as well.
                    # See above "if self.proposals:"
                    self.startMinusOne = self.gen
                    # Start the tree partitions over.
                    self.treePartitions = None
                    # Zero the proposal counts
                    for p in self.props.proposals:
                        p.nProposals = [0] * self.nChains
                        p.nAcceptances = [0] * self.nChains
                        #p.nTopologyChangeAttempts = [0] * self.nChains
                        #p.nTopologyChanges = [0] * self.nChains
                        p.nAborts = [0] * self.nChains
                    # Zero the swap matrix
                    if self.nChains > 1:
                        self.swapMatrix = []
                        for i in range(self.nChains):
                            self.swapMatrix.append([0] * self.nChains)
            # Reassuring pips ...
            # We want to skip the first gen of every call to run()
            if firstGen != self.gen:
                if nGensToDo <= 20000:
                    if (self.gen - firstGen) % 1000 == 0:
                        if verbose:
                            deltaTime = self._doTimeCheck(
                                nGensToDo, firstGen, 1000)
                            if deltaTime.days:
                                timeString = "%s days, %s" % (
                                    deltaTime.days, time.strftime("%H:%M:%S", time.gmtime(deltaTime.seconds)))
                            else:
                                timeString = time.strftime(
                                    "%H:%M:%S", time.gmtime(deltaTime.seconds))
                            print("%10i - %s" % (self.gen, timeString))
                        else:
                            sys.stdout.write(".")
                            sys.stdout.flush()
                    elif (self.gen - firstGen) % 100 == 0:
                        sys.stdout.write(".")
                        sys.stdout.flush()
                else:
                    if (self.gen - firstGen) % 50000 == 0:
                        if verbose:
                            deltaTime = self._doTimeCheck(
                                nGensToDo, firstGen, 50000)
                            if deltaTime.days:
                                timeString = "%s days, %s" % (
                                    deltaTime.days, time.strftime("%H:%M:%S", time.gmtime(deltaTime.seconds)))
                            else:
                                timeString = time.strftime(
                                    "%H:%M:%S", time.gmtime(deltaTime.seconds))
                            print("%10i - %s" % (self.gen, timeString))
                        else:
                            sys.stdout.write(".")
                            sys.stdout.flush()
                    elif (self.gen - firstGen) % 1000 == 0:
                        sys.stdout.write(".")
                        sys.stdout.flush()
        # Gens finished.  Clean up.
        print()
        if verbose:
            print("Finished %s generations." % nGensToDo)
        treeFile = open(self.treeFileName, 'a')
        treeFile.write('end;\n\n')
        treeFile.close() 
[docs]
    def _doTimeCheck(self, nGensToDo, firstGen, genInterval):
        """Time check 
        firstGen is the first generation of this call to Mcmc.run() else
        timing fails on restart"""
        nowTime = time.time()
        diff_secs = nowTime - self.lastTimeCheck
        total_secs = (float(nGensToDo - (self.gen - firstGen)) /
                      float(genInterval)) * float(diff_secs)
        deltaTime = datetime.timedelta(seconds=int(round(total_secs)))
        self.lastTimeCheck = nowTime
        return deltaTime 
[docs]
    def checkPoint(self):
        # Maybe we should not save the inTrees? -- would make it more
        # lightweight.
        if 0:
            for chNum in range(self.nChains):
                ch = self.chains[chNum]
                print("chain %i ==================" % chNum)
                ch.curTree.summarizeModelComponentsNNodes()
        # the Frrf object does not pickle
        savedFrrfs = []
        savedBigTrs = []
        if self.stRFCalc == 'fastReducedRF':
            for chNum in range(self.nChains):
                ch = self.chains[chNum]
                savedFrrfs.append(ch.frrf)
                ch.frrf = None
                savedBigTrs.append(ch.bigTr)
                ch.bigTr = None
        # The logger does not pickle
        savedLogger = self.logger
        self.logger = None
        savedLoggerPrinter = self.loggerPrinter
        self.loggerPrinter = None
        # The FastSpa stuff does not pickle
        if self.modelName.startswith("SPA") and var.stmcmc_useFastSpa:
            savedFspa = self.fspa
            self.fspa = None
        # _io.TextIOWrapper objects (as returned by open(fileName)) cannot be
        # copied ("serialized"), even if they are closed.  So save them and
        # restore them.
        # Except that this is not needed, as it is not really used.  But I will
        # leave this comment here for when I do start to use it.
        # savedTreeFile = self.treeFile
        theCopy = copy.deepcopy(self)
        self.logger = savedLogger
        self.loggerPrinter = savedLoggerPrinter
        theCopy.treePartitions._finishSplits()
        # assert theCopy.treeFile == None
        # theCopy.treePartitions = None    # this can be the biggest part of
        # the pickle.
        # Pickle it.
        fName = "mcmc_checkPoint_%i.%i" % (self.runNum, self.gen + 1)
        f = open(fName, 'wb')
        pickle.dump(theCopy, f, pickle.HIGHEST_PROTOCOL)
        f.close()
        if self.stRFCalc == 'fastReducedRF':
            for chNum in range(self.nChains):
                ch = self.chains[chNum]
                ch.frrf = savedFrrfs[chNum]
                ch.bigTr = savedBigTrs[chNum]
        if self.modelName.startswith("SPA") and var.stmcmc_useFastSpa:
            self.fspa = savedFspa 
[docs]
    def writeProposalProbs(self):
        """(Another) Pretty-print the proposal probabilities.
        See also STMcmc.writeProposalAcceptances().
        """
        nProposals = len(self.props.proposals)
        if not nProposals:
            print("STMcmc.writeProposalProbs().  No proposals (yet?).")
            return
        nAttained = [0] * nProposals
        nAccepted = [0] * nProposals
        for i in range(nProposals):
            nAttained[i] = self.props.proposals[i].nProposals[0]
            nAccepted[i] = self.props.proposals[i].nAcceptances[0]
        sumAttained = float(sum(nAttained))  # should be zero or nGen
        if not sumAttained:
            print("STMcmc.writeProposalProbs().  No proposals have been made.")
            print("Possibly, due to it being a checkPoint interval, nProposals have all been set to zero.")
            return
        # assert int(sumAttained) == self.gen + 1, "sumAttained is %i, should be gen+1, %i." % (
        #    int(sumAttained), self.gen + 1)
        probAttained = []
        for i in range(len(nAttained)):
            probAttained.append(100.0 * float(nAttained[i]) / sumAttained)
        if math.fabs(sum(probAttained) - 100.0 > 1e-13):
            raise P4Error(
                "bad sum of attained proposal probs. %s" % sum(probAttained))
        spacer = ' ' * 4
        print("\nProposal probabilities (%)")
        # print "There are %i proposals" % len(self.proposals)
        print("For %i gens, from gens %i to %i, inclusive." % (
            (self.gen - self.startMinusOne), self.startMinusOne + 1, self.gen))
        print("%2s %11s %11s  %11s %10s %23s" % ('', 'nProposals', 'proposed(%)',
                                                         'accepted(%)', 'tuning', 'proposal'))
        for i in range(len(self.props.proposals)):
            print("%2i" % i, end=' ')
            p = self.props.proposals[i]
            print("   %7i " % self.props.proposals[i].nProposals[0], end=' ')
            print("   %5.1f    " % probAttained[i], end=' ')
            if nAttained[i]:
                print("   %5.1f   " % (100.0 * float(nAccepted[i]) / float(nAttained[i])), end=' ')
            else:
                print("       -   ", end=' ')
            if p.tuning == None:
                print("      -   ", end=' ')
            elif p.tuning[0] < 2.0:
                print("  %8.4f" % p.tuning[0], end=' ')
            elif p.tuning[0] < 20.0:
                print("  %8.3f" % p.tuning[0], end=' ')
            elif p.tuning[0] < 200.0:
                print("  %8.1f" % p.tuning[0], end=' ')
            else:
                print("  %8.3g" % p.tuning[0], end=' ')
            print("%23s " % p.name, end=' ')
            print() 
 
    # def writeProposalIntendedProbs(self):
    #     """Tabulate the intended proposal probabilities.
    #     """
    #     nProposals = len(self.proposals)
    #     if not nProposals:
    #         print("STMcmc.writeProposalIntendedProbs().  No proposals (yet?).")
    #         return
    #     intended = self.propWeights[:]
    #     for i in range(len(intended)):
    #         intended[i] /= self.totalPropWeights
    #     if math.fabs(sum(intended) - 1.0 > 1e-14):
    #         raise P4Error(
    #             "bad sum of intended proposal probs. %s" % sum(intended))
    #     spacer = ' ' * 4
    #     print("\nIntended proposal probabilities (%)")
    #     # print "There are %i proposals" % len(self.proposals)
    #     print("%2s %11s %23s %5s %5s" % ('', 'intended(%)', 'proposal', 'part', 'num'))
    #     for i in range(len(self.proposals)):
    #         print("%2i" % i, end=' ')
    #         p = self.proposals[i]
    #         print("   %6.2f    " % (100. * intended[i]), end=' ')
    #         print(" %20s" % p.name, end=' ')
    #         if p.pNum != -1:
    #             print(" %3i " % p.pNum, end=' ')
    #         else:
    #             print("   - ", end=' ')
    #         if p.mtNum != -1:
    #             print(" %3i " % p.mtNum, end=' ')
    #         else:
    #             print("   - ", end=' ')
    #         print()
class STMcmcCheckPointReader(object):
    """Read in and display mcmc_checkPoint files.
    Three options--
    To read in a specific checkpoint file, specify the file name by
    fName=whatever
    To read in the most recent (by os.path.getmtime()) checkpoint
    file, say last=True
    If you specify neither of the above, it will read in all the
    checkPoint files that it finds.
    Where it looks is determined by theGlob, which by default is '*',
    ie everything in the current directory.  If you want to look
    somewhere else, you can specify eg
        theGlob='SomeWhereElse/*' 
    or, if it is unambiguous, just
        theGlob='S*/*' 
    So you might say
        cpr = STMcmcCheckPointReader(theGlob='*_0.*')
    to get all the checkpoints from the first run, run 0.  Then, you
    can tell the cpr object to do various things.  Eg
        cpr.writeProposalAcceptances()
    But perhaps the most powerful thing about it is that it allows
    easy access to the checkpointed Mcmc objects, in the list mm.  Eg
    to get the first one, ask for
        m = cpr.mm[0]
    and m is an STMcmc object, complete with all its records of
    proposals and acceptances and so on.  And the TreePartitions
    object.  
    (Sorry!  -- Lazy documentation.  See the source code for more that it can do.)
    """
    def __init__(self, fName=None, theGlob='*', last=False, verbose=True):
        self.mm = []
        if not fName:
            #fList = [fName for fName in os.listdir(os.getcwd()) if fName.startswith("mcmc_checkPoint")]
            #fList = glob.glob(theGlob)
            # print "Full glob = %s" % fList
            fList = [fName for fName in glob.glob(theGlob) if
                     os.path.basename(fName).startswith("mcmc_checkPoint")]
            # print fList
            if not fList:
                raise P4Error("No checkpoints found in this directory.")
            if last:
                # Find the most recent
                mostRecent = os.path.getmtime(fList[0])
                mostRecentFileName = fList[0]
                if len(fList) > 1:
                    for fName in fList[1:]:
                        mtime = os.path.getmtime(fName)
                        if mtime > mostRecent:
                            mostRecent = mtime
                            mostRecentFileName = fName
                f = open(mostRecentFileName, 'rb')
                m = pickle.load(f)
                f.close()
                self.mm.append(m)
            else:
                # get all the files
                for fName in fList:
                    f = open(fName, 'rb')
                    m = pickle.load(f)
                    f.close()
                    self.mm.append(m)
                self.mm = p4.func.sortListOfObjectsOn2Attributes(
                    self.mm, "gen", 'runNum')
        else:
            # get the file by name
            f = open(fName, 'rb')
            m = pickle.load(f)
            f.close()
            self.mm.append(m)
        if verbose:
            self.dump()
    def dump(self):
        print("STMcmcCheckPoints (%i checkPoints read)" % len(self.mm))
        print("%12s %12s %12s %12s" % (" ", "index", "run", "gen+1"))
        print("%12s %12s %12s %12s" % (" ", "-----", "---", "-----"))
        for i in range(len(self.mm)):
            m = self.mm[i]
            # print "    %2i    run %2i,  gen+1 %11i" % (i, m.runNum, m.gen+1)
            print("%12s %12s %12s %12s" % (" ", i, m.runNum, m.gen + 1))
    def compareSplits(self, mNum1, mNum2, verbose=True, minimumProportion=0.1):
        """Do the TreePartitions.compareSplits() method between two checkpoints 
        Args:
            mNum1, mNum2 (int): indices to STMcmc checkpoints in self
        Returns:
            a tuple of asdoss and the maximum difference in split supports
        """
        # Should we be only looking at splits within the 95% ci of the topologies?
        m1 = self.mm[mNum1]
        m2 = self.mm[mNum2]
        tp1 = m1.treePartitions
        tp2 = m2.treePartitions
        if verbose:
            print("\nSTMcmcCheckPointReader.compareSplits(%i,%i)" % (mNum1, mNum2))
            print("%12s %12s %12s %12s %12s" % ("mNum", "runNum", "start", "gen+1", "nTrees"))
            for i in range(5):
                print("   ---------", end=' ')
            print()
            for mNum in [mNum1, mNum2]:
                print(" %10i " % mNum, end=' ')
                m = self.mm[mNum]
                print(" %10i " % m.runNum, end=' ')
                print(" %10i " % (m.startMinusOne + 1), end=' ')
                print(" %10i " % (m.gen + 1), end=' ')
                # for i in m.splitCompares:
                #    print i
                print(" %10i " % m.treePartitions.nTrees)
        asdos, maxDiff, meanDiff = self.compareSplitsBetweenTwoTreePartitions(
            tp1, tp2, minimumProportion, verbose=verbose)
        asdos2, maxDiff2, meanDiff2 = self.compareSplitsBetweenTwoTreePartitions(
            tp2, tp1, minimumProportion, verbose=verbose)
        if math.fabs(asdos - asdos2) > 0.000001:
            print("Reciprocal assdos differs:  %s  %s" % (asdos, asdos2))
        if asdos == None and verbose:
            print("No splits > %s" % minimumProportion)
        return asdos, maxDiff, meanDiff
    def compareSplitsBetweenTwoTreePartitions(tp1, tp2, minimumProportion, verbose=False):
        """Returns a tuple of asdoss, maximum of the differences and mean of the differences
        This calls the method TreePartitions.compareSplits(), and digests the
        results returned from that.
        Args:
            tp1, tp2 (TreePartition): TreePartition objects
            minimumProportion (float): passed to TreePartitions.compareSplits()
        
        Returns:
            (asdoss, maxOfDiffs, meanOfDiffs)
        """
        ret = tp1.compareSplits(tp2, minimumProportion=minimumProportion)
        #print(ret)  # a list of 3-item lists
        #  1. The split key
        #  2. The split string
        #  3. A list of the 2 supports
        if not ret:
            return None
        sumOfStdDevs = 0.0
        nSplits = len(ret)
        diffs = []
        for i in ret:
            # print "            %.3f  %.3f    " % (i[2][0], i[2][1]),
            stdDev = math.sqrt(p4.func.variance(i[2]))
            # print "%.5f" % stdDev
            sumOfStdDevs += stdDev
            diffs.append(math.fabs(i[2][0] - i[2][1]))
        asdoss = sumOfStdDevs / nSplits
        maxOfDiffs = max(diffs)
        meanOfDiffs = sum(diffs) / nSplits
        if verbose:
            print("     nSplits=%i, average of std devs of split supports %.4f " % (nSplits, asdoss))
            print("     max of differences %f, mean of differences %f" % (maxOfDiffs, meanOfDiffs))
        return (asdoss, maxOfDiffs, meanOfDiffs)  
    compareSplitsBetweenTwoTreePartitions = staticmethod(
        compareSplitsBetweenTwoTreePartitions)
    def compareSplitsAll(self, precision=3, linewidth=120):
        """Do the compareSplits() method between all pairs
        Output is verbose.  Shows 
        - average standard deviation of split frequencies (or supports), like MrBayes
        - maximum difference between split supports from each pair of checkpoints, like PhyloBayes
        Returns:
            None
        """
        nM = len(self.mm)
        nItems = int(((nM * nM) - nM) / 2)
        asdosses = np.zeros((nM, nM), dtype=np.float64)
        vect = np.zeros(nItems, dtype=np.float64)
        maxDiffs = np.zeros((nM, nM), dtype=np.float64)
        vCounter = 0
        for mNum1 in range(1, nM):
            for mNum2 in range(mNum1):
                thisAsdoss, thisMaxDiff, thisMeanDiff = self.compareSplits(mNum1, mNum2, verbose=False)
                #print("+++ thisAsdoss = %s  thisMaxDiff=%f, mNum1=%i, mNum2=%i" % (
                #      thisAsdoss, thisMaxDiff, mNum1, mNum2))
                if thisAsdoss == None:
                    thisAsdoss = 0.0
                asdosses[mNum1][mNum2] = thisAsdoss
                asdosses[mNum2][mNum1] = thisAsdoss
                vect[vCounter] = thisAsdoss
                vCounter += 1
                maxDiffs[mNum1][mNum2] = thisMaxDiff
                maxDiffs[mNum2][mNum1] = thisMaxDiff
                if 0:
                    print(" %10i " % mNum1, end=' ')
                    print(" %10i " % mNum2, end=' ')
                    print("%.3f" % thisAsdoss)
        # Save current numpy printoptions, and restore, below.
        curr = np.get_printoptions()
        np.set_printoptions(precision=precision, linewidth=linewidth)
        print("Pairwise asdoss values ---")
        print(asdosses)
        print()
        print("For the %i values in one triangle," % nItems)
        print("max =  ", vect.max())
        print("min =  ", vect.min())
        print("mean = ", vect.mean())
        print("var =  ", vect.var())
        print()
        print("Pairwise maximum differences in split supports between the two runs ---")
        print(maxDiffs)
        # Reset printoptions back to what it was
        np.set_printoptions(
            precision=curr['precision'], linewidth=curr['linewidth'])
    def writeProposalAcceptances(self):
        for m in self.mm:
            m.writeProposalAcceptances()
    def writeSwapMatrices(self):
        for m in self.mm:
            if m.nChains > 1:
                m.writeSwapMatrix()
    def writeProposalProbs(self):
        for m in self.mm:
            m.writeProposalProbs()
class QpaML(object):
    """Uses STMcmc to do likelihood calcs and Q optimization."""
    def __init__(self, inTrees, bigT):
        assert inTrees
        ttDupes = []
        for t in inTrees:
            ttDupes.append(t.dupe())
        
        assert bigT
        bigTDupe = bigT.dupe()
        if bigTDupe.taxNames:
            pass
        else:
            raise P4Error('The bigT needs taxNames')
        stm = STMcmc(ttDupes, bigT=bigTDupe, modelName='QPA',
                     beta=1.0, spaQ=0.5, stRFCalc='purePython1',
                     nChains=1, runNum=0, sampleInterval=100,
                     checkPointInterval=None, useSplitSupport=False, verbose=False,
                     checkForOutputFiles=False)
        self.ch = STChain(stm, 0)
    def setSuperTree(self, st):
        assert self.ch.propTree.taxNames
        st = st.dupe()
        if st.taxNames:
            assert st.taxNames == self.ch.propTree.taxNames
        else:
            st.taxNames = self.ch.propTree.taxNames
        st.setPreAndPostOrder()
        #st.draw()
        st.taxBits = [1 << i for i in range(st.nTax)]
        for n in st.iterPostOrder():
            if n == st.root:
                break
            if n.isLeaf:
                spot = st.taxNames.index(n.name)
                n.stSplitKey = 1 << spot
            else:
                n.stSplitKey = n.leftChild.stSplitKey
                p = n.leftChild.sibling
                while p:
                    n.stSplitKey |= p.stSplitKey    # "or", in-place
                    p = p.sibling
        st.skk = [n.stSplitKey for n in st.iterInternalsNoRoot()]
        st.qSet = set()
        for sk in st.skk:
            ups = [txBit for txBit in st.taxBits if (sk & txBit)]
            downs = [txBit for txBit in st.taxBits if not (sk & txBit)]
            for down in itertools.combinations(downs, 2):
                assert down[0] < down[1]   # probably not needed
                for up in itertools.combinations(ups, 2):
                    assert up[0] < up[1]  # probably not needed
                    if down[0] < up[0]:
                        st.qSet.add(down + up)
                    else:
                        st.qSet.add(up + down)
        # print st.qSet
        st.nQuartets = len(st.qSet)
        self.ch.propTree = st
        
                
    def calcP(self, Q):
        if Q >= 1.:
            return 10000000.
        if Q <= 0.:
            return 10000000.
        self.ch.propTree.spaQ = Q
        self.ch.getTreeLogLike_qpa_slow()
        return -self.ch.propTree.logLike
    def optimizeQ(self, x0=0.3):
        res = minimize(self.calcP, x0, method='Nelder-Mead')
        return (res.x, res.fun)
class SpaML(object):
    """Using STMcmc."""
    def __init__(self, inTrees, bigT):
        assert inTrees
        ttDupes = []
        for t in inTrees:
            ttDupes.append(t.dupe())
        
        assert bigT
        bigTDupe = bigT.dupe()
        if bigTDupe.taxNames:
            pass
        else:
            raise P4Error('The bigT needs taxNames')
        stm = STMcmc(ttDupes, bigT=bigTDupe, modelName='SPA',
                     beta=1.0, spaQ=0.5, stRFCalc='purePython1',
                     nChains=1, runNum=0, sampleInterval=100,
                     checkPointInterval=None, useSplitSupport=False, verbose=False, 
                     checkForOutputFiles=False)
        self.ch = STChain(stm, 0)
    def setSuperTree(self, st):
        assert self.ch.propTree.taxNames
        st = st.dupe()
        if st.taxNames:
            assert st.taxNames == self.ch.propTree.taxNames
        else:
            st.taxNames = self.ch.propTree.taxNames
        st.setPreAndPostOrder()
        self.ch.propTree = st
        self.ch.setupBitarrayCalcs()
        
                
    def calcP(self, Q):
        #if not isinstance(Q, np.ndarray):
        #    Q = np.array([Q])
        if Q >= 1.:
            return 10000000.
        if Q <= 0.:
            return 10000000.
        if not isinstance(Q, np.ndarray):
            Q = np.array([Q])
        self.ch.propTree.spaQ = Q
        self.ch.getTreeLogLike_spa_bitarray()
        return -self.ch.propTree.logLike
    def optimizeQ(self, x0=0.3):
        res = minimize(self.calcP, x0, method='Nelder-Mead')
        return (res.x, res.fun)
class SR2008ML(object):
    """Using STMcmc, with SR2008_rf_aZ_fb."""
    def __init__(self, inTrees, bigT):
        assert inTrees
        ttDupes = []
        for t in inTrees:
            ttDupes.append(t.dupe())
        
        assert bigT
        bigTDupe = bigT.dupe()
        if bigTDupe.taxNames:
            pass
        else:
            raise P4Error('SR2008ML: The bigT needs taxNames')
        stm = STMcmc(ttDupes, bigT=bigTDupe, modelName='SR2008_rf_aZ_fb',
                     beta=1.0, spaQ=0.5, 
                     stRFCalc='purePython1',
                     #stRFCalc='bitarray',
                     nChains=1, runNum=0, sampleInterval=100,
                     checkPointInterval=None, useSplitSupport=False, verbose=False, 
                     checkForOutputFiles=False)
        self.ch = STChain(stm, 0)
    def setSuperTree(self, st):
        assert self.ch.propTree.taxNames
        st = st.dupe()
        if st.taxNames:
            assert st.taxNames == self.ch.propTree.taxNames
        else:
            st.taxNames = self.ch.propTree.taxNames
        st.setPreAndPostOrder()
        self.ch.propTree = st
        #self.ch.setupBitarrayCalcs()
        
                
    def calcP(self, beta):
        myMIN = 1.e-10
        myMAX = 1.e+10
        if beta >= myMAX:
            return 10000000.
        if beta <= myMIN:
            return 10000000.
        self.ch.propTree.beta = beta
        self.ch.getTreeLogLike_ppy1()  # pure python
        return -self.ch.propTree.logLike
    def optimizeBeta(self, x0=1.0):
        res = minimize(self.calcP, x0, method='Nelder-Mead')
        return (res.x, res.fun)