# This is STMcmc, for super tree mcmc.
# Started 18 March 2011, first commit 22 March 2011.
import p4.pf as pf
import p4.func
from p4.var import var
import math
import random
import string
import sys
import time
import copy
import os
import pickle
import glob
import numpy as np
from p4.p4exceptions import P4Error
from p4.treepartitions import TreePartitions
from p4.constraints import Constraints
from p4.tree import Tree
import datetime
import itertools
from scipy.optimize import minimize
import logging
import bitarray
def choose(n, k):
"""
A fast way to calculate binomial coefficients
by Andrew Dalke (contrib).
"""
if 0 <= k <= n:
ntok = 1
ktok = 1
for t in xrange(1, min(k, n - k) + 1):
ntok *= n
ktok *= t
n -= 1
return ntok // ktok
else:
return 0
# def nSplits(n):
# mySum = 0
# for k in range(2, n-1):
# mySum += choose(n-1, k)
# return mySum
def bForN(n):
# This is the log version of this function. The max diff (in
# log(result)) between this and the non-log function seems to be
# about 2.5e-10 for n up to 10000.
prodLog = 0.0
if n > 3:
for k in range(4, n + 1):
prodLog += math.log((2 * k) - 5)
return prodLog
def BS2009_Eqn30_ZTApprox(n, beta, cT):
# This log version of this function differs from from the non-log
# version (in log(result)) by at most 6.82e-13 for n up to 150,
# over a wide range of beta (0.001 -- 1000) and cT (2 -- n/2)
myLambda = cT / (2.0 * n)
tester = 0.5 * math.log((n - 3.) / myLambda)
epsilon = math.exp(-2. * beta)
bigANEpsilon = 1 + (((2. * n) - 3.) * epsilon) + \
(2. * ((n * n) - (4. * n) - 6.) * epsilon * epsilon)
termA = math.log(bigANEpsilon + 6 * cT * epsilon * epsilon)
if beta < tester:
termB = -(2. * beta) * (n - 3.) + \
(myLambda * (math.exp(2. * beta) - 1.))
termB += bForN(n)
if termA > termB:
return termA
else:
return termB
else:
return termA
def popcountA(k, nBits):
count = 0
for i in range(nBits):
tester = 1 << i
if tester > k:
return count
if tester & k:
count += 1
return count
def bitReduce(bk, txBits, lLen, sLen, allOnes):
# print "bitReduce: bk %i, txBits %i, lLen %i, sLen %i, allOnes %i" % (bk,
# txBits, lLen, sLen, allOnes)
newBk = 0
counter = 0
pops = 0
for pos in range(lLen):
tester = 1 << pos
# print "pos %2i, tester: %3i" % (pos, tester)
if tester & txBits:
# print " tester & txBits -- True"
if tester & bk:
adder = 1 << counter
# print " adding:", adder
newBk += adder
pops += 1
else:
# print " not adding"
pass
counter += 1
if (1 & newBk):
# print "flipping"
newBk = allOnes ^ newBk
pops = sLen - pops
# print "returning newBk %i, pops %i" % (newBk, pops)
return newBk, pops
if 0: # test bitReduce
sk = 6 # always at least 2 bits, even
txBits = 30
lLen = 5
sLen = 4
allOnes = 15
print(" sk: %3i %s" % (sk, p4.func.getSplitStringFromKey(sk, lLen)))
print("taxBits: %3i %s" % (txBits, p4.func.getSplitStringFromKey(txBits, lLen)))
rsk, popcount = bitReduce(sk, txBits, lLen, sLen, allOnes)
print(" rsk: %3i %s" % (rsk, p4.func.getSplitStringFromKey(rsk, sLen)))
print(" popcount %i" % popcount)
# sk: 6 .**..
# taxBits: 30 .****
# rsk: 12 ..**
# popcount 2
def maskedSymmetricDifference(skk, skSet, taxBits, longLen, shortLen, allOnes):
if 0:
print("-" * 50)
print("skk (skk_ppy1 from the current supertree)")
for sk in skk:
print(p4.func.getSplitStringFromKey(sk, longLen))
print("skSet (from input tree)")
for sk in skSet:
print(p4.func.getSplitStringFromKey(sk, shortLen))
print("taxBits:", taxBits, p4.func.getSplitStringFromKey(taxBits, longLen))
newSkk = []
for sk in skk:
reducedSk, popcount = bitReduce(
sk, taxBits, longLen, shortLen, allOnes)
if 0:
print("taxBits: %s " % p4.func.getSplitStringFromKey(taxBits, longLen), end=' ')
print("%4i %s " % (sk, p4.func.getSplitStringFromKey(sk, longLen)), end=' ')
print("%4i %s %i" % (reducedSk, p4.func.getSplitStringFromKey(reducedSk, shortLen), popcount))
if popcount <= 1 or popcount >= (shortLen - 1):
pass
else:
newSkk.append(reducedSk)
newSkkSet = set(newSkk)
# print newSkkSet, skSet
# print "reduced supertree splits = newSkkSet = %s" % newSkkSet
ret = len(newSkkSet.symmetric_difference(skSet))
# print "symmetric difference %i" % ret
nCherries = 0
for sk in newSkkSet:
popcount = popcountA(sk, shortLen)
if popcount == 2:
nCherries += 1
# not "elif", because they might both be True
if popcount == (shortLen - 2):
nCherries += 1
# print "nCherries %i" % nCherries
return ret, nCherries
def slowQuartetDistance(st, inputTree):
dst = st.dupe()
toRemove = []
for n in dst.iterLeavesNoRoot():
if n.name not in inputTree.taxNames:
toRemove.append(n)
for n in toRemove:
dst.removeNode(n)
qd = dst.topologyDistance(inputTree, metric='scqdist')
return qd
class STChain(object):
def __init__(self, aSTMcmc, chNum):
gm = ['STChain.__init__()']
self.stMcmc = aSTMcmc
self.chNum = chNum # Does not change. Used with fastspa only
self.tempNum = chNum # 'temp'erature, not 'temp'orary; changes when swapped.
if chNum == 0:
self.curTree = aSTMcmc.tree.dupe()
self.propTree = aSTMcmc.tree.dupe()
else: # heated chains are randomized, unless var.mcmc_sameBigTToStartOnAllChains is set.
if var.mcmc_sameBigTToStartOnAllChains: # False by default
self.curTree = aSTMcmc.tree.dupe()
self.propTree = aSTMcmc.tree.dupe()
else:
rTree = aSTMcmc.tree.dupe()
rTree.randomizeTopology(randomBrLens = False)
rTree.stripBrLens()
rTree.setPreAndPostOrder()
self.curTree = rTree
self.propTree = rTree.dupe()
self.logProposalRatio = 0.0
self.logPriorRatio = 0.0
self.frrf = None
#self.nInTreeSplits = 0
if self.stMcmc.modelName.startswith('SR2008_rf'):
self.curTree.beta = self.stMcmc.beta
self.propTree.beta = self.stMcmc.beta
if self.stMcmc.stRFCalc == 'purePython1':
self.getTreeLogLike_ppy1()
elif self.stMcmc.stRFCalc == 'fastReducedRF':
self.startFrrf()
self.getTreeLogLike_fastReducedRF()
elif self.stMcmc.stRFCalc == 'bitarray':
self.setupBitarrayCalcs()
self.getTreeLogLike_bitarray()
self.curTree.logLike = self.propTree.logLike
elif self.stMcmc.modelName.startswith('SPA'):
self.curTree.spaQ = np.array([self.stMcmc.spaQ])
self.propTree.spaQ = np.array([self.stMcmc.spaQ])
# for t in self.stMcmc.trees:
# self.nInTreeSplits += len(t.splSet)
# print "Got nInTreeSplits %s" % self.nInTreeSplits
self.setupBitarrayCalcs()
self.getTreeLogLike_spa_bitarray()
if var.stmcmc_useFastSpa:
#print("Here E. bitarray propTree.logLike is %f" % self.propTree.logLike)
fspaLike = self.stMcmc.fspa.calcLogLike(self.chNum)
diff = math.fabs(self.propTree.logLike - fspaLike)
#print("Got fspaLike %f, diff %g" % (fspaLike, diff))
if diff > 1e-13:
gm.append("bad fastspa likelihood calc, %f vs %f, diff %f" % (self.propTree.logLike, fspaLike, diff))
raise P4Error(gm)
self.curTree.logLike = self.propTree.logLike
elif self.stMcmc.modelName.startswith('QPA'):
self.curTree.spaQ = self.stMcmc.spaQ
self.propTree.spaQ = self.stMcmc.spaQ
self.nPossibleQuartets = choose(self.stMcmc.tree.nTax, 4) * 3
self.getTreeLogLike_qpa_slow()
self.curTree.logLike = self.propTree.logLike
else:
gm.append('Unknown modelName %s' % self.stMcmc.modelName)
raise P4Error(gm)
if 0:
print("STChain init()")
self.curTree.draw()
print("logLike is %f" % self.curTree.logLike)
def getTreeLogLike_qpa_slow(self):
gm = ["STChain.getTreeLogLike_qpa_slow()"]
if self.propTree.spaQ > 1. or self.propTree.spaQ <= 0.0:
gm.append("bad propTree.spaQ value %f" % self.propTree.spaQ)
raise P4Error(gm)
for n in self.propTree.iterInternalsPostOrder():
if n == self.propTree.root:
break
n.stSplitKey = n.leftChild.stSplitKey
p = n.leftChild.sibling
while p:
n.stSplitKey |= p.stSplitKey # "or", in-place
p = p.sibling
self.propTree.skk = [
n.stSplitKey for n in self.propTree.iterInternalsNoRoot()]
self.propTree.qSet = set()
for sk in self.propTree.skk:
ups = [txBit for txBit in self.propTree.taxBits if (sk & txBit)]
downs = [
txBit for txBit in self.propTree.taxBits if not (sk & txBit)]
for down in itertools.combinations(downs, 2):
if down[0] > down[1]:
down = (down[1], down[0])
for up in itertools.combinations(ups, 2):
if up[0] > up[1]:
up = (up[1], up[0])
if down[0] < up[0]:
self.propTree.qSet.add(down + up)
else:
self.propTree.qSet.add(up + down)
# print self.propTree.qSet
self.propTree.nQuartets = len(self.propTree.qSet)
if self.propTree.nQuartets:
q = self.propTree.spaQ / self.propTree.nQuartets
R = 1. - self.propTree.spaQ
r = R / (self.nPossibleQuartets - self.propTree.nQuartets)
logq = math.log(q)
else:
R = 1.
r = R / self.nPossibleQuartets
logr = math.log(r)
self.propTree.logLike = 0.0
for it in self.stMcmc.trees:
for qu in it.qSet:
if qu in self.propTree.qSet:
self.propTree.logLike += logq
else:
self.propTree.logLike += logr
def getTreeLogLike_spa_bitarray(self):
gm = ["STChain.getTreeLogLike_spa_bitarray"]
if self.propTree.spaQ[0] > 1. or self.propTree.spaQ[0] <= 0.0:
gm.append("bad propTree.spaQ value %f" % self.propTree.spaQ)
raise P4Error(gm)
slowCheck = False
if slowCheck:
print("\n", "-" * 30)
print("Super tree: ", end=' ')
self.propTree.write()
slowCheckLogLike = 0.0
for it in self.stMcmc.trees:
it.makeSplitKeys()
it.inbb = [n.br for n in it.iterInternalsNoRoot()]
self.propTree.logLike = 0.0
#sumOfLogqs = 0.0
#sumOfLogrs = 0.0
# self.propTree.draw()
for it in self.stMcmc.trees:
if 0:
print("-" * 50)
it.draw()
print("baTaxBits %s" % it.baTaxBits)
print("firstTax at %i" % it.firstTax)
if 0:
print(" input tree: ", end=' ')
it.write()
if slowCheck:
stDupe = self.propTree.dupe()
toRemove = []
for n in stDupe.iterLeavesNoRoot():
if n.name not in it.taxNames:
toRemove.append(n)
for n in toRemove:
stDupe.removeNode(n)
stDupe.taxNames = it.taxNames
stDupe.makeSplitKeys(makeNodeForSplitKeyDict=True)
# No need to consider (masked) splits with less than two
# 1s or more than it.nTax - 2 1s.
upperGood = it.nTax - 2
relevantStSplits = []
for n in self.propTree.iterInternalsNoRoot():
# Choose which spl (spl or spl2) based on it.firstTax)
if n.ss.spl[it.firstTax]:
n.ss.theSpl = n.ss.spl
else:
n.ss.theSpl = n.ss.spl2
n.ss.maskedSplitWithTheFirstTaxOne = n.ss.theSpl & it.baTaxBits
n.ss.onesCount = n.ss.maskedSplitWithTheFirstTaxOne.count()
if 0:
print("bigT node %i" % n.nodeNum)
print(" theSpl is %s" % n.ss.theSpl)
print(" maskedSplitWithTheFirstTaxOne %s" % n.ss.maskedSplitWithTheFirstTaxOne)
print(" onesCount %i" % n.ss.onesCount)
if n.ss.onesCount >= 2 and n.ss.onesCount <= upperGood:
print(" -> relevant")
else:
print(" -> not relevant")
if n.ss.onesCount >= 2 and n.ss.onesCount <= upperGood:
relevantStSplits.append(n.ss)
nonRedundantStSplitDict = {}
for ss in relevantStSplits:
ss.bytes = ss.maskedSplitWithTheFirstTaxOne.tobytes()
nonRedundantStSplitDict[ss.bytes] = ss
if 0:
for ss in relevantStSplits:
ss.dump()
print("There are %i relevant splits in the st for this it." % len(relevantStSplits))
for ss in nonRedundantStSplitDict:
ss.dump()
print("There are %i non-redundant splits in the st for this it." % len(nonRedundantStSplitDict))
if 0:
if(len(relevantStSplits) != len(nonRedundantStSplitDict)):
print("Gen %12i: Got %i relevantStSplits; %i in nonRedundantStSplitDict" % (
self.stMcmc.gen, len(relevantStSplits),len(nonRedundantStSplitDict)))
# S_st is the number of splits in the reduced supertree
S_st = len(nonRedundantStSplitDict)
if slowCheck:
# stDupe.draw()
# print "the drawing above is stDupe"
slowCheckS_st = len([n for n in stDupe.iterInternalsNoRoot()])
assert S_st == slowCheckS_st
# S is the number of possible splits in an it-sized tree
S = 2 ** (it.nTax - 1) - (it.nTax + 1)
# print(" S=%i, S_st=%i, S_x=%i" % (S, S_st, S-S_st))
if S_st:
q = self.propTree.spaQ[0] / S_st
R = 1. - self.propTree.spaQ[0]
r = R / (S - S_st)
#print("q=%g" % q)
logq = math.log(q)
else:
R = 1.
r = R / S
#print ("r=%g" % r)
logr = math.log(r)
for n in it.internals:
ret = nonRedundantStSplitDict.get(n.stSplitKeyBytes)
if ret:
if self.stMcmc.useSplitSupport and n.br.support != None:
thisSplitLike = math.log(r + (n.br.support * (q - r)))
self.propTree.logLike += thisSplitLike
# sumOfLogqs += thisSplitLike
else:
self.propTree.logLike += logq
# sumOfLogqs += logq
else:
# If we are here when S_st is zero, then q is undefined.
# So fall into the else clause
if 0:
# This has the dodgy assumption that 1-support is in the supertree.
# Might be zero support for a split in the supertree.
if self.stMcmc.useSplitSupport and S_st and n.br.support != None:
self.propTree.logLike += math.log(
r + ((1. - n.br.support) * (q - r)))
else:
self.propTree.logLike += logr
else:
# This does not make the assumption above. Safer.
self.propTree.logLike += logr
# sumOfLogrs += logr
if slowCheck:
for inb in it.inbb:
splitString = p4.func.getSplitStringFromKey(
inb.splitKey, it.nTax)
print(" %s " % splitString, end=' ')
ret = stDupe.nodeForSplitKeyDict.get(inb.splitKey)
# Here we need to check that S_st is not zero. If it is,
# then q is undefined.
if self.stMcmc.useSplitSupport and inb.support != None and S_st:
if ret:
thisLogQc = math.log(r + (inb.support * (q - r)))
print("qc %.3f" % thisLogQc)
slowCheckLogLike += thisLogQc
else:
thisLogRc = math.log(
r + ((1. - inb.support) * (q - r)))
print("rc %.3f" % thisLogRc)
slowCheckLogLike += thisLogRc
else:
if ret:
print("q %.3f" % q)
slowCheckLogLike += logq
else:
print("r %.3f" % r)
slowCheckLogLike += logr
if 1:
if slowCheck:
# print self.propTree.logLike, slowCheckLogLike
myDiff = self.propTree.logLike - slowCheckLogLike
if math.fabs(myDiff) > 1.e-12:
gm.append("Bad like calc. slowCheck %f, bitarray %f, diff %g" % (
slowCheckLogLike, self.propTree.logLike, myDiff))
raise P4Error(gm)
# print("sumOfLogqs = %g, sumOfLogrs = %g" % (sumOfLogqs, sumOfLogrs))
def setupBitarrayCalcs(self):
# Prepare self.propTree (ie bigT). First make n.stSplitKeys. These
# are temporary; the info is held more permanently in n.ss, a BigTSplitStuff object
for n in self.propTree.iterPostOrder():
if n == self.propTree.root:
break
if n.isLeaf:
spot = self.stMcmc.taxNames.index(n.name)
self.stMcmc.tBits[spot] = True
n.stSplitKey = bitarray.bitarray(self.stMcmc.tBits)
self.stMcmc.tBits[spot] = False
else:
n.stSplitKey = n.leftChild.stSplitKey.copy()
p = n.leftChild.sibling
while p:
n.stSplitKey |= p.stSplitKey # "or", in-place
p = p.sibling
# Next transfer the internal node split keys to BigTSplitStuff objects
for n in self.propTree.iterInternalsNoRoot():
n.ss = BigTSplitStuff()
n.ss.spl = n.stSplitKey
n.ss.spl2 = n.ss.spl.copy()
n.ss.spl2.invert()
# This next one will be empty, not used immediately, but will
# be used after supertree rearrangements.
self.propTree.root.ss = BigTSplitStuff()
if self.stMcmc.modelName.startswith('SPA') and var.stmcmc_useFastSpa:
self.stMcmc.fspa.setBigT(len(self.propTree.nodes), self.propTree.nTax,
self.propTree.postOrder, self.propTree.spaQ)
for nNum in self.propTree.postOrder:
if nNum == -10000:
break
n = self.propTree.nodes[nNum]
if n == self.propTree.root or n.isLeaf:
theSpl = '0'
theSpl2 = '0'
else:
theSpl = n.ss.spl.to01()
theSpl2 = n.ss.spl2.to01()
self.stMcmc.fspa.setBigTNoSpl(self.chNum, nNum, theSpl, theSpl2)
def refreshBitarrayPropTree(self):
# Refresh self.propTree (ie bigT) after a topology change.
for n in self.propTree.iterPostOrder():
if n == self.propTree.root:
break
if n.isLeaf:
pass
else:
n.stSplitKey = n.leftChild.stSplitKey.copy()
p = n.leftChild.sibling
while p:
n.stSplitKey |= p.stSplitKey # "or", in-place
p = p.sibling
# Next transfer the internal node split keys to BigTSplitStuff objects
for n in self.propTree.iterInternalsNoRoot():
n.ss.spl = n.stSplitKey
n.ss.spl2 = n.ss.spl.copy()
n.ss.spl2.invert()
if self.stMcmc.modelName.startswith('SPA') and var.stmcmc_useFastSpa:
for nNum in self.propTree.postOrder:
if nNum == -10000:
break
n = self.propTree.nodes[nNum]
if n == self.propTree.root or n.isLeaf:
theSpl = '0'
theSpl2 = '0'
else:
theSpl = n.ss.spl.to01()
theSpl2 = n.ss.spl2.to01()
self.stMcmc.fspa.setBigTNoSpl(self.chNum, nNum, theSpl, theSpl2)
def startFrrf(self):
# if using self.stMcmc.stRFCalc= 'fastReducedRF'
self.frrf = self.stMcmc.Frrf(len(self.stMcmc.taxNames))
self.bigTr = self.frrf.setBigT(
len(self.propTree.nodes), self.propTree.nTax, self.propTree.postOrder)
for n in self.propTree.nodes:
if n.parent:
self.bigTr.setParent(n.nodeNum, n.parent.nodeNum)
if n.leftChild:
self.bigTr.setLeftChild(n.nodeNum, n.leftChild.nodeNum)
else:
self.bigTr.setNodeTaxNum(
n.nodeNum, self.stMcmc.taxNames.index(n.name))
if n.sibling:
self.bigTr.setSibling(n.nodeNum, n.sibling.nodeNum)
if 1:
for t in self.stMcmc.trees:
tr = self.frrf.appendInTree(len(t.nodes), t.nTax, t.postOrder)
for n in t.nodes:
if n.parent:
tr.setParent(n.nodeNum, n.parent.nodeNum)
if n.leftChild:
tr.setLeftChild(n.nodeNum, n.leftChild.nodeNum)
else:
tr.setNodeTaxNum(
n.nodeNum, self.stMcmc.taxNames.index(n.name))
if n.sibling:
tr.setSibling(n.nodeNum, n.sibling.nodeNum)
self.frrf.setInTreeTaxBits()
self.frrf.setInTreeInternalBits()
self.frrf.maybeFlipInTreeBits()
self.frrf.setBigTInternalBits()
# self.frrf.dump()
def getTreeLogLike_ppy1(self):
gm = ['STChain.getTreeLogLike_pp1']
self.propTree.makeSplitKeys()
self.propTree.skk = [
n.br.splitKey for n in self.propTree.iterInternalsNoRoot()]
self.propTree.logLike = 0.0
for t in self.stMcmc.trees:
# Get the distance
thisDist = None
if self.stMcmc.modelName.startswith('SR2008_rf'):
thisDist, nCherries = maskedSymmetricDifference(self.propTree.skk, t.skSet,
t.taxBits, self.stMcmc.nTax, t.nTax, t.allOnes)
else:
raise P4Error(
"STChain.getTreeLogLike_ppy1() unknown model '%s'" % self.stMcmc.modelName)
# Now multiply by beta, and do approximate Z_T
assert thisDist != None
beta_distance = self.propTree.beta * thisDist
if self.stMcmc.modelName == 'SR2008_rf_ia':
self.propTree.logLike -= beta_distance
elif self.stMcmc.modelName.startswith('SR2008_rf_aZ'):
log_approxZT = BS2009_Eqn30_ZTApprox(
t.nTax, self.propTree.beta, nCherries)
if 0:
# Testing, testing ...
assert self.propTree.beta == 0.1
assert t.nTax == 6
if nCherries == 2:
log_approxZT = 4.13695897651 # exact
elif nCherries == 3:
log_approxZT = 4.14853562562
self.propTree.logLike -= log_approxZT
self.propTree.logLike -= beta_distance
else:
gm.append("Unknown modelName %s" % self.stMcmc.modelName)
raise P4Error(gm)
def getTreeLogLike_fastReducedRF(self):
slowCheck = False
if slowCheck:
self.getTreeLogLike_ppy1()
savedLogLike = self.propTree.logLike
self.frrf.wipeBigTPointers()
for n in self.propTree.nodes:
if n.parent:
self.bigTr.setParent(n.nodeNum, n.parent.nodeNum)
if n.leftChild:
self.bigTr.setLeftChild(n.nodeNum, n.leftChild.nodeNum)
# else:
# bigTr.setNodeTaxNum(n.nodeNum, tNames.index(n.name))
if n.sibling:
self.bigTr.setSibling(n.nodeNum, n.sibling.nodeNum)
self.frrf.setBigTInternalBits()
if self.stMcmc.modelName == 'SR2008_rf_ia':
sd = self.frrf.getSymmDiff()
self.propTree.logLike = -sd * self.propTree.beta
elif self.stMcmc.modelName.startswith('SR2008_rf_aZ'):
self.propTree.logLike = self.frrf.getLogLike(self.propTree.beta)
if slowCheck:
if self.propTree.logLike != savedLogLike:
gm = ['STChain.getTreeLogLike_fastReducedRF()']
gm.append("Slow likelihood %f" % savedLogLike)
gm.append("Fast likelihood %f" % self.propTree.logLike)
raise P4Error(gm)
def getTreeLogLike_bitarray(self):
self.propTree.logLike = 0.0
slowCheck = False
if slowCheck:
self.propTree.makeSplitKeys()
self.propTree.skk = [
n.br.splitKey for n in self.propTree.iterInternalsNoRoot()]
for t in self.stMcmc.trees:
if 0:
print("-" * 50)
t.draw()
print("baTaxBits %s" % t.baTaxBits)
print("firstTax at %i" % t.firstTax)
# splitStuff objects with onesCount >= 2 and <= t.nTax = 2
usables = []
# No need to consider (masked) splits with less than two
# 1s or more than nTax - 2 1s. The nTax depends on the
# input tree.
upperGood = t.nTax - 2
for n in self.propTree.iterInternalsNoRoot():
# Choose which spl (spl or spl2) based on t.firstTax)
if n.ss.spl[t.firstTax]:
n.ss.theSpl = n.ss.spl
else:
n.ss.theSpl = n.ss.spl2
n.ss.maskedSplitWithTheFirstTaxOne = n.ss.theSpl & t.baTaxBits
n.ss.onesCount = n.ss.maskedSplitWithTheFirstTaxOne.count()
if 0:
print("bigT node %i" % n.nodeNum)
print(" theSpl is %s" % n.ss.theSpl)
print(" maskedSplitWithTheFirstTaxOne %s" % n.ss.maskedSplitWithTheFirstTaxOne)
print(" onesCount %i" % n.ss.onesCount)
if n.ss.onesCount >= 2 and n.ss.onesCount <= upperGood:
print(" -> used")
else:
print(" -> not used")
if n.ss.onesCount >= 2 and n.ss.onesCount <= upperGood:
usables.append(n.ss)
usablesDict = {}
for usable in usables:
usable.bytes = usable.maskedSplitWithTheFirstTaxOne.tobytes()
usablesDict[usable.bytes] = usable
splSet = set() # bytes, for RF calculation
for usable in usables:
# splSet.add(n.ss.maskedSplitWithTheFirstTaxOne.tobytes())
splSet.add(usable.bytes)
thisBaRF = len(splSet.symmetric_difference(t.splSet))
if slowCheck: # with purePython1
thisPPyRF, thisPPyNCherries = maskedSymmetricDifference(self.propTree.skk, t.skSet,
t.taxBits, self.stMcmc.nTax, t.nTax, t.allOnes)
if thisBaRF != thisPPyRF:
raise P4Error("bitarray and purePython1 RF calcs differ.")
beta_distance = self.propTree.beta * thisBaRF
if self.stMcmc.modelName == 'SR2008_rf_ia':
self.propTree.logLike -= beta_distance
elif self.stMcmc.modelName.startswith('SR2008_rf_aZ'):
nCherries = 0
for ba in splSet:
theSS = usablesDict[ba]
# theSS.dump()
if theSS.onesCount == 2:
nCherries += 1
if theSS.onesCount == upperGood:
nCherries += 1
if slowCheck:
if nCherries != thisPPyNCherries:
raise P4Error(
"bitarray and purePython1 nCherries calcs differ.")
log_approxZT = BS2009_Eqn30_ZTApprox(
t.nTax, self.propTree.beta, nCherries)
self.propTree.logLike -= log_approxZT
self.propTree.logLike -= beta_distance
else:
gm.append("Unknown model %s" % self.stMcmc.modelName)
raise P4Error(gm)
def proposePolytomy(self, theProposal):
theProposal.doAbort = False
dbug = False
if dbug:
# print "proposePolytomy() starting with this tree ..."
#self.propTree.draw(width=80, addToBrLen=0.2)
print("j There are %i internal nodes." % self.propTree.nInternalNodes)
if self.propTree.nInternalNodes == 1:
print("-> so its a star tree -> proposeDeleteEdge is not possible.")
elif self.propTree.nInternalNodes == self.propTree.nTax - 2:
print("-> so its a fully-resolved tree, so proposeAddEdge is not possible.")
if self.propTree.nInternalNodes == 1: # a star tree
self.proposeAddEdge(theProposal)
elif self.propTree.nInternalNodes == self.propTree.nTax - 2:
candidateNodes = self._getCandidateNodesForDeleteEdge()
if candidateNodes:
self.proposeDeleteEdge(theProposal, candidateNodes)
else:
#gm = ["proposePolytomy()"]
#gm.append("The tree is fully resolved, so I can't proposeAddEdge()")
#gm.append("But there are no suitable nodes to remove.")
#raise P4Error(gm)
theProposal.doAbort = True
self.curTree._nInternalNodes = self.propTree._nInternalNodes
return
else:
r = random.random()
#r = 0.4
if r < 0.5:
self.proposeAddEdge(theProposal)
else:
candidateNodes = self._getCandidateNodesForDeleteEdge()
if candidateNodes:
self.proposeDeleteEdge(theProposal, candidateNodes)
else:
self.proposeAddEdge(theProposal)
def proposeAddEdge(self, theProposal):
gm = ["STChain.proposeAddEdge()"]
# print "proposeAddEdge() here"
dbug = False
pTree = self.propTree
if 0:
print("proposeAddEdge(), starting with this tree ...")
pTree.draw()
print("k There are %i internal nodes." % pTree.nInternalNodes)
print("root is node %i" % pTree.root.nodeNum)
allPolytomies = []
for n in pTree.iterInternalsNoRoot():
if n.getNChildren() > 2:
allPolytomies.append(n)
if pTree.root.getNChildren() > 3:
allPolytomies.append(pTree.root)
theChosenPolytomy = random.choice(allPolytomies)
# We want to choose one of the possible ways to add a node. See
# Lewis et al page 246, left top. "The number of distinct ways of
# dividing k edges into two groups, making sure that at least 3
# edges are attached to each node afterwards, is 2^{k-1} - k - 1".
# For non-root polytomies (with 3 or more children), it is
# straightforward, but for root polytomies (ie with 4 or more
# children) it is different. I think in the case of root
# polytomies that they will be equivalent to non-root polytomies
# if I arbitrarily consider one randomly chosen child node to
# take the role that the parent takes in the non-root-polytomies.
# So a 4-child root will be considered to have a parent-like node
# and 3 children.
if theChosenPolytomy != pTree.root:
nChildren = theChosenPolytomy.getNChildren()
k = nChildren + 1
childrenNodeNums = pTree.getChildrenNums(theChosenPolytomy)
else:
# Its the root. So we say that a random child takes the role
# of the "parent", for purposes of these calculations.
nChildren = theChosenPolytomy.getNChildren() - 1 # n - 1 children
k = nChildren + 1
# Yes, all children.
childrenNodeNums = pTree.getChildrenNums(theChosenPolytomy)
nPossibleWays = math.pow(2, k - 1) - k - 1
if dbug:
print("These nodes are polytomies: %s" % [n.nodeNum for n in allPolytomies])
print("We randomly choose to do node %i" % theChosenPolytomy.nodeNum)
print("It has %i children, so k=%i, so there are %i possible ways to add a node." % (
nChildren, k, nPossibleWays))
# We want to choose one of the possible ways to add a node, but we
# want to choose it randomly. I'll describe it for the case with
# nChildren=5, so k is 6. We know already that there are
# nPossibleWays=25 different ways to add a node. The complication
# is that we could make a new group of 2, 3, or 4 nInNewGroup, and it will be
# different numbers of possible ways in each. The numbers of each are given by
# p4.func.nChoosek(), so there are 10 ways to make a group of 2 from 5
# children, 10 ways to make a group of 3 from 5 children, and 5
# ways to make a group of 4 from 5 children. So thats [10, 10,
# 5], which sums to 25 (nPossibleWays). So we can make a
# cumulative sum list ie [10, 20, 25], and use it to choose one
# group randomly.
nChooseKs = []
for i in range(2, nChildren):
nChooseKs.append(p4.func.nChooseK(nChildren, i))
cumSum = [nChooseKs[0]]
for i in range(len(nChooseKs))[1:]:
cumSum.append(nChooseKs[i] + cumSum[i - 1])
ran = random.randrange(nPossibleWays)
for i in range(len(cumSum)):
if ran < cumSum[i]:
break
nInNewGroup = i + 2
# Ok, so we have decided that of the nChildren of
# theChosenPolytomy, we will make a new node with a group of
# nInNewGroup of them. For that, we can use random.sample().
newChildrenNodeNums = random.sample(childrenNodeNums, nInNewGroup)
if dbug:
print("The nChooseKs are %s" % nChooseKs)
print("The cumSum is %s" % cumSum)
print("Since there are nPossibleWays=%i, we choose a random number from 0-%i" % (
nPossibleWays, nPossibleWays - 1))
print("->We chose a random number: %i" % ran)
print("So we choose the group at index %i, which means nInNewGroup=%i" % (i, nInNewGroup))
print("So we make a new node with newChildrenNodeNums %s" % newChildrenNodeNums)
# sys.exit()
# Choose to add a node between theChosenPolytomy and the first in
# the list of newChildrenNodeNums. The node that we add will be
# chosen from pTree.nodes for the first node where both the parent
# and the leftChild are None.
firstNode = pTree.nodes[newChildrenNodeNums[0]]
for newNode in pTree.nodes:
if not newNode.parent and not newNode.leftChild:
break
# print "Got newNode = %i" % newNode.nodeNum
# Add the newNode between theChosenPolytomy and firstNode
newNode.parent = theChosenPolytomy
newNode.leftChild = firstNode
firstNode.parent = newNode
if theChosenPolytomy.leftChild == firstNode:
theChosenPolytomy.leftChild = newNode
else:
oldCh = theChosenPolytomy.leftChild
while oldCh.sibling != firstNode:
oldCh = oldCh.sibling
oldCh.sibling = newNode
if firstNode.sibling:
newNode.sibling = firstNode.sibling
firstNode.sibling = None
pTree.setPreAndPostOrder()
pTree._nInternalNodes += 1
if 0:
# pTree.setPreAndPostOrder()
pTree.draw()
for nodeNum in newChildrenNodeNums[1:]:
n = pTree.pruneSubTreeWithoutParent(nodeNum)
pTree.reconnectSubTreeWithoutParent(n, newNode)
# Calculate the rawSplitKey and splitKey.
# if self.mcmc.constraints:
# children = [n for n in newNode.iterChildren()]
# x = children[0].br.rawSplitKey
# for n in children[1:]:
# y = n.br.rawSplitKey
# x = x | y # '|' is bitwise "OR".
# newNode.br.rawSplitKey = x
# if 1 & newNode.br.rawSplitKey: # Ie "Does rawSplitKey contain a 1?" or "Is rawSplitKey odd?"
# if self.mcmc.constraints:
# newNode.br.splitKey = self.mcmc.constraints.allOnes ^ newNode.br.rawSplitKey # "^" is xor, a bit-flipper.
# else:
# allOnes = 2**(self.propTree.nTax) - 1
# newNode.br.splitKey = allOnes ^ newNode.br.rawSplitKey
# else:
# newNode.br.splitKey = newNode.br.rawSplitKey
# Its a newly-added node, possibly in a new context. We need to
# deal with model stuff if it isHet. The model.isHet if any part
# isHet.
if dbug:
pTree.setPreAndPostOrder()
pTree.draw()
# Now the Hastings ratio. First calculate gamma_B. If the
# current tree is a star tree (nInternalNodes == 1) and the
# proposed tree is not fully resolved (ie is less than
# len(self.propTree.nodes) - 2), then gamma_B is 0.5.
if (self.curTree.nInternalNodes == 1) and (pTree.nInternalNodes < (len(pTree.nodes) - 2)):
gamma_B = 0.5
# If the proposed tree is fully resolved and the current tree is not
# the star tree
elif (pTree.nInternalNodes == (len(pTree.nodes) - 2)) and (self.curTree.nInternalNodes > 1):
gamma_B = 2.0
else:
gamma_B = 1.0
# n_e is number of internal edges present before the Add-edge move.
# That would be self.curTree.nInternalNodes - 1
n_e = float(self.curTree.nInternalNodes - 1)
# n_p is the number of polytomies present before the move,
# len(allPolytomies)
n_p = float(len(allPolytomies))
hastingsRatio = (gamma_B * n_p * float(nPossibleWays)) / (1.0 + n_e)
if dbug:
print("The new node is given a random branch length of %f" % newNode.br.len)
print("For the Hastings ratio ...")
print("gamma_B is %.1f" % gamma_B)
print("n_e is %.0f" % n_e)
print("k is (still) %i, and (2^{k-1} - k - 1) = nPossibleWays is still %i" % (k, nPossibleWays))
print("n_p = %.0f is the number of polytomies present before the move." % n_p)
print("So the hastings ratio is %f" % hastingsRatio)
self.logProposalRatio = math.log(hastingsRatio)
if 0:
priorRatio = theProposal.brLenPriorLambda * \
math.exp(- theProposal.brLenPriorLambda * newNode.br.len)
if dbug:
print("The theProposal.brLenPriorLambda is %f" % theProposal.brLenPriorLambda)
print("So the prior ratio is %f" % priorRatio)
self.logPriorRatio = math.log(priorRatio)
# The Jacobian
jacobian = 1.0 / (theProposal.brLenPriorLambda *
math.exp(- theProposal.brLenPriorLambda * newNode.br.len))
self.logJacobian = math.log(jacobian)
print("logPriorRatio = %f, logJacobian = %f" % (self.logPriorRatio, self.logJacobian))
# Here I pull a fast one, as explained in Lewis et al. The
# priorRatio and the Jacobian terms cancel out. So the logs might
# as well be zeros.
self.logPriorRatio = 0.0
#self.logJacobian = 0.0
# That was easy, wasn't it?
if theProposal.polytomyUseResolutionClassPrior:
# We are gaining a node. So the prior ratio is T_{n,m + 1} /
# (T_{n,m} * C) . We have the logs, and the result is the
# log.
if 0:
print("-" * 30)
print('curTree.nInternalNodes', self.curTree.nInternalNodes)
print('pTree.nInternalNodes', pTree.nInternalNodes)
print('logBigT[curTree.nInternalNodes]', theProposal.logBigT[self.curTree.nInternalNodes])
# print
# math.exp(theProposal.logBigT[self.curTree.nInternalNodes])
print('C ', theProposal.polytomyPriorLogBigC)
print('logBigT[pTree.nInternalNodes]', theProposal.logBigT[pTree.nInternalNodes])
# print math.exp(theProposal.logBigT[pTree.nInternalNodes])
print("-" * 30)
self.logPriorRatio = (theProposal.logBigT[self.curTree.nInternalNodes] -
(theProposal.polytomyPriorLogBigC +
theProposal.logBigT[pTree.nInternalNodes]))
else:
if theProposal.polytomyPriorLogBigC:
self.logPriorRatio = -theProposal.polytomyPriorLogBigC
else:
self.logPriorRatio = 0.0
# print "gaining a node, m %2i->%2i. logPriorRatio is %f" % (self.curTree.nInternalNodes,
# pTree.nInternalNodes, self.logPriorRatio)
def _getCandidateNodesForDeleteEdge(self):
pTree = self.propTree
nodesWithInternalEdges = [n for n in pTree.iterInternalsNoRoot()]
# Remove any that might violate constraints.
# if self.mcmc.constraints:
# nodesToRemove = []
# for n in nodesWithInternalEdges:
# if n.br.splitKey in self.mcmc.constraints.constraints:
# nodesToRemove.append(n)
# for n in nodesToRemove:
# nodesWithInternalEdges.remove(n)
return nodesWithInternalEdges
def proposeDeleteEdge(self, theProposal, candidateNodes):
dbug = False
pTree = self.propTree
# print "doing proposeDeleteEdge()"
if 0:
print("proposeDeleteEdge(), starting with this tree ...")
pTree.draw()
print("m There are %i internal nodes (before deleting the edge)." % pTree.nInternalNodes)
if not candidateNodes:
raise P4Error(
"proposeDeleteEdge() could not find a good node to attempt to delete.")
theChosenNode = random.choice(candidateNodes)
if dbug:
print("There are %i candidateNodes." % len(candidateNodes))
print("node nums %s" % [n.nodeNum for n in candidateNodes])
print("Randomly choose node %s" % theChosenNode.nodeNum)
theNewParent = theChosenNode.parent
theRightmostChild = theChosenNode.rightmostChild()
theLeftSib = theChosenNode.leftSibling()
if theLeftSib:
theLeftSib.sibling = theChosenNode.leftChild
else:
theNewParent.leftChild = theChosenNode.leftChild
for n in theChosenNode.iterChildren():
n.parent = theNewParent
theRightmostChild.sibling = theChosenNode.sibling
theChosenNode.wipe()
pTree.setPreAndPostOrder()
pTree._nInternalNodes -= 1
# print pTree.preOrder
# if dbug:
# pTree.draw()
# Hastings ratio. First calculate the gamma_D. If the current
# tree is fully resolved and the proposed tree is not the star
# tree, then gamma_D is 0.5
if (self.curTree.nInternalNodes == len(pTree.nodes) - 2) and pTree.nInternalNodes != 1:
gamma_D = 0.5
# If the proposed tree is the star tree and the current tree is not
# fully resolved
elif (self.curTree.nInternalNodes < len(pTree.nodes) - 2) and pTree.nInternalNodes == 1:
gamma_D = 2.
else:
gamma_D = 1.
# n_e is the number of internal edges in existence before the move,
# which would be nInternalNodes - 1
n_e = float(self.curTree.nInternalNodes - 1)
# nStar_p is the number of polytomies in the tree after the move.
nStar_p = 0
for n in pTree.iterInternalsNoRoot():
if n.getNChildren() > 2:
nStar_p += 1
if pTree.root.getNChildren() > 3:
nStar_p += 1
nStar_p = float(nStar_p)
# kStar is the number of edges emanating from the polytomy created (or
# enlarged) by the move.
kStar = theNewParent.getNChildren()
if theNewParent.parent:
kStar += 1
hastingsRatio = (gamma_D * n_e) / \
(nStar_p * (2 ** (kStar - 1) - kStar - 1))
self.logProposalRatio = math.log(hastingsRatio)
if 0:
# Now the prior ratio. The prior probability density f(nu) for a
# branch length is lambda * exp(-lambda * nu). To a first
# approximation, with equal priors on topologies, the prior ratio
# is 1/f(nu)
priorRatio = 1.0 / (theProposal.brLenPriorLambda *
math.exp(- theProposal.brLenPriorLambda * theChosenNode.br.len))
if dbug:
print("The theProposal.brLenPriorLambda is %f" % theProposal.brLenPriorLambda)
print("So the prior ratio is %f" % priorRatio)
self.logPriorRatio = math.log(priorRatio)
# The Jacobian
jacobian = theProposal.brLenPriorLambda * \
math.exp(- theProposal.brLenPriorLambda *
theChosenNode.br.len)
self.logJacobian = math.log(jacobian)
print("logPriorRatio = %f, logJacobian = %f" % (self.logPriorRatio, self.logJacobian))
# Here I pull a fast one, as explained in Lewis et al. The
# priorRatio and the Jacobian terms cancel out. So the logs might
# as well be zeros.
self.logPriorRatio = 0.0
#self.logJacobian = 0.0
# That was easy, wasn't it?
if theProposal.polytomyUseResolutionClassPrior:
# We are losing a node. So the prior ratio is (T_{n,m} * C) /
# T_{n,m - 1}. We have the logs, and the result is the log.
if 0:
print("-" * 30)
print('curTree.nInternalNodes', self.curTree.nInternalNodes)
print('pTree.nInternalNodes', pTree.nInternalNodes)
print('logBigT[curTree.nInternalNodes]', theProposal.logBigT[self.curTree.nInternalNodes])
# print
# math.exp(theProposal.logBigT[self.curTree.nInternalNodes])
print('C ', theProposal.polytomyPriorLogBigC)
print('logBigT[pTree.nInternalNodes]', theProposal.logBigT[pTree.nInternalNodes])
# print math.exp(theProposal.logBigT[pTree.nInternalNodes])
print("-" * 30)
self.logPriorRatio = ((theProposal.logBigT[self.curTree.nInternalNodes] +
theProposal.polytomyPriorLogBigC) -
theProposal.logBigT[pTree.nInternalNodes])
else:
if theProposal.polytomyPriorLogBigC:
self.logPriorRatio = theProposal.polytomyPriorLogBigC
else:
self.logPriorRatio = 0.0
# print " losing a node, m %2i->%2i. logPriorRatio is %f" % (self.curTree.nInternalNodes,
# pTree.nInternalNodes, self.logPriorRatio)
def propose(self, theProposal):
gm = ['STChain.propose()']
#print("propose() About to propose %s" % theProposal.name)
if theProposal.name == 'nni':
# self.proposeNni(theProposal)
self.propTree.nni() # this does setPreAndPostOrder()
if theProposal.doAbort:
pass
# else:
# if not self.propTree.preAndPostOrderAreValid: # not needed
# self.propTree.setPreAndPostOrder()
elif theProposal.name == 'spr':
self.propTree.randomSpr()
if theProposal.doAbort:
pass
else:
if not self.propTree.preAndPostOrderAreValid:
self.propTree.setPreAndPostOrder()
elif theProposal.name == 'SR2008beta_uniform':
mt = self.propTree.beta
# Slider proposal
mt += (random.random() - 0.5) * theProposal.tuning[self.tempNum]
# Linear reflect
isGood = False
myMIN = 1.e-10
myMAX = 1.e+10
while not isGood:
if mt < myMIN:
mt = (myMIN - mt) + myMIN
elif mt > myMAX:
mt = myMAX - (mt - myMAX)
else:
isGood = True
self.propTree.beta = mt
self.logProposalRatio = 0.0
self.logPriorRatio = 0.0
elif theProposal.name == 'spaQ_uniform':
mt = self.propTree.spaQ[0]
originally = mt
# Slider proposal
mt += (random.random() - 0.5) * theProposal.tuning[self.tempNum]
# Linear reflect
isGood = False
myMIN = 1.e-10
myMAX = 1.
while not isGood:
if mt < myMIN:
mt = (myMIN - mt) + myMIN
elif mt > myMAX:
mt = myMAX - (mt - myMAX)
else:
isGood = True
self.propTree.spaQ[0] = mt
self.logProposalRatio = 0.0
if theProposal.spaQPriorType == 'flat':
self.logPriorRatio = 0.0
elif theProposal.spaQPriorType == 'exponential':
self.logPriorRatio = theProposal.spaQExpPriorLambda * (originally - mt)
else:
raise P4Error("this should not happen! wxyzz")
# print "proposing mt from %.3f to %.3f, diff=%g" % (originally,
# mt, mt-originally)
elif theProposal.name == 'polytomy':
self.proposePolytomy(theProposal)
if not self.propTree.preAndPostOrderAreValid:
self.propTree.setPreAndPostOrder()
# self.propTree.draw()
else:
gm.append('Unlisted proposal.name=%s Fix me.' % theProposal.name)
raise P4Error(gm)
if theProposal.doAbort:
return 0.0
# print "...about to calculate the likelihood of the propTree.
# Model %s" % self.stMcmc.modelName
if self.stMcmc.modelName.startswith('SR2008_rf'):
if self.stMcmc.stRFCalc == 'fastReducedRF':
self.getTreeLogLike_fastReducedRF()
elif self.stMcmc.stRFCalc == 'purePython1':
self.getTreeLogLike_ppy1()
elif self.stMcmc.stRFCalc == 'bitarray':
self.refreshBitarrayPropTree()
self.getTreeLogLike_bitarray()
elif self.stMcmc.modelName == 'SPA':
self.refreshBitarrayPropTree()
if var.stmcmc_useFastSpa:
if 0: # check
self.getTreeLogLike_spa_bitarray()
#print("Here F. bitarray propTree.logLike is %f" % self.propTree.logLike)
fspaLike = self.stMcmc.fspa.calcLogLike(self.chNum)
diff = math.fabs(self.propTree.logLike - fspaLike)
#print("Got fspaLike %f, diff %g" % (fspaLike, diff))
if diff > 1e-13:
gm.append("gen %i bad fastspa likelihood calc, %f vs %f, diff %f" % (
self.stMcmc.gen, self.propTree.logLike, fspaLike, diff))
raise P4Error(gm)
else:
self.propTree.logLike = self.stMcmc.fspa.calcLogLike(self.chNum)
else:
self.getTreeLogLike_spa_bitarray()
elif self.stMcmc.modelName == 'QPA':
self.getTreeLogLike_qpa_slow()
else:
gm.append('Unknown model %s' % self.stMcmc.modelName)
raise P4Error(gm)
# if theProposal.name == 'polytomy':
#print("propTree logLike is %f, curTree logLike is %f" % (
# self.propTree.logLike, self.curTree.logLike))
#myDist = self.propTree.topologyDistance(self.curTree)
# print "myDist %2i, propTree.logLike %.3f curTree.logLike %.3f "
# % (myDist, self.propTree.logLike, self.curTree.logLike)
logLikeRatio = self.propTree.logLike - self.curTree.logLike
# print logLikeRatio
# To run "without the data", which shows the effect of priors.
#logLikeRatio = 0.0
# Mcmcmc
if self.stMcmc.nChains > 1:
if self.stMcmc.swapVector:
heatBeta = 1.0 / (1.0 + self.stMcmc.chainTemps[self.tempNum])
else:
heatBeta = 1.0 / (1.0 + self.stMcmc.chainTemp * self.tempNum)
logLikeRatio *= heatBeta
self.logPriorRatio *= heatBeta
#print("propose(). chainTemp=%s, heatBeta=%f" % (self.stMcmc.chainTemps, heatBeta))
# Experimental Heating hack
if self.stMcmc.doHeatingHack: # and theProposal.name in self.stMcmc.heatingHackProposalNames:
heatFactor = 1.0 / (1.0 + self.stMcmc.heatingHackTemperature)
logLikeRatio *= heatFactor
self.logPriorRatio *= heatFactor
theSum = logLikeRatio + self.logProposalRatio + self.logPriorRatio
#theSum = self.logProposalRatio + self.logPriorRatio
# if theProposal.name == 'polytomy':
# print "%f %f %f %f" % (theSum, logLikeRatio,
# self.logProposalRatio, self.logPriorRatio)
return theSum
def gen(self, aProposal):
gm = ['STChain.gen()']
# doAborts means that it was not a valid generation,
# neither accepted or rejected. Give up, by returning True.
acceptMove = False
# print "Doing %s" % aProposal.name
pRet = self.propose(aProposal)
#if self.tempNum == 0:
# print(self.propTree.postOrder)
# print "pRet = %.6f" % pRet,
if not aProposal.doAbort:
if pRet < -100.0: # math.exp(-100.) is 3.7200759760208361e-44
r = 0.0
elif pRet >= 0.0:
r = 1.0
else:
r = math.exp(pRet)
if r == 1.0:
acceptMove = True
elif random.random() < r:
acceptMove = True
# if aProposal.name == 'polytomy':
# print "acceptMove = %s" % acceptMove
# print "------------"
# print " %6.0f" % pRet
if 0 and acceptMove:
d1 = self.propTree.topologyDistance(self.curTree, metric='scqdist')
d2 = self.stMcmc.tree.topologyDistance(
self.propTree, metric='scqdist')
print(" %6.0f %5i %5i %5s" % (pRet, d1, d2, acceptMove))
aProposal.nProposals[self.tempNum] += 1
aProposal.tnNSamples[self.tempNum] += 1
if acceptMove:
aProposal.accepted = True
aProposal.nAcceptances[self.tempNum] += 1
aProposal.tnNAccepts[self.tempNum] += 1
# if not aProposal.doAbort:
if acceptMove:
a = self.propTree
b = self.curTree
else:
a = self.curTree
b = self.propTree
if aProposal.name in ['nni', 'spr', 'polytomy']:
b.logLike = a.logLike
a.copyToTree(b)
elif aProposal.name in ['SR2008beta_uniform']:
b.logLike = a.logLike
b.beta = a.beta
elif aProposal.name in ['spaQ_uniform']:
b.logLike = a.logLike
b.spaQ[0] = a.spaQ[0]
else:
gm.append('Unlisted proposal.name = %s Fix me.' % aProposal.name)
raise P4Error(gm)
# for proposal probs
fudgeFactor = {}
fudgeFactor['nni'] = 1.0
fudgeFactor['spr'] = 1.0
fudgeFactor['SR2008beta_uniform'] = 0.1
fudgeFactor['spaQ_uniform'] = 0.1
fudgeFactor['polytomy'] = 0.5
class STMcmcTunings(object):
def __init__(self):
self.default = {}
self.default['SR2008beta_uniform'] = 0.2
self.default['spaQ_uniform'] = 0.1
class STMcmcProposalProbs(dict):
"""User-settable relative proposal probabilities.
An instance of this class is made as STMcmc.prob, where you can
do, for example,
yourSTMcmc.prob.nni = 2.0
These are relative proposal probs, that do not sum to 1.0, and
affect the calculation of the final proposal probabilities (ie the
kind that do sum to 1). It is a relative setting, and the default
is 1.0. Setting it to 0 turns it off. For small
probabilities, setting it to 2.0 doubles it. For bigger
probabilities, setting it to 2.0 makes it somewhat bigger.
Check the effect that it has by doing a
yourSTMcmc.writeProposalIntendedProbs()
which prints out the final calculated probabilities.
"""
def __init__(self):
object.__setattr__(self, 'nni', 1.0)
object.__setattr__(self, 'spr', 1.0)
object.__setattr__(self, 'SR2008beta_uniform', 1.0)
object.__setattr__(self, 'spaQ_uniform', 1.0)
object.__setattr__(self, 'polytomy', 0.0)
def __setattr__(self, item, val):
# complaintHead = "\nSTMcmcProposalProbs.__setattr__()"
gm = ["\nSTMcmcProposalProbs(). (set %s to %s)" % (item, val)]
theKeys = self.__dict__.keys()
if item in theKeys:
try:
val = float(val)
if val < 1e-9:
val = 0
object.__setattr__(self, item, val)
except:
gm.append("Should be a float. Got '%s'" % val)
raise P4Error(gm)
else:
self.dump()
gm.append(" Can't set '%s'-- no such proposal." % item)
raise P4Error(gm)
def reprString(self):
stuff = ["\nUser-settable relative proposal probabilities, from yourStMcmc.prob"]
stuff.append(" To change it, do eg ")
stuff.append(" yourSTMcmc.prob.spaQ_uniform = 0.0 # turns spaQ_uniform proposals off")
stuff.append(" Current settings:")
theKeys = list(self.__dict__.keys())
theKeys.sort()
for k in theKeys:
stuff.append(" %20s: %s" % (k, getattr(self, k)))
return '\n'.join(stuff)
def dump(self):
print(self.reprString())
def __repr__(self):
return self.reprString()
class STProposal(object):
def __init__(self, theSTMcmc=None):
self.name = None
self.stMcmc = theSTMcmc # reference loop
self.nChains = theSTMcmc.nChains
self.pNum = -1
self.weight = 1.0
self.tuning = None
self.tunings = {}
self.nProposals = [0] * self.nChains
self.nAcceptances = [0] * self.nChains
self.accepted = 0
self.doAbort = False
self.nAborts = [0] * self.nChains
self.tnSampleSize = 250
self.tnNSamples = [0] * theSTMcmc.nChains
self.tnNAccepts = [0] * theSTMcmc.nChains
self.tnAccVeryHi = None
self.tnAccHi = None
self.tnAccLo = None
self.tnAccVeryLo = None
self.tnFactorVeryHi = None
self.tnFactorHi = None
self.tnFactorLo = None
self.tnFactorVeryLo = None
self.tnFactorZero = None
def dump(self):
print("proposal name=%-10s pNum=%2i, weight=%5.1f, tuning=%s" % (
self.name, self.pNum, self.weight, self.tuning))
#print(" nProposals by temperature: %s" % self.nProposals)
#print(" nAcceptances by temperature: %s" % self.nAcceptances)
# def _getTuning(self):
# if self.name in ['nni', 'spr', 'SR2008beta_uniform', 'spaQ_uniform']:
# # print "getting tuning for %s, returning %f" % (self.name, getattr(self.mcmc.tunings, self.name))
# # print self.stMcmc.tunings
# return getattr(self.stMcmc.tunings, self.name)
# else:
# return None
# def _setTuning(self, whatever):
# raise P4Error("Can't set tuning this way.")
# def _delTuning(self):
# raise P4Error("Can't del tuning.")
# tuning = property(_getTuning, _setTuning, _delTuning)
def tune(self, tempNum):
assert self.tnSampleSize >= 100.
assert self.tnNSamples[tempNum] >= self.tnSampleSize
acc = float(self.tnNAccepts[tempNum]) / self.tnNSamples[tempNum] # float() for Py2
doMessage = False
if acc > self.tnAccHi:
oldTn = self.tuning[tempNum]
if acc > self.tnAccVeryHi:
self.tuning[tempNum] *= self.tnFactorVeryHi
else:
self.tuning[tempNum] *= self.tnFactorHi
doMessage = True
elif acc < self.tnAccLo:
oldTn = self.tuning[tempNum]
if acc < self.tnAccVeryLo:
self.tuning[tempNum] *= self.tnFactorVeryLo
else:
self.tuning[tempNum] *= self.tnFactorLo
doMessage = True
self.tnNSamples[tempNum] = 0
self.tnNAccepts[tempNum] = 0
if doMessage:
message = "%s tune gen=%i tempNum=%i acceptance=%.3f " % (self.name, self.stMcmc.gen, tempNum, acc)
message += "(target %.3f -- %.3f) " % (self.tnAccLo, self.tnAccHi)
message += "Adjusting tuning from %g to %g" % (oldTn, self.tuning[tempNum])
#print(message)
self.stMcmc.logger.info(message)
class Proposals(object):
def __init__(self):
self.proposals = []
self.proposalsDict = {}
self.propWeights = []
self.cumPropWeights = []
self.totalPropWeights = 0.0
self.intended = None
def summary(self):
print("There are %i proposals" % len(self.proposals))
for p in self.proposals:
print("proposal name=%-10s pNum=%2s, weight=%s, tuning=%s" % (
'%s,' % p.name, p.pNum, p.weight, p.tuning))
def calculateWeights(self):
gm = ["Proposals.calculateWeights()"]
self.propWeights = []
for p in self.proposals:
#print("%s: %s" % (p.name, p.weight))
self.propWeights.append(p.weight)
#print(self.propWeights)
self.cumPropWeights = [self.propWeights[0]]
for i in range(len(self.propWeights))[1:]:
self.cumPropWeights.append(
self.cumPropWeights[i - 1] + self.propWeights[i])
self.totalPropWeights = sum(self.propWeights)
if self.totalPropWeights < 1e-9:
gm.append("No proposal weights?")
raise P4Error(gm)
self.intended = self.propWeights[:]
for i in range(len(self.intended)):
self.intended[i] /= self.totalPropWeights
if math.fabs(sum(self.intended) - 1.0 > 1e-14):
raise P4Error("bad sum of intended proposal probs. %s" % sum(self.intended))
#print(self.intended)
def chooseProposal(self, equiProbableProposals):
if equiProbableProposals:
return random.choice(self.proposals)
else:
theRan = random.uniform(0.0, self.totalPropWeights)
for i in range(len(self.cumPropWeights)):
if theRan < self.cumPropWeights[i]:
break
return self.proposals[i]
def writeProposalIntendedProbs(self):
"""Tabulate the intended proposal probabilities"""
spacer = ' ' * 4
print("\nIntended proposal probabilities (%)")
print("There are %i proposals" % len(self.proposals))
print("%2s %11s %30s %5s %12s" % ('', 'intended(%)', 'proposal', 'part', 'tuning'))
for i in range(len(self.proposals)):
print("%2i" % i, end=' ')
p = self.proposals[i]
print(" %6.2f " % (100. * self.intended[i]), end=' ')
print(" %27s" % p.name, end=' ')
if p.pNum != -1:
print(" %3i " % p.pNum, end=' ')
else:
print(" - ", end=' ')
if p.tuning == None:
print(" %12s "% ' - ', end=' ')
else:
if p.tuning[0] < 0.1:
print(" %12.4g" % p.tuning[0], end=' ')
elif p.tuning[0] < 1.0:
print(" %12.4f" % p.tuning[0], end=' ')
elif p.tuning[0] < 10.0:
print(" %12.3f" % p.tuning[0], end=' ')
elif p.tuning[0] < 1000.0:
print(" %12.1f" % p.tuning[0], end=' ')
else:
print(" %12.2g " % p.tuning[0], end=' ')
print()
def writeTunings(self):
print("Proposal tunings:")
print("%20s %12s" % ("proposal name", "tuning"))
for p in self.proposals:
print("%20s" % p.name, end=' ')
if p.tuning:
# if p.tuning < 10.0:
# print("%12.3f" % p.tuning, end=' ')
# else:
# print("%12.1f" % p.tuning, end=' ')
print(p.tuning)
else:
print(" %4s " % '-', end=' ')
print()
class SwapTuner(object):
"""Continuous tuning for swap temperature"""
def __init__(self, sampleSize):
assert sampleSize >= 100
self.sampleSize = sampleSize
self.swaps01_nAttempts = 0
self.swaps01_nSwaps = 0
self.tnAccVeryHi = 0.18
self.tnAccHi = 0.12
self.tnAccLo = 0.04
self.tnAccVeryLo = 0.01
self.tnFactorVeryHi = 1.4
self.tnFactorHi = 1.2
self.tnFactorLo = 0.9
self.tnFactorVeryLo = 0.6
self.tnFactorZero = 0.4
def tune(self, theMcmc):
assert self.swaps01_nAttempts >= self.sampleSize
acc = float(self.swaps01_nSwaps) / self.swaps01_nAttempts # float() for Py2
#print("SwapTuner.tune() nSwaps %i, nAttemps %i, acc %s" % (
# self.swaps01_nSwaps, self.swaps01_nAttempts, acc))
doMessage = False
direction = None
if acc > self.tnAccHi:
oldTn = theMcmc.chainTemp
if acc > self.tnAccVeryHi:
theMcmc.chainTemp *= self.tnFactorVeryHi
else:
theMcmc.chainTemp *= self.tnFactorHi
doMessage = True
direction = 'Increase'
elif acc < self.tnAccLo:
oldTn = theMcmc.chainTemp
if acc == 0.0: # no swaps at all
theMcmc.chainTemp *= self.tnFactorZero
elif acc < self.tnAccVeryLo:
theMcmc.chainTemp *= self.tnFactorVeryLo
else:
theMcmc.chainTemp *= self.tnFactorLo
doMessage = True
direction = 'Decrease'
self.swaps01_nAttempts = 0
self.swaps01_nSwaps = 0
if doMessage:
message = "%s tune gen=%i acceptance=%.3f " % ('chainTemp', theMcmc.gen, acc)
message += "(target %.3f -- %.3f) " % (self.tnAccLo, self.tnAccHi)
message += "%s chainTemp from %g to %g" % (direction, oldTn, theMcmc.chainTemp)
#print(message)
theMcmc.logger.info(message)
class STSwapTunerV(object):
"""Continuous tuning for swap temperature"""
def __init__(self, theMcmc):
assert var.mcmc_swapTunerSampleSize >= 100
self.mcmc = theMcmc
self.nChains = self.mcmc.nChains
# These are for adjacent pairs. Eg for attempts between chains 0 and 1,
# we increment self.nAttempts[0], ie it is indexed with the lower number
# in the pair.
self.nAttempts = [0] * self.nChains
self.nSwaps = [0] * self.nChains
self.tnAccVeryHi = 0.30
self.tnAccHi = 0.25
self.tnAccLo = 0.10
self.tnAccVeryLo = 0.05
self.tnFactorVeryHi = 1.4
self.tnFactorHi = 1.2
self.tnFactorLo = 0.9
self.tnFactorVeryLo = 0.6
self.tnFactorZero = 0.4
self.tnLimitHi = 10.0
self.tnLimitLo = 0.2
def tune(self, theTempNum):
assert self.nAttempts[theTempNum] >= var.mcmc_swapTunerSampleSize
acc = float(self.nSwaps[theTempNum]) / self.nAttempts[theTempNum] # float() for Py2
# print("STSwapTunerV.tune() theTempNum %i, nSwaps %i, nAttemps %i, acc %s" % (
# theTempNum, self.nSwaps[theTempNum], self.nAttempts[theTempNum], acc))
# print("tempDiffs %s" % self.mcmc.chainTempDiffs)
# print("temps %s" % self.mcmc.chainTemps)
doMessage = False
direction = None
oldTn = self.mcmc.chainTempDiffs[theTempNum]
if acc > self.tnAccHi:
if self.mcmc.chainTempDiffs[theTempNum] >= self.tnLimitHi:
direction = "no change"
else:
if acc > self.tnAccVeryHi:
self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorVeryHi
else:
self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorHi
doMessage = True
direction = 'Increase'
elif acc < self.tnAccLo:
if self.mcmc.chainTempDiffs[theTempNum] <= self.tnLimitLo:
direction = "no change"
else:
if acc == 0.0: # no swaps at all
self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorZero
elif acc < self.tnAccVeryLo:
self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorVeryLo
else:
self.mcmc.chainTempDiffs[theTempNum] *= self.tnFactorLo
doMessage = True
direction = 'Decrease'
self.nAttempts[theTempNum] = 0
self.nSwaps[theTempNum] = 0
if direction != "no change":
if doMessage:
message = "%s tune gen=%i tempNum=%i acceptance=%.3f " % ('chainTemp', self.mcmc.gen, theTempNum, acc)
message += "(target %.3f -- %.3f) " % (self.tnAccLo, self.tnAccHi)
message += "%s chainTempDiff from %g to %g" % (direction, oldTn, self.mcmc.chainTempDiffs[theTempNum])
#print(message)
self.mcmc.logger.info(message)
# Make chainTemps from chainTempDiffs
self.mcmc.chainTemps = [0.0]
for dNum in range(self.mcmc.nChains - 1):
self.mcmc.chainTemps.append(self.mcmc.chainTempDiffs[dNum] + self.mcmc.chainTemps[-1])
if doMessage:
message = "new chainTemps gen=%i " % (self.mcmc.gen)
for cT in self.mcmc.chainTemps:
message += "%10.2f" % cT
self.mcmc.logger.info(message)
class BigTSplitStuff(object):
# An organizer for splits on STMcmc.tree (ie bigT) internal nodes, only
# for use with bitarray
def __init__(self):
self.spl = None
self.spl2 = None
self.theSpl = None
self.maskedSplitWithFirstTaxOne = None
self.onesCount = None
self.bytes = None
def dump(self):
print("ss: spl=%s, spl2=%s, masked=%s, onesCount=%s" % (
self.spl, self.spl2, self.maskedSplitWithFirstTaxOne, self.onesCount))
[docs]
class STMcmc(object):
"""An MCMC for making supertrees from a set of input trees.
This week, it implements the Steel and Rodrigo 2008 model, with the
alpha calculation using the approximation in Bryant and Steel 2009.
**Arguments**
inTrees
A list of p4 tree objects. You could just use ``var.trees``.
modelName
The SR2008 models implemented here are based on the Steel and
Rodrigo 2008 description of a likelihood model, "Maximum
likelihood supertrees" Syst. Biol. 57(2):243--250, 2008. At
the moment, they are all SR2008_rf, meaning that they use
Robinson-Foulds distances.
SR2008_rf_ia
Here 'ia' means 'ignore alpha'. The alpha values are not
calculated at all, as they are presumed (erroneously, but
not too badly) to cancel out.
SR2008_rf_aZ
This uses the approximation for Z_T = alpha^{-1} as described
in Equation 30 in the Bryant and Steel paper "Computing the
distribution of a tree metric" in IEEE/ACM Transactions on
computational biology and bioinformatics, VOL. 6, 2009.
SR2008_rf_aZ_fb
This is as SR2008_rf_aZ above, but additionally it allows
beta to be a free parameter, and it is sampled. Samples
are written to mcmc_prams* files.
beta
This only applies to SR2008. The beta is the weight as
given in Steel and Rodrigo 2008. By default it is 1.0.
stRFCalc
There are three ways to calculate the RF distances and
likelihood, for these SR2008_rf models above --- all giving
the same answer.
1. purePython1. Slow.
2. bitarray, using the bitarray module. About twice as fast
as purePython1
3. fastReducedRF, written in C++ using boost and ublas.
About 10 times faster than purePython1, but perhaps a bit
of a bother to get going. It needs the fastReducedRF
module, included in the p4 source code.
It is under control of the argument stRFCalc, which can be one
of 'purePython1', 'bitarray', and 'fastReducedRF'. By default
it is purePython1, so you may want to at least install
bitarray.
runNum
You may want to do more than one 'run' in the same directory,
to facilitate convergence testing. The first runNum would be
0, and samples, likelihoods, and checkPoints are written to
files with that number.
sampleInterval
Interval at which the chain is sampled, including writing a tree,
and the logLike. Plan to get perhaps 1000 samples; so if you are
planning to make a run of 10000 generations then you might set
sampleInterval=10.
checkPointInterval
Interval at which checkpoints are made. If set to None (the
default) it means don't make checkpoints. My taste is to aim to
make perhaps 2 to 4 per run. So if you are planning to start out
with a run of 10000 generations, you could set
checkPointInterval=5000, which will give you 2 checkpoints. See
more about checkpointing below.
To prepare for a run, instantiate an Mcmc object, for example::
m = STMcmc(treeList, modelName='SR2008_rf_aZ_fb', stRFCalc='fastReducedRF', sampleInterval=10)
To start it running, do this::
# Tell it the number of generations to do
m.run(10000)
As it runs, it saves trees and likelihoods at sampleInterval
intervals (actually whenever the current generation number is
evenly divisible by the sampleInterval).
**CheckPoints**
Whenever the current generation number is evenly divisible by the
checkPointInterval it will write a checkPoint file. A checkPoint
file is the whole MCMC, pickled. Using a checkPoint, you can
re-start an STMcmc from the point you left off. Or, in the event
of a crash, you can restart from the latest checkPoint. But the
most useful thing about them is that you can query checkPoints to
get information about how the chain has been running, and about
convergence diagnostics.
In order to restart the MCMC from the end of a previous run::
# read the last checkPoint file
m = func.unPickleSTMcmc(0) # runNum 0
m.run(20000)
Its that easy if your previous run finished properly. However, if
your previous run has crashed and you want to restart it from a
checkPoint, then you will need to repair the sample output files
to remove samples that were taken after the last checkPoint, but
before the crash. Fix the trees, likelihoods, prams, and sims.
(You probably do not need to beware of confusing gen (eg 9999) and
gen+1 (eg 10000) issues.) When you remove trees from the tree
files be sure to leave the 'end;' at the end-- p4 needs it, and
will deal with it.
The checkPoints can help with convergence testing. To help with
that, you can use the STMcmcCheckPointReader class. It will print
out a table of average standard deviations of split supports
between 2 runs, or between 2 checkPoints from the same run. It
will print out tables of proposal acceptances to show whether they
change over the course of the MCMC.
**Making a consensus tree**
See :class:`TreePartitions`.
"""
def __init__(self, inTrees, bigT=None, modelName='SR2008_rf_aZ',
beta=1.0, spaQ=0.5, stRFCalc='purePython1',
nChains=1, runNum=0, sampleInterval=100,
checkPointInterval=None, useSplitSupport=False, verbose=True,
checkForOutputFiles=True, swapTuner=250):
import p4.func # This should not be needed, but it is. Why?
#print(p4.func)
gm = ['STMcmc.__init__()']
assert inTrees
for t in inTrees:
assert isinstance(t, Tree)
if bigT:
assert isinstance(bigT, Tree)
assert bigT.taxNames
bigT.stripBrLens()
for n in bigT.iterInternalsNoRoot():
n.name = None
goodModelNames = ['SR2008_rf_ia', 'SR2008_rf_aZ', 'SR2008_rf_aZ_fb',
'SPA', 'QPA']
if modelName not in goodModelNames:
gm.append("Arg modelName '%s' is not recognized. " % modelName)
gm.append("Good modelNames are %s" % goodModelNames)
raise P4Error(gm)
self.modelName = modelName
self.tree = None
self.stRFCalc = None
if modelName.startswith("SR2008"):
try:
fBeta = float(beta)
except ValueError:
gm.append("Arg beta (%s) should be a float" % beta)
raise P4Error(gm)
self.beta = fBeta
for t in inTrees:
if t.isFullyBifurcating():
pass
else:
gm.append("The SR2008 model wants trees that are fully bifurcating.")
raise P4Error(gm)
goodSTRFCalcNames = ['purePython1', 'bitarray', 'fastReducedRF']
if stRFCalc not in goodSTRFCalcNames:
gm.append("Arg stRFCalc '%s' is not recognized. " % modelName)
gm.append("Good stRFCalc names are %s" % goodSTRFCalcNames)
raise P4Error(gm)
self.stRFCalc = stRFCalc
try:
nChains = int(nChains)
except (ValueError, TypeError):
gm.append("nChains should be an int, 1 or more. Got %s" % nChains)
raise P4Error(gm)
if nChains < 1:
gm.append("nChains should be an int, 1 or more. Got %s" % nChains)
raise P4Error(gm)
self.nChains = nChains
self.chains = []
self.gen = -1
self.startMinusOne = -1
self.chainTemp = 1.0
self.constraints = None
self.simulate = None
# spaQ is a property. Whenever it is set, it is propagated to all the chains.
self._spaQ = None
if modelName in ['SPA', 'QPA']:
try:
self._spaQ = float(spaQ)
except ValueError:
gm.append("Arg spaQ (%s) should be a float" % spaQ)
raise P4Error(gm)
self.spaQ = self._spaQ
try:
runNum = int(runNum)
except (ValueError, TypeError):
gm.append("runNum should be an int, 0 or more. Got %s" % runNum)
raise P4Error(gm)
if runNum < 0:
gm.append("runNum should be an int, 0 or more. Got %s" % runNum)
raise P4Error(gm)
self.runNum = runNum
self._setLogger()
if checkForOutputFiles:
# Check that we are not going to over-write good stuff
ff = os.listdir(os.getcwd())
hasPickle = False
for fName in ff:
if fName.startswith("mcmc_checkPoint_%i." % self.runNum):
hasPickle = True
break
if hasPickle:
gm.append("runNum is set to %i" % self.runNum)
gm.append("There is at least one mcmc_checkPoint_%i.xxx file in this directory." % self.runNum)
gm.append("This is a new STMcmc, and I am refusing to over-write exisiting files.")
gm.append("Maybe you want to re-start from the latest mcmc_checkPoint_%i file?" % self.runNum)
gm.append("Otherwise, get rid of the existing mcmc_xxx_%i.xxx files and start again." % self.runNum)
raise P4Error(gm)
if var.strictRunNumberChecking:
# We want to start runs with number 0, so if runNum is more than
# that, check that there are other runs.
if self.runNum > 0:
for runNum2 in range(self.runNum):
hasTrees = False
for fName in ff:
if fName.startswith("mcmc_trees_%i" % runNum2):
hasTrees = True
break
if not hasTrees:
gm.append("runNum is set to %i" % self.runNum)
gm.append("runNums should go from zero up.")
gm.append("There are no mcmc_trees_%i.nex files to show that run %i has been done." % (
runNum2, runNum2))
gm.append("Set the runNum to that, first.")
gm.append("Or else turn var.strictRunNumberChecking off to prevent checking.")
raise P4Error(gm)
self.sampleInterval = sampleInterval
self.checkPointInterval = checkPointInterval
self.props = Proposals()
self.tunableProps = """SR2008beta_uniform spaQ_uniform""".split()
# maybeTunableButNotNow polytomy
self.treePartitions = None
self.likesFileName = "mcmc_likes_%i" % runNum
self.treeFileName = "mcmc_trees_%i.nex" % runNum
#self.simFileName = "mcmc_sims_%i" % runNum
self.pramsFileName = "mcmc_prams_%i" % runNum
self.writePrams = False
if self.modelName in ['SR2008_rf_aZ_fb', "SPA", "QPA"]:
self.writePrams = True
self.lastTimeCheck = None
if self.nChains > 1:
self.swapMatrix = []
for i in range(self.nChains):
self.swapMatrix.append([0] * self.nChains)
# if self.swapVector:
# self.swapTuner = STSwapTunerV(self)
# else:
# if swapTuner: # a kwarg
# myST = int(swapTuner)
# if myST >= 100:
# self.swapTuner = SwapTuner(myST)
# else:
# gm.append("The swapTuner kwarg, the sample size, should be at least 100. Got %i." % myST)
# raise P4Error(gm)
# else:
# self.swapTuner = None
else:
self.swapMatrix = None
self.swapTuner = None
self._tunings = STMcmcTunings()
self.polytomyUseResolutionClassPrior = False
self.polytomyPriorLogBigC = 0.0
self.prob = STMcmcProposalProbs()
if self.modelName in ['SPA', 'QPA']:
self.prob.polytomy = 1.0
self.prob.spr = 0.0
# Zap internal node names
# for n in aTree.root.iterInternals():
# if n.name:
# n.name = None
if not bigT:
allNames = set()
for t in inTrees:
t.unsorted_taxNames = [n.name for n in t.iterLeavesNoRoot()]
# Get the union of a set and other stuff using set.update(stuff).
allNames.update(t.unsorted_taxNames)
self.taxNames = list(allNames)
# not needed, but nice for debugging
self.taxNames.sort()
else:
for t in inTrees:
t.unsorted_taxNames = [n.name for n in t.iterLeavesNoRoot()]
self.taxNames = bigT.taxNames
# print self.taxNames
self.nTax = len(self.taxNames)
if self.modelName in ['SPA'] or self.stRFCalc == 'bitarray':
# print "self.taxNames = ", self.taxNames
for t in inTrees:
# print "-" * 50
# t.draw()
sorted_taxNames = []
t.baTaxBits = []
for tNum in range(self.nTax):
tN = self.taxNames[tNum]
if tN in t.unsorted_taxNames:
sorted_taxNames.append(tN)
t.baTaxBits.append(True)
else:
t.baTaxBits.append(False)
t.taxNames = sorted_taxNames
t.baTaxBits = bitarray.bitarray(t.baTaxBits)
t.firstTax = t.baTaxBits.index(1)
# print "intree baTaxBits is %s" % t.baTaxBits
# print "intree firstTax is %i" % t.firstTax
# Can't use Tree.makeSplitKeys(), unfortunately. So
# make split keys here. STMcmc.tBits is only used for
# the leaves, here and in
# STChain.setupBitarrayCalcs(), and there only once,
# during STChain.__init__(). So probably does not
# need to be an instance attribute. Maybe delete?
self.tBits = [False] * self.nTax
for n in t.iterPostOrder():
if n == t.root:
break
if n.isLeaf:
spot = self.taxNames.index(n.name)
self.tBits[spot] = True
n.stSplitKey = bitarray.bitarray(self.tBits)
self.tBits[spot] = False
else:
n.stSplitKey = n.leftChild.stSplitKey.copy()
p = n.leftChild.sibling
while p:
n.stSplitKey |= p.stSplitKey # "or", in-place
p = p.sibling
# print "setting node %i stSplitKey to %s" %
# (n.nodeNum, n.stSplitKey)
if self.stRFCalc == 'bitarray':
t.splSet = set()
for n in t.iterInternalsNoRoot():
# make sure splitKey[firstTax] is a '1'
if not n.stSplitKey[t.firstTax]:
n.stSplitKey.invert()
n.stSplitKey &= t.baTaxBits # 'and', in-place
# print "inverting and and-ing node %i stSplitKey
# to %s" % (n.nodeNum, n.stSplitKey)
# bytes so that I can use it as a set element
t.splSet.add(n.stSplitKey.tobytes())
if self.modelName in ['SPA']:
t.internals = []
for n in t.iterInternalsNoRoot():
# make sure splitKey[firstTax] is a '1'
if not n.stSplitKey[t.firstTax]:
n.stSplitKey.invert()
n.stSplitKey &= t.baTaxBits # 'and', in-place
# print "inverting and and-ing node %i stSplitKey
# to %s" % (n.nodeNum, n.stSplitKey)
# bytes so that I can use it as a set element
n.stSplitKeyBytes = n.stSplitKey.tobytes()
t.internals.append(n)
self.fspa = None
if self.modelName == 'SPA' and var.stmcmc_useFastSpa:
import p4.fastspa as fastspa
self.fspa = fastspa.FastSpa(useSplitSupport)
for tNum, t in enumerate(inTrees):
self.fspa.setInTr(tNum, t.nTax, self.nTax, t.baTaxBits.to01(), t.firstTax)
for n in t.internals:
if n.br and hasattr(n.br, "support"):
support = n.br.support
else:
support = -1.0
self.fspa.setInTrNo(tNum, n.stSplitKey.to01(), support)
#self.fspa.summarizeInTrs()
if self.modelName in ['QPA']:
for t in inTrees:
sorted_taxNames = []
t.taxBits = []
for tNum in range(self.nTax):
tN = self.taxNames[tNum]
if tN in t.unsorted_taxNames:
sorted_taxNames.append(tN)
t.taxBits.append(1 << tNum)
else:
t.taxBits.append(0)
t.taxNames = sorted_taxNames
# print "intree taxBits is %s" % t.taxBits
# Can't use Tree.makeSplitKeys(), unfortunately. So
# make split keys here. STMcmc.tBits is only used for
# the leaves, here and in
# STChain.setupBitarrayCalcs(), and there only once,
# during STChain.__init__(). So probably does not
# need to be an instance attribute. Maybe delete?
#self.tBits = [False] * self.nTax
for n in t.iterPostOrder():
if n == t.root:
break
if n.isLeaf:
spot = self.taxNames.index(n.name)
#self.tBits[spot] = True
n.stSplitKey = 1 << spot
#self.tBits[spot] = False
else:
n.stSplitKey = n.leftChild.stSplitKey
p = n.leftChild.sibling
while p:
n.stSplitKey |= p.stSplitKey # "or", in-place
p = p.sibling
# print "setting node %i stSplitKey to %s" % (n.nodeNum, n.stSplitKey)
# t.splSet = set()
# for n in t.iterInternalsNoRoot():
# if not n.stSplitKey[t.firstTax]: # make sure splitKey[firstTax] is a '1'
# n.stSplitKey.invert()
# n.stSplitKey &= t.baTaxBits # 'and', in-place
# #print "inverting and and-ing node %i stSplitKey to %s" % (n.nodeNum, n.stSplitKey)
# t.splSet.add(n.stSplitKey.tobytes()) # bytes so that I can
# use it as a set element
t.skk = [n.stSplitKey for n in t.iterInternalsNoRoot()]
t.qSet = set()
for sk in t.skk:
ups = [txBit for txBit in t.taxBits if (sk & txBit)]
downs = [txBit for txBit in t.taxBits if not (sk & txBit)]
for down in itertools.combinations(downs, 2):
if down[0] > down[1]:
down = (down[1], down[0])
for up in itertools.combinations(ups, 2):
if up[0] > up[1]:
up = (up[1], up[0])
if down[0] < up[0]:
t.qSet.add(down + up)
else:
t.qSet.add(up + down)
# print t.qSet
t.nQuartets = len(t.qSet)
self.trees = inTrees
if bigT:
self.tree = bigT
else:
self.tree = p4.func.randomTree(taxNames=self.taxNames, name='stTree', randomBrLens=False)
if self.stRFCalc in ['purePython1', 'fastReducedRF']:
for t in inTrees:
sorted_taxNames = []
t.taxBits = 0
for tNum in range(self.nTax):
tN = self.taxNames[tNum]
if tN in t.unsorted_taxNames:
sorted_taxNames.append(tN)
adder = 1 << tNum
t.taxBits += adder
t.taxNames = sorted_taxNames
t.allOnes = 2 ** (t.nTax) - 1
t.makeSplitKeys()
t.skSet = set([n.br.splitKey for n in t.iterInternalsNoRoot()])
if self.stRFCalc in ['purePython1', 'fastReducedRF']:
self.tree.makeSplitKeys()
self.Frrf = None
if self.stRFCalc == 'fastReducedRF':
try:
import p4.fastReducedRF
self.Frrf = p4.fastReducedRF.Frrf
# not explicitly used--but makes converters available
import pyublas
except ImportError:
gm.append("var.stRFCalc is set to 'fastReducedRF', but I could not import")
gm.append("at least one of fastReducedRF or pyublas.")
gm.append("Make sure they are installed.")
raise P4Error(gm)
if self.modelName in ['QPA']:
t = self.tree
t.taxBits = [1 << i for i in range(t.nTax)]
for n in t.iterPostOrder():
if n == t.root:
break
if n.isLeaf:
spot = self.taxNames.index(n.name)
n.stSplitKey = 1 << spot
else:
n.stSplitKey = n.leftChild.stSplitKey
p = n.leftChild.sibling
while p:
n.stSplitKey |= p.stSplitKey # "or", in-place
p = p.sibling
t.skk = [n.stSplitKey for n in t.iterInternalsNoRoot()]
t.qSet = set()
for sk in t.skk:
ups = [txBit for txBit in t.taxBits if (sk & txBit)]
downs = [txBit for txBit in t.taxBits if not (sk & txBit)]
for down in itertools.combinations(downs, 2):
assert down[0] < down[1] # probably not needed
for up in itertools.combinations(ups, 2):
assert up[0] < up[1] # probably not needed
if down[0] < up[0]:
t.qSet.add(down + up)
else:
t.qSet.add(up + down)
# print t.qSet
t.nQuartets = len(t.qSet)
self.useSplitSupport = False
if useSplitSupport:
if self.modelName.startswith("SR2008"):
gm.append(
"Arg useSplitSupport is turned on, but it is not implemented with SR2008")
raise P4Error(gm)
assert useSplitSupport in [True, 'percent']
self.useSplitSupport = True
hasSplitInfo = False
for it in self.trees:
for n in it.iterInternalsNoRoot():
if not hasattr(n.br, 'support'):
n.br.support = None
else:
assert n.br.support == None
if n.name:
flName = float(n.name)
hasSplitInfo = True
if useSplitSupport == 'percent':
flName *= 0.01
if flName < 0.0 or flName > 1.0:
gm.append("Input tree %s" %
it.writeNewick(toString=True))
gm.append(
"Got support value %s, outside of range 0 to 1" % n.name)
if flName > 1.0 and useSplitSupport == True:
gm.append(
"Maybe it is percent support? If so, set useSplitSupport to 'percent' rather than True")
raise P4Error(gm)
n.br.support = flName
#n.br.logSupport = math.log(n.br.support)
else:
#n.br.logSupport = None
pass
if not hasSplitInfo:
gm.append(
"Arg useSplitSupport is turned on, but none of the trees seem to have split info.")
raise P4Error(gm)
splash = p4.func.splash2(verbose=False)
for aLine in splash:
self.logger.info(aLine)
self.swapVector = True
if self.nChains > 1:
self.swapTuner = STSwapTunerV(self)
# Hidden experimental hacking
self.doHeatingHack = False
self.heatingHackTemperature = 5.0
#self.heatingHackProposalNames = ['nni', 'spr']
if verbose:
self.loggerPrinter.info("Initializing STMcmc")
self.loggerPrinter.info("%-16s: %s" % ('modelName', modelName))
if self.modelName.startswith("SR2008"):
self.loggerPrinter.info("%-16s: %s" % ('stRFCalc', self.stRFCalc))
if self.modelName in ["SPA", "QPA"]:
self.loggerPrinter.info("%-16s: %s" % ('useSplitSupport', self.useSplitSupport))
self.loggerPrinter.info("%-16s: %s" % ('inTrees', len(self.trees)))
self.loggerPrinter.info("%-16s: %s" % ('nTax', self.nTax))
if self.nChains == 1:
self.loggerPrinter.info("%-16s: %s" % ('mcmcmc', "off: 1 chain"))
elif self.nChains > 1:
self.loggerPrinter.info("%-16s: %s" % ('mcmcmc', "on -- %i chains" % self.nChains))
if self.swapVector:
self.loggerPrinter.info("%-16s: %s" % ('swapVector', "on"))
self.loggerPrinter.info("%-16s: %s" % ('swapTuner', "on"))
def _del_nothing(self):
gm = ["Don't/Can't delete this property."]
raise P4Error(gm)
def _get_spaQ(self):
return self._spaQ
def _set_spaQ(self, newVal):
try:
newVal = float(newVal)
except:
gm = ['This property should be set to a float.']
raise P4Error(gm)
self._spaQ = newVal
if self.chains:
for ch in self.chains:
ch.propTree.spaQ = newVal
ch.curTree.spaQ = newVal
spaQ = property(_get_spaQ, _set_spaQ, _del_nothing)
"""(property) The current spaQ"""
[docs]
def _setLogger(self):
"""Make two loggers; one that writes to a file and to stderr, and one that writes only to a file."""
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(message)s',
datefmt='[%Y-%m-%d %H:%M]',
filename="mcmc_log_%i" % self.runNum,
filemode='a')
# define a Handler which writes INFO messages or higher to the sys.stderr
console = logging.StreamHandler()
console.setLevel(logging.INFO)
# set a format which is simpler for console use
formatter = logging.Formatter('%(message)s')
# tell the handler to use this format
console.setFormatter(formatter)
# add the handler to the root logger
# Using named loggers allows me to keep them separate.
self.loggerPrinter = logging.getLogger('withPrint')
self.loggerPrinter.addHandler(console)
# This logger only logs to the file, not to stderr.
self.logger = logging.getLogger("logFileOnly")
[docs]
def _makeProposals(self):
"""Make proposals for the STMcmc."""
gm = ['STMcmc._makeProposals()']
# nni
if self.prob.nni:
p = STProposal(self)
p.name = 'nni'
# * (len(self.tree.nodes) - 1) * fudgeFactor['nni']
#print(self.prob.nni)
#print(fudgeFactor)
p.weight = self.prob.nni * fudgeFactor['nni']
self.props.proposals.append(p)
if self.prob.spr:
p = STProposal(self)
p.name = 'spr'
# * (len(self.tree.nodes) - 1) * fudgeFactor['spr']
p.weight = self.prob.spr * fudgeFactor['spr']
self.props.proposals.append(p)
if self.modelName in ['SR2008_rf_aZ_fb']:
if self.prob.SR2008beta_uniform:
p = STProposal(self)
p.name = 'SR2008beta_uniform'
p.tuning = [self._tunings.default[p.name]] * self.nChains
# * (len(self.tree.nodes) - 1) * fudgeFactor['SR2008beta_uniform']
p.weight = self.prob.SR2008beta_uniform * fudgeFactor['SR2008beta_uniform']
p.tnAccVeryHi = 0.7
p.tnAccHi = 0.6
p.tnAccLo = 0.05
p.tnAccVeryLo = 0.03
p.tnFactorVeryHi = 1.6
p.tnFactorHi = 1.2
p.tnFactorLo = 0.8
p.tnFactorVeryLo = 0.7
self.props.proposals.append(p)
if self.modelName in ['SPA', 'QPA']:
if self.prob.spaQ_uniform:
p = STProposal(self)
p.name = 'spaQ_uniform'
p.tuning = [self._tunings.default[p.name]] * self.nChains
p.spaQPriorType = 'flat'
p.spaQExpPriorLambda = 100.
# * (len(self.tree.nodes) - 1) * fudgeFactor['spaQ_uniform']
p.weight = self.prob.spaQ_uniform * fudgeFactor['spaQ_uniform']
p.tnAccVeryHi = 0.7
p.tnAccHi = 0.6
p.tnAccLo = 0.05
p.tnAccVeryLo = 0.03
p.tnFactorVeryHi = 1.6
p.tnFactorHi = 1.2
p.tnFactorLo = 0.8
p.tnFactorVeryLo = 0.7
self.props.proposals.append(p)
if self.prob.polytomy:
p = STProposal(self)
p.name = 'polytomy'
p.polytomyUseResolutionClassPrior = self.polytomyUseResolutionClassPrior
p.polytomyPriorLogBigC = self.polytomyPriorLogBigC
p.weight = self.prob.polytomy * fudgeFactor['polytomy']
self.props.proposals.append(p)
if not self.props.proposals:
gm.append("No proposals?")
raise P4Error(gm)
for p in self.props.proposals:
self.props.proposalsDict[p.name] = p
self.props.calculateWeights()
[docs]
def _refreshProposalProbsAndTunings(self):
"""Adjust proposals after a restart."""
gm = ['STMcmc._refreshProposalProbsAndTunings()']
for p in self.props.proposals:
# nni
if p.name == 'nni':
#p.weight = self.prob.local * (len(self.tree.nodes) - 1) * fudgeFactor['local']
p.weight = self.prob.nni
self.propWeights = []
for p in self.props.proposals:
self.propWeights.append(p.weight)
self.cumPropWeights = [self.propWeights[0]]
for i in range(len(self.propWeights))[1:]:
self.cumPropWeights.append(
self.cumPropWeights[i - 1] + self.propWeights[i])
self.totalPropWeights = sum(self.propWeights)
if self.totalPropWeights < 1e-9:
gm.append("No proposal weights?")
raise P4Error(gm)
[docs]
def writeProposalAcceptances(self):
"""Pretty-print the proposal acceptances."""
if (self.gen - self.startMinusOne) <= 0:
print("\nSTMcmc.writeProposalAcceptances() There is no info in memory. ")
print(" Maybe it was just emptied after writing to a checkpoint? ")
print("If so, read the checkPoint and get the proposalAcceptances from there.")
return
spacer = ' ' * 8
print("\nProposal acceptances, run %i, for %i gens, from gens %i to %i, inclusive." % (
self.runNum, (self.gen - self.startMinusOne), self.startMinusOne + 1, self.gen))
print("%s %20s %10s %13s%8s" % (spacer, 'proposal', 'nProposals', 'acceptance(%)', 'tuning'))
for p in self.props.proposals:
print("%s" % spacer, end=' ')
print("%20s" % p.name, end=' ')
print("%10i" % p.nProposals[0], end=' ')
if p.nProposals[0]: # Don't divide by zero
print(" %5.1f " % (100.0 * float(p.nAcceptances[0]) / float(p.nProposals[0])), end=' ')
else:
print(" - ", end=' ')
if p.tuning == None:
print(" -", end=' ')
elif p.tuning[0] < 2.0:
print(" %8.4f" % p.tuning[0], end=' ')
elif p.tuning[0] < 20.0:
print(" %8.3f" % p.tuning[0], end=' ')
elif p.tuning[0] < 200.0:
print(" %8.1f" % p.tuning[0], end=' ')
else:
print(" %8.3g" % p.tuning[0], end=' ')
print()
# # Tabulate topology changes by temperature
if self.nChains > 1:
for propName in ['nni', 'spr']:
p = self.props.proposalsDict.get(propName)
if p:
print("'%s' proposal-- topology changes by temperature" % (propName))
print("%s tempNum nProps nAccepts percent" % spacer)
for tNum in range(self.nChains):
print("%s" % spacer, end=' ')
print("%4i " % tNum, end=' ')
print("%9i" % p.nProposals[tNum], end=' ')
print("%8i" % p.nAcceptances[tNum], end=' ')
print(" %5.1f" % (100.0 * float(p.nAcceptances[tNum]) / float(p.nProposals[tNum])))
# # Check for aborts.
# p = self.proposalsHash.get(propName)
# if p:
# if hasattr(p, 'nAborts'):
# if p.nAborts[0]:
# print("The '%s' proposal had %i aborts in the cold chain." % (propName, p.nAborts[0]))
# if self.constraints:
# print("(Aborts might be due to violated constraints.)")
# else:
# print("The '%s' proposal had no aborts in the cold chain" % propName)
for pN in ['polytomy']:
p = None
try:
p = self.props.proposalsDict[pN]
except KeyError:
pass
if p:
if hasattr(p, 'nAborts'):
print("The %s proposal had %5i aborts." % (p.name, p.nAborts[0]))
if self.nChains > 1:
print("\n\nAcceptances and tunings by temperature")
print("%s %30s %5s %5s %10s %13s%10s" % (
spacer, 'proposal', 'part', 'tempNum', 'nProposals', 'acceptance(%)', 'tuning'))
for p in self.props.proposals:
for tempNum in range(self.nChains):
print("%s" % spacer, end=' ')
print("%30s" % p.name, end=' ')
if p.pNum != -1:
print(" %3i " % p.pNum, end=' ')
else:
print(" - ", end=' ')
print(" %3i " % tempNum, end=' ')
print("%10i" % p.nProposals[tempNum], end=' ')
if p.nProposals[tempNum]: # Don't divide by zero
print(" %5.1f " % (
100.0 * float(p.nAcceptances[tempNum]) / float(p.nProposals[tempNum])), end=' ')
else:
print(" - ", end=' ')
if p.tuning == None:
print(" -", end=' ')
elif p.tuning[tempNum] < 2.0:
print(" %8.4f" % p.tuning[tempNum], end=' ')
elif p.tuning[tempNum] < 20.0:
print(" %8.3f" % p.tuning[tempNum], end=' ')
elif p.tuning[tempNum] < 200.0:
print(" %8.1f" % p.tuning[tempNum], end=' ')
else:
print(" %8.3g" % p.tuning[tempNum], end=' ')
print()
[docs]
def writeSwapMatrix(self):
print("\nChain swapping, for %i gens, from gens %i to %i, inclusive." % (
(self.gen - self.startMinusOne), self.startMinusOne + 1, self.gen))
#print(" Swaps are presented as a square matrix, nChains * nChains.")
print(" Upper triangle is the number of swaps proposed between two chains.")
print(" Lower triangle is the percent swaps accepted.")
#print(" The current tunings.chainTemp is %5.3f\n" % self.chainTemp)
#if var.mcmc_swapVector:
# print(" The chainTemp is continuously tuned for each chain\n")
#else:
#print(" The chainTemp is %f.\n" % self.chainTemp)
print(" " * 10, end=' ')
for i in range(self.nChains):
print("%7i" % i, end=' ')
print()
print(" " * 10, end=' ')
for i in range(self.nChains):
print(" ----", end=' ')
print()
for i in range(self.nChains):
print(" " * 7, "%2i" % i, end=' ')
for j in range(self.nChains):
if i < j: # upper triangle
print("%7i" % self.swapMatrix[i][j], end=' ')
elif i == j:
print(" -", end=' ')
else:
if self.swapMatrix[j][i] == 0: # no proposals
print(" -", end=' ')
else:
print(" %5.1f" % (100.0 * float(self.swapMatrix[i][j]) / float(self.swapMatrix[j][i])), end=' ')
print()
[docs]
def _makeChainsAndProposals(self):
"""Make chains and proposals."""
gm = ['STMcmc._makeChainsAndProposals()']
# random.seed(0)
# Make chains, if needed
if not self.chains:
self.chains = []
# chNum is used by fastspa; it is also a starting point for tempNum (which changes in swaps)
for chNum in range(self.nChains):
aChain = STChain(self, chNum)
self.chains.append(aChain)
if not self.props.proposals:
self._makeProposals()
# If we are going to be doing the resolution class prior
# in the polytomy move, we want to pre-compute the logs of
# T_{n,m}. Its a vector with indices (ie m) from zero to
# nTax-2 inclusive.
p = self.props.proposalsDict.get('polytomy')
if p and self.polytomyUseResolutionClassPrior:
bigT = p4.func.nUnrootedTreesWithMultifurcations(self.tree.nTax)
p.logBigT = [0.0] * (self.tree.nTax - 1)
for i in range(1, self.tree.nTax - 1):
p.logBigT[i] = math.log(bigT[i])
#print p.logBigT
[docs]
def _setOutputTreeFile(self):
"""Setup the (output) tree file for the STMcmc."""
gm = ['STMcmc._setOutputTreeFile()']
# Write the preamble for the trees outfile.
treeFile = open(self.treeFileName, 'w')
treeFile.write('#nexus\n\n')
treeFile.write('begin taxa;\n')
treeFile.write(' dimensions ntax=%s;\n' % self.tree.nTax)
treeFile.write(' taxlabels')
for tN in self.tree.taxNames:
treeFile.write(' %s' % p4.func.nexusFixNameIfQuotesAreNeeded(tN))
treeFile.write(';\nend;\n\n')
treeFile.write('begin trees;\n')
self.translationHash = {}
i = 1
for tName in self.tree.taxNames:
self.translationHash[tName] = i
i += 1
treeFile.write(' translate\n')
for i in range(self.tree.nTax - 1):
treeFile.write(' %3i %s,\n' % (
i + 1, p4.func.nexusFixNameIfQuotesAreNeeded(self.tree.taxNames[i])))
treeFile.write(' %3i %s\n' % (
self.tree.nTax, p4.func.nexusFixNameIfQuotesAreNeeded(self.tree.taxNames[-1])))
treeFile.write(' ;\n')
treeFile.write(' [Tree numbers are gen+1]\n')
treeFile.close()
[docs]
def run(self, nGensToDo, verbose=True, equiProbableProposals=False, writeSamples=True):
"""Start the STMcmc running."""
gm = ['STMcmc.run()']
# Hidden experimental hack
if self.doHeatingHack:
print("Heating hack is turned on.")
assert self.nChains == 1, "MCMCMC does not work with the heating hack"
print("Heating hack temperature is %.2f" % self.heatingHackTemperature)
#print("Heating hack affects proposals %s" % self.heatingHackProposalNames)
# Keep track of the first gen of this call to run(), maybe restart
firstGen = self.gen + 1
if self.checkPointInterval:
# We want a couple of things:
# 1. The last gen should be on checkPointInterval. For
# example, if the checkPointInterval is 200, then doing
# 100 or 300 generations will not be allowed cuz the
# chain would continue past the checkPoint-- bad. Or if
# you re-start after 500 gens and change to a
# checkPointInterval of 200, then you won't be allowed to
# do 500 gens.
# if ((self.gen + 1) + nGensToDo) % self.checkPointInterval == 0:
if nGensToDo % self.checkPointInterval == 0:
pass
else:
gm.append(
"With the current settings, the last generation won't be on a checkPointInterval.")
gm.append("self.gen+1=%i, nGensToDo=%i, checkPointInterval=%i" % ((self.gen + 1),
nGensToDo, self.checkPointInterval))
raise P4Error(gm)
# 2. We also want the checkPointInterval to be evenly
# divisible by the sampleInterval.
if self.checkPointInterval % self.sampleInterval == 0:
pass
else:
gm.append(
"The checkPointInterval (%i) should be evenly divisible" % self.checkPointInterval)
gm.append("by the sampleInterval (%i)." % self.sampleInterval)
raise P4Error(gm)
if self.props.proposals:
# Its either a re-start, or it has been thru autoTune().
# I can tell the difference by self.gen, which is -1 after
# autoTune()
if self.gen == -1:
self._makeChainsAndProposals()
self._setOutputTreeFile()
# if self.simulate:
# self.writeSimFileHeader(self.tree)
# The probs and tunings may have been changed by the user.
self._refreshProposalProbsAndTunings()
# This stuff below should be the same as is done after pickling,
# see below.
self.startMinusOne = self.gen
# Start the tree partitions over.
self.treePartitions = None
# Zero the proposal counts
for p in self.props.proposals:
p.nProposals = [0] * self.nChains
p.nAcceptances = [0] * self.nChains
#p.nTopologyChangeAttempts = [0] * self.nChains
#p.nTopologyChanges = [0] * self.nChains
# Zero the swap matrix
if self.nChains > 1:
self.swapMatrix = []
for i in range(self.nChains):
self.swapMatrix.append([0] * self.nChains)
else:
self._makeChainsAndProposals()
self._setOutputTreeFile()
# if self.simulate:
# self.writeSimFileHeader(self.tree)
# The swap vector is just the diagonal of the swap matrix
if self.swapVector and self.nChains > 1:
# These are differences in temperatures between adjacent chains. The last one is not used.
self.chainTempDiffs = [self.chainTemp] * self.nChains
# These are cumulative, summed over the diffs. This needs to be done whenever the diffs change
self.chainTemps = [0.0]
for dNum in range(self.nChains - 1):
self.chainTemps.append(self.chainTempDiffs[dNum] + self.chainTemps[-1])
if verbose:
self.props.writeProposalIntendedProbs()
sys.stdout.flush()
coldChainNum = 0
# If polytomy is turned on, then it is possible to get a star
# tree, in which case local will not work. So if we have both
# polytomy and local proposals, we should also have brLen.
# if "polytomy" in self.proposalsHash and 'local' in self.proposalsHash:
# if 'brLen' not in self.proposalsHash:
# gm.append("If you have polytomy and local proposals, you should have a brLen proposal as well.")
# gm.append("It can have a low proposal probability, but it needs to be there.")
# gm.append("Turn it on by eg yourMcmc.prob.brLen = 0.001")
# raise P4Error(gm)
if self.gen > -1:
# it is a re-start, so we need to back over the "end;" in the tree
# files.
f2 = open(self.treeFileName, 'r+b')
pos = -1
while 1:
f2.seek(pos, 2)
c = f2.read(1)
if c == b';':
break
pos -= 1
# print "pos now %i" % pos
pos -= 3 # end;
f2.seek(pos, 2)
c = f2.read(4)
# print "got c = '%s'" % c
if c != b"end;":
gm.append("Stmcmc.run(). Failed to find and remove the 'end;' at the end of the tree file.")
raise P4Error(gm)
else:
f2.seek(pos, 2)
f2.truncate()
f2.close()
self.logger.info("Re-starting the ST MCMC run %i from gen=%i" % (self.runNum, self.gen))
if verbose:
print()
print("Re-starting the ST MCMC run %i from gen=%i" % (self.runNum, self.gen))
if not writeSamples:
print("Arg 'writeSamples' is off" )
print("Set to do %i more generations." % nGensToDo)
# if self.writePrams:
# if self.chains[0].curTree.model.nFreePrams == 0:
# print "There are no free prams in the model, so I am turning writePrams off."
# self.writePrams = False
sys.stdout.flush()
self.startMinusOne = self.gen
else:
self.logger.info("Starting the ST MCMC %s run %i" % ((self.constraints and "(with constraints)" or ""), self.runNum))
self.logger.info("nChains %i" % (self.nChains))
ret = self.props.proposalsDict.get('polytomy')
if ret:
message = "Doing polytomy proposal, with polytomyUseResolutionClassPrior=%s" % ret.polytomyUseResolutionClassPrior
self.loggerPrinter.info(message)
message = "polytomy: polytomyPriorLogBigC=%f" % ret.polytomyPriorLogBigC
self.loggerPrinter.info(message)
if verbose:
if self.nChains > 1:
print("Using Metropolis-coupled MCMC, with %i chains. Temperature %f." % (self.nChains, self.chainTemp))
else:
print("Not using Metropolis-coupled MCMC.")
self.loggerPrinter.info("Starting the ST MCMC %s run %i" % ((self.constraints and "(with constraints)" or ""), self.runNum))
self.loggerPrinter.info("Set to do %i generations." % nGensToDo)
if self.writePrams:
# if self.chains[0].curTree.model.nFreePrams == 0:
# print "There are no free prams in the model, so I am turning writePrams off."
# self.writePrams = False
# else:
pramsFile = open(self.pramsFileName, 'a')
if self.modelName.startswith("SR2008"):
pramsFile.write(" genPlus1 beta\n")
elif self.modelName.startswith("SPA"):
pramsFile.write(" genPlus1 spaQ\n")
elif self.modelName.startswith("QPA"):
pramsFile.write(" genPlus1 spaQ\n")
pramsFile.close()
sys.stdout.flush()
if not writeSamples:
self.logger.info("STMcmc.run() arg 'writeSamples' is off, so samples are not being written")
if equiProbableProposals:
self.logger.info("STMcmc.run() arg 'equiProbableProposals' is turned on")
if verbose:
if writeSamples:
print("Sampling every %i." % self.sampleInterval)
else:
print("Arg 'writeSamples' is off, so samples are not being written")
if equiProbableProposals:
print("Arg 'equiProbableProposals' is turned on")
if self.checkPointInterval:
print("CheckPoints written every %i." % self.checkPointInterval)
if nGensToDo <= 20000:
print("One dot is 100 generations.")
else:
print("One dot is 1000 generations.")
sys.stdout.flush()
self.treePartitions = None
realTimeStart = time.time()
self.lastTimeCheck = time.time()
abortableProposals = ['nni', 'spr', 'polytomy']
##################################################
############### Main loop ########################
##################################################
self.swapInterval = 2
for gNum in range(nGensToDo):
self.gen += 1
# Do an initial time estimate based on 100 gens
if nGensToDo > 100 and self.gen - firstGen == 100:
diff_secs = time.time() - realTimeStart
total_secs = (float(nGensToDo) / float(100)) * float(diff_secs)
deltaTime = datetime.timedelta(seconds=int(round(total_secs)))
print("Estimated completion time: %s days, %s" % (
deltaTime.days, time.strftime("%H:%M:%S", time.gmtime(deltaTime.seconds))))
# Above is a list of proposals where it is possible to abort.
# When a gen(aProposal) is made, below, aProposal.doAbort
# might be set, in which case we want to skip it for this
# gen. But we want to start each 'for chNum' loop with
# doAborts all turned off.
for chNum in range(self.nChains):
failure = True
nAttempts = 0
while failure:
# Get the next proposal
gotIt = False
safety = 0
while not gotIt:
# equiProbableProposals is True or False. Usually False.
aProposal = self.props.chooseProposal(equiProbableProposals)
if aProposal:
gotIt = True
if aProposal.name == 'nni':
# Can't do nni on a star tree.
if self.chains[chNum].curTree.nInternalNodes == 1:
#aProposal = self.props.proposalsDict['polytomy']
gotIt = False
if aProposal.doAbort:
gotIt = False
safety += 1
if safety > 1000:
gm.append(
"Could not find a proposal after %i attempts." % safety)
gm.append("Possibly a programming error.")
gm.append(
"Or possibly it is just a pathologically frustrating Mcmc.")
raise P4Error(gm)
if 0:
print("==== gNum=%i, chNum=%i, aProposal=%s" % (
gNum, chNum, aProposal.name), end=' ')
sys.stdout.flush()
# print gNum,
# success returns None
failure = self.chains[chNum].gen(aProposal)
if failure:
myWarn = "STMcmc.run() main loop. Proposal %s generated a 'failure'. Why?" % aProposal.name
self.logger.warning(myWarn)
if 0:
if failure:
print(" failure")
else:
print()
nAttempts += 1
if nAttempts > 1000:
gm.append("Was not able to do a successful generation after %i attempts." % nAttempts)
raise P4Error(gm)
# Continuous tuning. We have a tuning, and propose/accept
# tallies for each temperature, kept separately. Note that
# since chNum does not equal tempNum, the most recently
# incremented values (and the ones we want to tune now) will
# be aProposal.tnNSamples[tempNum] and
# aProposal.tnNAccepts[tempNum], where tempNum will most
# likely not be chNum. So we get the tempNum from this
# chNum, and tune it.
if aProposal.name in self.tunableProps:
tempNum = self.chains[chNum].tempNum
if aProposal.tnNSamples[tempNum] >= aProposal.tnSampleSize:
aProposal.tune(tempNum)
# print " Mcmc.run(). finished a gen on chain %i" % (chNum)
for prNm in abortableProposals:
ret = self.props.proposalsDict.get(prNm)
if ret:
ret.doAbort = False
# Do swap, if there is more than 1 chain.
if (self.gen + 1) % self.swapInterval == 0:
if self.nChains == 1:
coldChain = 0
else:
if self.swapVector:
rTempNum1 = random.randrange(self.nChains - 1)
rTempNum2 = rTempNum1 + 1
chain1 = None
chain2 = None
for ch in self.chains:
if ch.tempNum == rTempNum1:
chain1 = ch
elif ch.tempNum == rTempNum2:
chain2 = ch
assert chain1 and chain2
# Use the upper triangle of swapMatrix for nAttempts
self.swapMatrix[chain1.tempNum][chain2.tempNum] += 1
lnR = (1.0 / (1.0 + (self.chainTemps[chain1.tempNum]))
) * chain2.curTree.logLike
lnR += (1.0 / (1.0 + (self.chainTemps[chain2.tempNum]))
) * chain1.curTree.logLike
lnR -= (1.0 / (1.0 + (self.chainTemps[chain1.tempNum]))
) * chain1.curTree.logLike
lnR -= (1.0 / (1.0 + (self.chainTemps[chain2.tempNum]))
) * chain2.curTree.logLike
# # An alternative calculation
# heatBeta1 = 1.0 / (1.0 + self.chainTemps[chain1.tempNum])
# heatBeta2 = 1.0 / (1.0 + self.chainTemps[chain2.tempNum])
# likeRatio12 = (chain2.curTree.logLike - chain1.curTree.logLike) * heatBeta1
# likeRatio21 = (chain1.curTree.logLike - chain2.curTree.logLike) * heatBeta2
# lnR2 = likeRatio12 + likeRatio21
# rDiff = math.fabs(lnR - lnR2)
# if rDiff > 1e-12:
# print("bad swap rDiff %f (%g) lnR=%f, lnR2=%f" % (rDiff, rDiff, lnR, lnR2))
# lnR = lnR2
if lnR < -100.0:
r = 0.0
elif lnR >= 0.0:
r = 1.0
else:
r = math.exp(lnR)
acceptSwap = 0
if random.random() < r:
acceptSwap = 1
# self.logger.info("swap proposed gen=%i between tempNum1=%i chNum1=%i temp1=%f and tempNum2=%i chNum2=%i temp2=%f acceptSwap=%s" % (
# self.gen, rTempNum1, chain1.chNum, self.chainTemps[chain1.tempNum],
# rTempNum2, chain2.chNum, self.chainTemps[chain2.tempNum], acceptSwap))
# for continuous temperature tuning with self.swapTuner
if self.swapTuner:
# Index the nAttempts and nSwaps with the lower of the two tempNum's, which would be chain1.tempNum
self.swapTuner.nAttempts[chain1.tempNum] += 1
if acceptSwap:
self.swapTuner.nSwaps[chain1.tempNum] += 1
if self.swapTuner.nAttempts[chain1.tempNum] >= var.mcmc_swapTunerSampleSize:
self.swapTuner.tune(chain1.tempNum)
# tune() zeros nAttempts and nSwaps counters
if acceptSwap:
# Use the lower triangle of swapMatrix to keep track of
# nAccepted's
assert chain1.tempNum < chain2.tempNum
self.swapMatrix[chain2.tempNum][chain1.tempNum] += 1
# Do the swap
chain1.tempNum, chain2.tempNum = chain2.tempNum, chain1.tempNum
else: # swap matrix
# Chain swapping stuff was lifted from MrBayes. Thanks again.
chain1, chain2 = random.sample(self.chains, 2)
# Use the upper triangle of swapMatrix for nProposed's
if chain1.tempNum < chain2.tempNum:
self.swapMatrix[chain1.tempNum][chain2.tempNum] += 1
thisCh1Temp = chain1.tempNum
thisCh2Temp = chain2.tempNum
else:
self.swapMatrix[chain2.tempNum][chain1.tempNum] += 1
thisCh1Temp = chain2.tempNum
thisCh2Temp = chain1.tempNum
lnR = (1.0 / (1.0 + (self.chainTemp * chain1.tempNum))
) * chain2.curTree.logLike
lnR += (1.0 / (1.0 + (self.chainTemp * chain2.tempNum))
) * chain1.curTree.logLike
lnR -= (1.0 / (1.0 + (self.chainTemp * chain1.tempNum))
) * chain1.curTree.logLike
lnR -= (1.0 / (1.0 + (self.chainTemp * chain2.tempNum))
) * chain2.curTree.logLike
if lnR < -100.0:
r = 0.0
elif lnR >= 0.0:
r = 1.0
else:
r = math.exp(lnR)
acceptSwap = 0
if random.random() < r:
acceptSwap = 1
# for continuous temperature tuning with self.swapTuner
if self.swapTuner and thisCh1Temp == 0 and thisCh2Temp == 1:
self.swapTuner.swaps01_nAttempts += 1
if acceptSwap:
self.swapTuner.swaps01_nSwaps += 1
if self.swapTuner.swaps01_nAttempts >= self.swapTuner.sampleSize:
self.swapTuner.tune(self)
# tune() zeros nAttempts and nSwaps counters
if acceptSwap:
# Use the lower triangle of swapMatrix to keep track of
# nAccepted's
if chain1.tempNum < chain2.tempNum:
self.swapMatrix[chain2.tempNum][chain1.tempNum] += 1
else:
self.swapMatrix[chain1.tempNum][chain2.tempNum] += 1
# Do the swap
chain1.tempNum, chain2.tempNum = chain2.tempNum, chain1.tempNum
# Find the cold chain, the one where tempNum is 0
coldChainNum = -1
for i in range(len(self.chains)):
if self.chains[i].tempNum == 0:
coldChainNum = i
break
if coldChainNum == -1:
gm.append("Unable to find which chain is the cold chain. Bad.")
raise P4Error(gm)
# If it is a writeInterval, write stuff
if (self.gen + 1) % self.sampleInterval == 0:
if writeSamples:
likesFile = open(self.likesFileName, 'a')
likesFile.write(
'%11i %f\n' % (self.gen + 1, self.chains[coldChainNum].curTree.logLike))
likesFile.close()
treeFile = open(self.treeFileName, 'a')
treeFile.write(" tree t_%i = [&U] " % (self.gen + 1))
self.chains[coldChainNum].curTree.writeNewick(treeFile,
withTranslation=1,
translationHash=self.translationHash,
doMcmcCommandComments=False)
treeFile.close()
if writeSamples and self.writePrams:
pramsFile = open(self.pramsFileName, 'a')
#pramsFile.write("%12i " % (self.gen + 1))
pramsFile.write("%12i" % (self.gen + 1))
if self.modelName.startswith("SR2008"):
pramsFile.write(
" %f\n" % self.chains[coldChainNum].curTree.beta)
elif self.modelName in ["SPA", "QPA"]:
pramsFile.write(
" %f\n" % self.chains[coldChainNum].curTree.spaQ)
pramsFile.close()
# Do a simulation
if self.simulate:
# print "about to simulate..."
self.doSimulate(self.chains[coldChainNum].curTree)
# print "...finished simulate."
# Do other stuff.
if hasattr(self, 'hook'):
self.hook(self.chains[coldChainNum].curTree)
if 0 and self.constraints:
print("Mcmc x1c")
print(self.chains[0].verifyIdentityOfTwoTreesInChain())
print("b checking curTree ..")
self.chains[0].curTree.checkSplitKeys()
print("b checking propTree ...")
self.chains[0].propTree.checkSplitKeys()
print("Mcmc xxx")
# Add curTree to treePartitions
if self.treePartitions:
self.treePartitions._getSplitsFromTree(
self.chains[coldChainNum].curTree)
else:
self.treePartitions = TreePartitions(
self.chains[coldChainNum].curTree)
# After _getSplitsFromTree, need to follow, at some point,
# with _finishSplits(). Do that when it is pickled, or at the
# end of the run.
# Checking and debugging constraints
if 0 and self.constraints:
print("Mcmc x1d")
print(self.chains[coldChainNum].verifyIdentityOfTwoTreesInChain())
print("c checking curTree ...")
self.chains[coldChainNum].curTree.checkSplitKeys()
print("c checking propTree ...")
self.chains[coldChainNum].propTree.checkSplitKeys()
# print "c checking that all constraints are present"
#theSplits = [n.br.splitKey for n in self.chains[0].curTree.iterNodesNoRoot()]
# for sk in self.constraints.constraints:
# if sk not in theSplits:
# gm.append("split %i is not present in the curTree." % sk)
# raise P4Error(gm)
print("Mcmc zzz")
# Check that the curTree has all the constraints
if self.constraints:
splitsInCurTree = [
n.br.splitKey for n in self.chains[coldChainNum].curTree.iterInternalsNoRoot()]
for sk in self.constraints.constraints:
if sk not in splitsInCurTree:
gm.append("Programming error.")
gm.append(
"The current tree (the last tree sampled) does not contain constraint")
gm.append(
"%s" % p4.func.getSplitStringFromKey(sk, self.tree.nTax))
raise P4Error(gm)
# If it is a checkPointInterval, pickle
if self.checkPointInterval and (self.gen + 1) % self.checkPointInterval == 0:
self.checkPoint()
# The stuff below needs to be done in a re-start as well.
# See above "if self.proposals:"
self.startMinusOne = self.gen
# Start the tree partitions over.
self.treePartitions = None
# Zero the proposal counts
for p in self.props.proposals:
p.nProposals = [0] * self.nChains
p.nAcceptances = [0] * self.nChains
#p.nTopologyChangeAttempts = [0] * self.nChains
#p.nTopologyChanges = [0] * self.nChains
p.nAborts = [0] * self.nChains
# Zero the swap matrix
if self.nChains > 1:
self.swapMatrix = []
for i in range(self.nChains):
self.swapMatrix.append([0] * self.nChains)
# Reassuring pips ...
# We want to skip the first gen of every call to run()
if firstGen != self.gen:
if nGensToDo <= 20000:
if (self.gen - firstGen) % 1000 == 0:
if verbose:
deltaTime = self._doTimeCheck(
nGensToDo, firstGen, 1000)
if deltaTime.days:
timeString = "%s days, %s" % (
deltaTime.days, time.strftime("%H:%M:%S", time.gmtime(deltaTime.seconds)))
else:
timeString = time.strftime(
"%H:%M:%S", time.gmtime(deltaTime.seconds))
print("%10i - %s" % (self.gen, timeString))
else:
sys.stdout.write(".")
sys.stdout.flush()
elif (self.gen - firstGen) % 100 == 0:
sys.stdout.write(".")
sys.stdout.flush()
else:
if (self.gen - firstGen) % 50000 == 0:
if verbose:
deltaTime = self._doTimeCheck(
nGensToDo, firstGen, 50000)
if deltaTime.days:
timeString = "%s days, %s" % (
deltaTime.days, time.strftime("%H:%M:%S", time.gmtime(deltaTime.seconds)))
else:
timeString = time.strftime(
"%H:%M:%S", time.gmtime(deltaTime.seconds))
print("%10i - %s" % (self.gen, timeString))
else:
sys.stdout.write(".")
sys.stdout.flush()
elif (self.gen - firstGen) % 1000 == 0:
sys.stdout.write(".")
sys.stdout.flush()
# Gens finished. Clean up.
print()
if verbose:
print("Finished %s generations." % nGensToDo)
treeFile = open(self.treeFileName, 'a')
treeFile.write('end;\n\n')
treeFile.close()
[docs]
def _doTimeCheck(self, nGensToDo, firstGen, genInterval):
"""Time check
firstGen is the first generation of this call to Mcmc.run() else
timing fails on restart"""
nowTime = time.time()
diff_secs = nowTime - self.lastTimeCheck
total_secs = (float(nGensToDo - (self.gen - firstGen)) /
float(genInterval)) * float(diff_secs)
deltaTime = datetime.timedelta(seconds=int(round(total_secs)))
self.lastTimeCheck = nowTime
return deltaTime
[docs]
def checkPoint(self):
# Maybe we should not save the inTrees? -- would make it more
# lightweight.
if 0:
for chNum in range(self.nChains):
ch = self.chains[chNum]
print("chain %i ==================" % chNum)
ch.curTree.summarizeModelComponentsNNodes()
# the Frrf object does not pickle
savedFrrfs = []
savedBigTrs = []
if self.stRFCalc == 'fastReducedRF':
for chNum in range(self.nChains):
ch = self.chains[chNum]
savedFrrfs.append(ch.frrf)
ch.frrf = None
savedBigTrs.append(ch.bigTr)
ch.bigTr = None
# The logger does not pickle
savedLogger = self.logger
self.logger = None
savedLoggerPrinter = self.loggerPrinter
self.loggerPrinter = None
# The FastSpa stuff does not pickle
if self.modelName.startswith("SPA") and var.stmcmc_useFastSpa:
savedFspa = self.fspa
self.fspa = None
# _io.TextIOWrapper objects (as returned by open(fileName)) cannot be
# copied ("serialized"), even if they are closed. So save them and
# restore them.
# Except that this is not needed, as it is not really used. But I will
# leave this comment here for when I do start to use it.
# savedTreeFile = self.treeFile
theCopy = copy.deepcopy(self)
self.logger = savedLogger
self.loggerPrinter = savedLoggerPrinter
theCopy.treePartitions._finishSplits()
# assert theCopy.treeFile == None
# theCopy.treePartitions = None # this can be the biggest part of
# the pickle.
# Pickle it.
fName = "mcmc_checkPoint_%i.%i" % (self.runNum, self.gen + 1)
f = open(fName, 'wb')
pickle.dump(theCopy, f, pickle.HIGHEST_PROTOCOL)
f.close()
if self.stRFCalc == 'fastReducedRF':
for chNum in range(self.nChains):
ch = self.chains[chNum]
ch.frrf = savedFrrfs[chNum]
ch.bigTr = savedBigTrs[chNum]
if self.modelName.startswith("SPA") and var.stmcmc_useFastSpa:
self.fspa = savedFspa
[docs]
def writeProposalProbs(self):
"""(Another) Pretty-print the proposal probabilities.
See also STMcmc.writeProposalAcceptances().
"""
nProposals = len(self.props.proposals)
if not nProposals:
print("STMcmc.writeProposalProbs(). No proposals (yet?).")
return
nAttained = [0] * nProposals
nAccepted = [0] * nProposals
for i in range(nProposals):
nAttained[i] = self.props.proposals[i].nProposals[0]
nAccepted[i] = self.props.proposals[i].nAcceptances[0]
sumAttained = float(sum(nAttained)) # should be zero or nGen
if not sumAttained:
print("STMcmc.writeProposalProbs(). No proposals have been made.")
print("Possibly, due to it being a checkPoint interval, nProposals have all been set to zero.")
return
# assert int(sumAttained) == self.gen + 1, "sumAttained is %i, should be gen+1, %i." % (
# int(sumAttained), self.gen + 1)
probAttained = []
for i in range(len(nAttained)):
probAttained.append(100.0 * float(nAttained[i]) / sumAttained)
if math.fabs(sum(probAttained) - 100.0 > 1e-13):
raise P4Error(
"bad sum of attained proposal probs. %s" % sum(probAttained))
spacer = ' ' * 4
print("\nProposal probabilities (%)")
# print "There are %i proposals" % len(self.proposals)
print("For %i gens, from gens %i to %i, inclusive." % (
(self.gen - self.startMinusOne), self.startMinusOne + 1, self.gen))
print("%2s %11s %11s %11s %10s %23s" % ('', 'nProposals', 'proposed(%)',
'accepted(%)', 'tuning', 'proposal'))
for i in range(len(self.props.proposals)):
print("%2i" % i, end=' ')
p = self.props.proposals[i]
print(" %7i " % self.props.proposals[i].nProposals[0], end=' ')
print(" %5.1f " % probAttained[i], end=' ')
if nAttained[i]:
print(" %5.1f " % (100.0 * float(nAccepted[i]) / float(nAttained[i])), end=' ')
else:
print(" - ", end=' ')
if p.tuning == None:
print(" - ", end=' ')
elif p.tuning[0] < 2.0:
print(" %8.4f" % p.tuning[0], end=' ')
elif p.tuning[0] < 20.0:
print(" %8.3f" % p.tuning[0], end=' ')
elif p.tuning[0] < 200.0:
print(" %8.1f" % p.tuning[0], end=' ')
else:
print(" %8.3g" % p.tuning[0], end=' ')
print("%23s " % p.name, end=' ')
print()
# def writeProposalIntendedProbs(self):
# """Tabulate the intended proposal probabilities.
# """
# nProposals = len(self.proposals)
# if not nProposals:
# print("STMcmc.writeProposalIntendedProbs(). No proposals (yet?).")
# return
# intended = self.propWeights[:]
# for i in range(len(intended)):
# intended[i] /= self.totalPropWeights
# if math.fabs(sum(intended) - 1.0 > 1e-14):
# raise P4Error(
# "bad sum of intended proposal probs. %s" % sum(intended))
# spacer = ' ' * 4
# print("\nIntended proposal probabilities (%)")
# # print "There are %i proposals" % len(self.proposals)
# print("%2s %11s %23s %5s %5s" % ('', 'intended(%)', 'proposal', 'part', 'num'))
# for i in range(len(self.proposals)):
# print("%2i" % i, end=' ')
# p = self.proposals[i]
# print(" %6.2f " % (100. * intended[i]), end=' ')
# print(" %20s" % p.name, end=' ')
# if p.pNum != -1:
# print(" %3i " % p.pNum, end=' ')
# else:
# print(" - ", end=' ')
# if p.mtNum != -1:
# print(" %3i " % p.mtNum, end=' ')
# else:
# print(" - ", end=' ')
# print()
class STMcmcCheckPointReader(object):
"""Read in and display mcmc_checkPoint files.
Three options--
To read in a specific checkpoint file, specify the file name by
fName=whatever
To read in the most recent (by os.path.getmtime()) checkpoint
file, say last=True
If you specify neither of the above, it will read in all the
checkPoint files that it finds.
Where it looks is determined by theGlob, which by default is '*',
ie everything in the current directory. If you want to look
somewhere else, you can specify eg
theGlob='SomeWhereElse/*'
or, if it is unambiguous, just
theGlob='S*/*'
So you might say
cpr = STMcmcCheckPointReader(theGlob='*_0.*')
to get all the checkpoints from the first run, run 0. Then, you
can tell the cpr object to do various things. Eg
cpr.writeProposalAcceptances()
But perhaps the most powerful thing about it is that it allows
easy access to the checkpointed Mcmc objects, in the list mm. Eg
to get the first one, ask for
m = cpr.mm[0]
and m is an STMcmc object, complete with all its records of
proposals and acceptances and so on. And the TreePartitions
object.
(Sorry! -- Lazy documentation. See the source code for more that it can do.)
"""
def __init__(self, fName=None, theGlob='*', last=False, verbose=True):
self.mm = []
if not fName:
#fList = [fName for fName in os.listdir(os.getcwd()) if fName.startswith("mcmc_checkPoint")]
#fList = glob.glob(theGlob)
# print "Full glob = %s" % fList
fList = [fName for fName in glob.glob(theGlob) if
os.path.basename(fName).startswith("mcmc_checkPoint")]
# print fList
if not fList:
raise P4Error("No checkpoints found in this directory.")
if last:
# Find the most recent
mostRecent = os.path.getmtime(fList[0])
mostRecentFileName = fList[0]
if len(fList) > 1:
for fName in fList[1:]:
mtime = os.path.getmtime(fName)
if mtime > mostRecent:
mostRecent = mtime
mostRecentFileName = fName
f = open(mostRecentFileName, 'rb')
m = pickle.load(f)
f.close()
self.mm.append(m)
else:
# get all the files
for fName in fList:
f = open(fName, 'rb')
m = pickle.load(f)
f.close()
self.mm.append(m)
self.mm = p4.func.sortListOfObjectsOn2Attributes(
self.mm, "gen", 'runNum')
else:
# get the file by name
f = open(fName, 'rb')
m = pickle.load(f)
f.close()
self.mm.append(m)
if verbose:
self.dump()
def dump(self):
print("STMcmcCheckPoints (%i checkPoints read)" % len(self.mm))
print("%12s %12s %12s %12s" % (" ", "index", "run", "gen+1"))
print("%12s %12s %12s %12s" % (" ", "-----", "---", "-----"))
for i in range(len(self.mm)):
m = self.mm[i]
# print " %2i run %2i, gen+1 %11i" % (i, m.runNum, m.gen+1)
print("%12s %12s %12s %12s" % (" ", i, m.runNum, m.gen + 1))
def compareSplits(self, mNum1, mNum2, verbose=True, minimumProportion=0.1):
"""Do the TreePartitions.compareSplits() method between two checkpoints
Args:
mNum1, mNum2 (int): indices to STMcmc checkpoints in self
Returns:
a tuple of asdoss and the maximum difference in split supports
"""
# Should we be only looking at splits within the 95% ci of the topologies?
m1 = self.mm[mNum1]
m2 = self.mm[mNum2]
tp1 = m1.treePartitions
tp2 = m2.treePartitions
if verbose:
print("\nSTMcmcCheckPointReader.compareSplits(%i,%i)" % (mNum1, mNum2))
print("%12s %12s %12s %12s %12s" % ("mNum", "runNum", "start", "gen+1", "nTrees"))
for i in range(5):
print(" ---------", end=' ')
print()
for mNum in [mNum1, mNum2]:
print(" %10i " % mNum, end=' ')
m = self.mm[mNum]
print(" %10i " % m.runNum, end=' ')
print(" %10i " % (m.startMinusOne + 1), end=' ')
print(" %10i " % (m.gen + 1), end=' ')
# for i in m.splitCompares:
# print i
print(" %10i " % m.treePartitions.nTrees)
asdos, maxDiff, meanDiff = self.compareSplitsBetweenTwoTreePartitions(
tp1, tp2, minimumProportion, verbose=verbose)
asdos2, maxDiff2, meanDiff2 = self.compareSplitsBetweenTwoTreePartitions(
tp2, tp1, minimumProportion, verbose=verbose)
if math.fabs(asdos - asdos2) > 0.000001:
print("Reciprocal assdos differs: %s %s" % (asdos, asdos2))
if asdos == None and verbose:
print("No splits > %s" % minimumProportion)
return asdos, maxDiff, meanDiff
def compareSplitsBetweenTwoTreePartitions(tp1, tp2, minimumProportion, verbose=False):
"""Returns a tuple of asdoss, maximum of the differences and mean of the differences
This calls the method TreePartitions.compareSplits(), and digests the
results returned from that.
Args:
tp1, tp2 (TreePartition): TreePartition objects
minimumProportion (float): passed to TreePartitions.compareSplits()
Returns:
(asdoss, maxOfDiffs, meanOfDiffs)
"""
ret = tp1.compareSplits(tp2, minimumProportion=minimumProportion)
#print(ret) # a list of 3-item lists
# 1. The split key
# 2. The split string
# 3. A list of the 2 supports
if not ret:
return None
sumOfStdDevs = 0.0
nSplits = len(ret)
diffs = []
for i in ret:
# print " %.3f %.3f " % (i[2][0], i[2][1]),
stdDev = math.sqrt(p4.func.variance(i[2]))
# print "%.5f" % stdDev
sumOfStdDevs += stdDev
diffs.append(math.fabs(i[2][0] - i[2][1]))
asdoss = sumOfStdDevs / nSplits
maxOfDiffs = max(diffs)
meanOfDiffs = sum(diffs) / nSplits
if verbose:
print(" nSplits=%i, average of std devs of split supports %.4f " % (nSplits, asdoss))
print(" max of differences %f, mean of differences %f" % (maxOfDiffs, meanOfDiffs))
return (asdoss, maxOfDiffs, meanOfDiffs)
compareSplitsBetweenTwoTreePartitions = staticmethod(
compareSplitsBetweenTwoTreePartitions)
def compareSplitsAll(self, precision=3, linewidth=120):
"""Do the compareSplits() method between all pairs
Output is verbose. Shows
- average standard deviation of split frequencies (or supports), like MrBayes
- maximum difference between split supports from each pair of checkpoints, like PhyloBayes
Returns:
None
"""
nM = len(self.mm)
nItems = int(((nM * nM) - nM) / 2)
asdosses = np.zeros((nM, nM), dtype=np.float64)
vect = np.zeros(nItems, dtype=np.float64)
maxDiffs = np.zeros((nM, nM), dtype=np.float64)
vCounter = 0
for mNum1 in range(1, nM):
for mNum2 in range(mNum1):
thisAsdoss, thisMaxDiff, thisMeanDiff = self.compareSplits(mNum1, mNum2, verbose=False)
#print("+++ thisAsdoss = %s thisMaxDiff=%f, mNum1=%i, mNum2=%i" % (
# thisAsdoss, thisMaxDiff, mNum1, mNum2))
if thisAsdoss == None:
thisAsdoss = 0.0
asdosses[mNum1][mNum2] = thisAsdoss
asdosses[mNum2][mNum1] = thisAsdoss
vect[vCounter] = thisAsdoss
vCounter += 1
maxDiffs[mNum1][mNum2] = thisMaxDiff
maxDiffs[mNum2][mNum1] = thisMaxDiff
if 0:
print(" %10i " % mNum1, end=' ')
print(" %10i " % mNum2, end=' ')
print("%.3f" % thisAsdoss)
# Save current numpy printoptions, and restore, below.
curr = np.get_printoptions()
np.set_printoptions(precision=precision, linewidth=linewidth)
print("Pairwise asdoss values ---")
print(asdosses)
print()
print("For the %i values in one triangle," % nItems)
print("max = ", vect.max())
print("min = ", vect.min())
print("mean = ", vect.mean())
print("var = ", vect.var())
print()
print("Pairwise maximum differences in split supports between the two runs ---")
print(maxDiffs)
# Reset printoptions back to what it was
np.set_printoptions(
precision=curr['precision'], linewidth=curr['linewidth'])
def writeProposalAcceptances(self):
for m in self.mm:
m.writeProposalAcceptances()
def writeSwapMatrices(self):
for m in self.mm:
if m.nChains > 1:
m.writeSwapMatrix()
def writeProposalProbs(self):
for m in self.mm:
m.writeProposalProbs()
class QpaML(object):
"""Uses STMcmc to do likelihood calcs and Q optimization."""
def __init__(self, inTrees, bigT):
assert inTrees
ttDupes = []
for t in inTrees:
ttDupes.append(t.dupe())
assert bigT
bigTDupe = bigT.dupe()
if bigTDupe.taxNames:
pass
else:
raise P4Error('The bigT needs taxNames')
stm = STMcmc(ttDupes, bigT=bigTDupe, modelName='QPA',
beta=1.0, spaQ=0.5, stRFCalc='purePython1',
nChains=1, runNum=0, sampleInterval=100,
checkPointInterval=None, useSplitSupport=False, verbose=False,
checkForOutputFiles=False)
self.ch = STChain(stm, 0)
def setSuperTree(self, st):
assert self.ch.propTree.taxNames
st = st.dupe()
if st.taxNames:
assert st.taxNames == self.ch.propTree.taxNames
else:
st.taxNames = self.ch.propTree.taxNames
st.setPreAndPostOrder()
#st.draw()
st.taxBits = [1 << i for i in range(st.nTax)]
for n in st.iterPostOrder():
if n == st.root:
break
if n.isLeaf:
spot = st.taxNames.index(n.name)
n.stSplitKey = 1 << spot
else:
n.stSplitKey = n.leftChild.stSplitKey
p = n.leftChild.sibling
while p:
n.stSplitKey |= p.stSplitKey # "or", in-place
p = p.sibling
st.skk = [n.stSplitKey for n in st.iterInternalsNoRoot()]
st.qSet = set()
for sk in st.skk:
ups = [txBit for txBit in st.taxBits if (sk & txBit)]
downs = [txBit for txBit in st.taxBits if not (sk & txBit)]
for down in itertools.combinations(downs, 2):
assert down[0] < down[1] # probably not needed
for up in itertools.combinations(ups, 2):
assert up[0] < up[1] # probably not needed
if down[0] < up[0]:
st.qSet.add(down + up)
else:
st.qSet.add(up + down)
# print st.qSet
st.nQuartets = len(st.qSet)
self.ch.propTree = st
def calcP(self, Q):
if Q >= 1.:
return 10000000.
if Q <= 0.:
return 10000000.
self.ch.propTree.spaQ = Q
self.ch.getTreeLogLike_qpa_slow()
return -self.ch.propTree.logLike
def optimizeQ(self, x0=0.3):
res = minimize(self.calcP, x0, method='Nelder-Mead')
return (res.x, res.fun)
class SpaML(object):
"""Using STMcmc."""
def __init__(self, inTrees, bigT):
assert inTrees
ttDupes = []
for t in inTrees:
ttDupes.append(t.dupe())
assert bigT
bigTDupe = bigT.dupe()
if bigTDupe.taxNames:
pass
else:
raise P4Error('The bigT needs taxNames')
stm = STMcmc(ttDupes, bigT=bigTDupe, modelName='SPA',
beta=1.0, spaQ=0.5, stRFCalc='purePython1',
nChains=1, runNum=0, sampleInterval=100,
checkPointInterval=None, useSplitSupport=False, verbose=False,
checkForOutputFiles=False)
self.ch = STChain(stm, 0)
def setSuperTree(self, st):
assert self.ch.propTree.taxNames
st = st.dupe()
if st.taxNames:
assert st.taxNames == self.ch.propTree.taxNames
else:
st.taxNames = self.ch.propTree.taxNames
st.setPreAndPostOrder()
self.ch.propTree = st
self.ch.setupBitarrayCalcs()
def calcP(self, Q):
#if not isinstance(Q, np.ndarray):
# Q = np.array([Q])
if Q >= 1.:
return 10000000.
if Q <= 0.:
return 10000000.
if not isinstance(Q, np.ndarray):
Q = np.array([Q])
self.ch.propTree.spaQ = Q
self.ch.getTreeLogLike_spa_bitarray()
return -self.ch.propTree.logLike
def optimizeQ(self, x0=0.3):
res = minimize(self.calcP, x0, method='Nelder-Mead')
return (res.x, res.fun)
class SR2008ML(object):
"""Using STMcmc, with SR2008_rf_aZ_fb."""
def __init__(self, inTrees, bigT):
assert inTrees
ttDupes = []
for t in inTrees:
ttDupes.append(t.dupe())
assert bigT
bigTDupe = bigT.dupe()
if bigTDupe.taxNames:
pass
else:
raise P4Error('SR2008ML: The bigT needs taxNames')
stm = STMcmc(ttDupes, bigT=bigTDupe, modelName='SR2008_rf_aZ_fb',
beta=1.0, spaQ=0.5,
stRFCalc='purePython1',
#stRFCalc='bitarray',
nChains=1, runNum=0, sampleInterval=100,
checkPointInterval=None, useSplitSupport=False, verbose=False,
checkForOutputFiles=False)
self.ch = STChain(stm, 0)
def setSuperTree(self, st):
assert self.ch.propTree.taxNames
st = st.dupe()
if st.taxNames:
assert st.taxNames == self.ch.propTree.taxNames
else:
st.taxNames = self.ch.propTree.taxNames
st.setPreAndPostOrder()
self.ch.propTree = st
#self.ch.setupBitarrayCalcs()
def calcP(self, beta):
myMIN = 1.e-10
myMAX = 1.e+10
if beta >= myMAX:
return 10000000.
if beta <= myMIN:
return 10000000.
self.ch.propTree.beta = beta
self.ch.getTreeLogLike_ppy1() # pure python
return -self.ch.propTree.logLike
def optimizeBeta(self, x0=1.0):
res = minimize(self.calcP, x0, method='Nelder-Mead')
return (res.x, res.fun)