from p4.tree import Tree
from p4.node import Node
from p4.func import read
from p4.var import var
from p4.p4exceptions import P4Error
import sys
import csv
import operator
import time
multiProcessing = False
try:
from multiprocessing import Process as Process
from multiprocessing import Queue, cpu_count
cpu_count = cpu_count()
multiProcessing = True
except ImportError:
from threading import Thread as Process
from Queue import Queue
def printVerticalNumbers(noOfNumbers):
if noOfNumbers > 99999:
print('Ops, to many taxa in the tree, more than 99999 of them')
return
Tens = Hundreds = Thousands = Tenthousands = False
base = ''
if noOfNumbers >= 10:
Tens = True
tens = ''
tensIndex = 0
if noOfNumbers >= 100:
Hundreds = True
hundreds = ''
hundredsIndex = 0
if noOfNumbers >= 1000:
Thousands = True
thousands = ''
thousandsIndex = 0
if noOfNumbers >= 10000:
Tenthousands = True
tenthousands = ''
tenthousandsIndex = 0
for i in range(1, noOfNumbers + 1):
base += str(i % 10)
if Tens:
if i < 10:
tens += ' '
else:
if i % 10 == 0:
tensIndex += 1
if tensIndex % 10 == 0:
tensIndex = 0
tens += str(tensIndex)
if Hundreds:
if i < 100:
hundreds += ' '
else:
if i % 100 == 0:
hundredsIndex += 1
if hundredsIndex % 10 == 0:
hundredsIndex = 0
hundreds += str(hundredsIndex)
if Thousands:
if i < 1000:
thousands += ' '
else:
if i % 1000 == 0:
thousandsIndex += 1
if thousandsIndex % 10 == 0:
thousandsIndex = 0
thousands += str(thousandsIndex)
if Tenthousands:
if i < 10000:
tenthousands += ' '
else:
if i % 10000 == 0:
tenthousandsIndex += 1
if tenthousandsIndex % 10 == 0:
tenthousandsIndex = 0
tenthousands += str(tenthousandsIndex)
if Tenthousands:
print(tenthousands)
if Thousands:
print(thousands)
if Hundreds:
print(hundreds)
if Tens:
print(tens)
print(base)
class ConcurrentIntersections(Process):
def __init__(self, bitKeys, minimumProportion, minNoOfTaxa, weights, listOfSplits, rooted, queue, verbose=1):
Process.__init__(self)
self.weights = weights
self.minProp = minimumProportion
self.minNoOfTaxa = minNoOfTaxa
self.bitkeys = bitKeys
self.bits = len(self.bitkeys)
self.dict = {}
# print 'Thread len(): ', len(listOfSplits)
self.listOfSplits = listOfSplits
self.rooted = rooted
self.noTrees = float(len(listOfSplits))
self.cutoff = sum(self.weights) * self.minProp
self.weightedNoTrees = sum(self.weights)
self.weight = 0
self.queue = queue
self.verbose = verbose
def run(self):
if self.verbose:
sampler = int(float(len(self.listOfSplits)) / 20)
if sampler < 1:
sampler = 1
sys.stdout.write('0% ')
sys.stdout.flush()
weightSoFar = 0.0
for q in range(len(self.listOfSplits)):
completeList = self.filteredList(
self.weightedNoTrees - weightSoFar)
weightSoFar += self.weights[q]
dict = {}
for s in self.listOfSplits[q]:
for i in completeList:
t = i[0] & s[0], (i[0] ^ s[0]) | i[1] | s[1]
dict[t] = i[2]
if not self.rooted:
t = (i[0] ^ s[0]) & i[
0], (i[0] ^ (~s[0] & ((self.bitkeys[-1] << 1) - 1))) | i[0] | s[0]
dict[t] = i[2]
for i, j in dict.items():
self.addAndUpdate(long(i[0]), long(i[1]), self.weights[q], j)
for s in self.listOfSplits[q]:
self.add(s[0], s[1], self.weights[q])
if self.verbose:
if (q + 1) % sampler == 0:
sys.stdout.write('.')
sys.stdout.flush()
if self.verbose:
sys.stdout.write(' 100%\n')
sys.stdout.flush()
self.filteredList(self.weightedNoTrees - weightSoFar)
self.weight = weightSoFar
self.queue.put((self.dict, self.weight))
def filteredList(self, remaining):
list = []
newDict = {}
for key, value in self.dict.items():
dict = {}
for k, v in value.items():
if remaining + v >= self.cutoff and self.enoughTaxa(k):
list.append([key, k, v])
dict[k] = v
if len(dict.values()) > 0:
newDict[key] = dict
self.dict = newDict
return list
def add(self, index1, index2, weight):
if not self._isTrivial(index1, index2):
if index1 in self.dict:
if index2 in self.dict[index1]:
# self.dict[index1][index2] = weight + self.dict[index1][index2]
return
else:
self.dict[index1][index2] = weight
return
else:
self.dict[index1] = {}
self.dict[index1][index2] = weight
return
def addAndUpdate(self, index1, index2, weight1, weight2):
if not self._isTrivial(index1, index2):
if index1 in self.dict:
if index2 in self.dict[index1]:
self.dict[index1][index2] = weight1 + \
self.dict[index1][index2]
return
else:
self.dict[index1][index2] = weight1 + weight2
return
else:
self.dict[index1] = {}
self.dict[index1][index2] = weight1 + weight2
return
def _isTrivial(self, index1, index2):
if index1 == 0:
return True
if self._popcount(index1) <= 1:
return True
if self._popcount(index1 | index2) == len(self.bitkeys):
return True
if index1 | index2 == 0:
return True
return False
def enoughTaxa(self, countie):
# print 'self.minNoOfTaxa: ', self.minNoOfTaxa
count = 0
for bk in self.bitkeys:
if not bk & countie:
count += 1
if count == self.minNoOfTaxa:
# print 'True'
return True
# print 'False'
return False
def _popcount(self, countie):
count = 0
for bk in self.bitkeys:
if countie < bk:
return count
if bk & countie:
count += 1
return count
def getWeight(self):
return self.weight
def getDict(self):
return self.dict
def getResults(self):
return self.results
class ConcurrentCombineIntersectionDicts(Process):
def __init__(self, bitKeys, minimumProportion, dict1, dict2, rooted, weights, weight, queue):
Process.__init__(self)
self.weights = weights
self.weight = weight
self.minProp = minimumProportion
self.bitkeys = bitKeys
self.dict = {}
self.dict1 = dict1
self.dict2 = dict2
self.rooted = rooted
self.cutoff = round(sum(self.weights) * self.minProp)
self.weightedNoTrees = sum(self.weights)
self.queue = queue
def run(self):
# print 'Started thread: ', self.getName()
list = []
for key, value in self.dict2.items():
for k, v in value.items():
list.append([key, k, v])
self.dict = self.dict1
# list = self.combine()
sampler = float(len(list)) / 10
weightSoFar = self.weight
dict = {}
for q in range(len(list)):
completeList = self.filteredList(
self.weightedNoTrees - weightSoFar)
for i in completeList:
t = i[0] & list[q][0], (i[0] ^ list[q][0]) | i[1] | list[q][1]
if t in dict:
if dict[t] < i[2]:
dict[t] = i[2]
else:
dict[t] = i[2]
if not self.rooted:
t = (i[0] ^ list[q][0]) & i[
0], (i[0] ^ (~list[q][0] & ((self.bitkeys[-1] << 1) - 1))) | i[0] | list[q][0]
if t in dict:
if dict[t] < i[2]:
dict[t] = i[2]
else:
dict[t] = i[2]
if q + 1 % sampler == 0:
print('Thread %s processed %s combinations' % (self.getName(), q + 1))
sys.stdout.flush()
for i, j in dict.items():
self.addAndUpdate(long(i[0]), long(i[1]), j, list[q][2])
for q in range(len(list)):
self.add(list[q][0], list[q][1], list[q][2])
self.printDict(self.dict)
self.queue.put(self.dict)
#======================================================================
# completeList = self.filteredList(self.weightedNoTrees - weightSoFar)
#======================================================================
def add(self, index1, index2, weight):
if not self._isTrivial(index1, index2):
if index1 in self.dict:
if index2 in self.dict[index1]:
# self.dict[index1][index2] = weight + self.dict[index1][index2]
return
else:
self.dict[index1][index2] = weight
return
else:
self.dict[index1] = {}
self.dict[index1][index2] = weight
return
def addAndUpdate(self, index1, index2, weight1, weight2):
if not self._isTrivial(index1, index2):
if index1 in self.dict:
if index2 in self.dict[index1]:
self.dict[index1][index2] = weight2 + \
self.dict[index1][index2]
return
else:
self.dict[index1][index2] = weight1 + weight2
return
else:
self.dict[index1] = {}
self.dict[index1][index2] = weight1 + weight2
return
def _isTrivial(self, index1, index2):
if index1 == 0:
return True
if self._popcount(index1) <= 1:
return True
if self._popcount(index1 | index2) == len(self.bitkeys):
return True
if index1 | index2 == 0:
return True
return False
def _popcount(self, countie):
count = 0
for bk in self.bitkeys:
if countie < bk:
return count
if bk & countie:
count += 1
return count
def filteredList(self, remaining):
list = []
newDict = {}
for key, value in self.dict.items():
dict = {}
for k, v in value.items():
# if remaining + v >= self.cutoff:
list.append([key, k, v])
dict[k] = v
# else:
# print 'Cut: ', self.cutoff
# print 'Rem + v: ', remaining +v
# print 'DOOOOOOOGH'
if len(dict.values()) > 0:
newDict[key] = dict
self.dict = newDict
return list
def printDict(self, dict):
for k, v in dict.items():
for key, value in v.items():
print('%s, %s, %s ' % (k, key, value))
def combine(self):
self.dict = self.dict1
list = []
for k, v in self.dict2.items():
for k2, v2 in v.items():
list.append([k, k2, v2])
return list
# list = []
# for k1, v1 in self.dict1.items():
# if k1 in self.dict2:
# dict = {}
# for k2, v2 in v1.items():
# if k2 in self.dict2[k1]:
# dict[k2] = self.dict2[k1][k2] + v2
# else:
# list.append([k1,k2,v2])
# if len(dict.keys()) > 0:
# self.dict[k1] = dict
# else:
# for k2, v2 in v1.items():
# list.append([k1,k2,v2])
# print list
# for k1, v1 in self.dict2.items():
# if k1 in self.dict1:
# for k2, v2 in v1.items():
# if not k2 in self.dict1[k1]:
# list.append([k1,k2,v2])
# else:
# for k2, v2 in v1.items():
# list.append([k1,k2,v2])
# print list
# return list
def getCombinedDict(self):
return self.dict
class SimpleIntersections(object):
def __init__(self, bitKeys, minimumProportion, minNoOfTaxa, weights, listOfSplits, rooted):
self.weights = weights
self.minProp = minimumProportion
self.minNoOfTaxa = minNoOfTaxa
self.bitkeys = bitKeys
self.bits = len(self.bitkeys)
self.dict = {}
self.inputSplits = listOfSplits
self.splits = []
self.rooted = rooted
self.cutoff = sum(self.weights) * self.minProp
self.weightedNoTrees = sum(self.weights)
self.allOnes = 2 ** (len(self.bitkeys)) - 1
def createIntersections(self):
sampler = int(float(len(self.listOfSplits)) / 20)
if sampler < 1:
sampler = 1
weightSoFar = 0.0
sys.stdout.write('0% ')
sys.stdout.flush()
for inputList in self.inputSplits:
for split in inputList:
self.splits.append()
for q in range(len(self.listOfSplits)):
completeList = self.filteredList(
self.weightedNoTrees - weightSoFar)
weightSoFar += self.weights[q]
dict = {}
for s in self.listOfSplits[q]:
for i in completeList:
t = i[0] & s[0], (i[0] ^ s[0]) | i[1] | s[1]
dict[t] = i[2]
if not self.rooted:
t = (i[0] ^ s[0]) & i[
0], (i[0] ^ (~s[0] & ((self.bitkeys[-1] << 1) - 1))) | i[0] | s[0]
dict[t] = i[2]
for i, j in dict.items():
self.addAndUpdate(long(i[0]), long(i[1]), self.weights[q], j)
for s in self.listOfSplits[q]:
self.add(s[0], s[1], self.weights[q])
if (q + 1) % sampler == 0:
sys.stdout.write('.')
sys.stdout.flush()
sys.stdout.write(' 100%\n')
sys.stdout.flush()
self.filteredList(self.weightedNoTrees - weightSoFar)
self.weight = weightSoFar
self.queue.put((self.dict, self.weight))
class BuildIntersections(object):
def __init__(self, bitKeys, minimumProportion, minNoOfTaxa, weights, listOfSplits, rooted, sortByNoSplits, verbose=1):
self.verbose = verbose
self.sortByNoSplits = sortByNoSplits
self.weights = weights
self.minProp = minimumProportion
self.minNoOfTaxa = minNoOfTaxa
self.bitkeys = bitKeys
self.bits = len(self.bitkeys)
self.dict = {}
self.listOfSplits = listOfSplits
self.rooted = rooted
self.cutoff = sum(self.weights) * self.minProp
self.weightedNoTrees = sum(self.weights)
self.allOnes = 2 ** (len(self.bitkeys)) - 1
# print 'self.weightedNoTrees:', self.weightedNoTrees
def buildIntersections(self, verbose=1):
if multiProcessing and False:
print('Starting dual cpu computations')
queue = Queue()
# print 'CPUS: ', cpu_count
threads = []
# cpus = cpu_count
# divider = len(self.listOfSplits)/cpus*4
# print 'Len(): ', len(self.listOfSplits)
# print 'Divider: ', divider
# for i in range(cpus):
# start = i*divider
# stop = (i+1)*divider
# print 'List of splits %s:%s ' % (start, stop)
# print 'Len(): ', len(self.listOfSplits[start:stop])
# thread = ConcurrentIntersections(self.bitkeys, self.minProp, self.minNoOfTaxa, self.weights[start:stop], self.listOfSplits[start:stop], self.rooted, queue)
# thread.start()
# threads.append(thread)
# results = []
# start = 0
# stop = 25
# thread = ConcurrentIntersections(self.bitkeys, self.minProp, self.weights[start:stop], self.listOfSplits[start:stop], self.rooted, queue)
# thread.start()
# threads.append(thread)
# results.append(queue.get())
# thread.join()
# print 'List of splits %s:%s ' % (start,stop)
# print 'Len(): ', len(self.listOfSplits[start:stop])
# print self.listOfSplits[start:stop]
#
# start = 25
# stop = 50
# thread = ConcurrentIntersections(self.bitkeys, self.minProp, self.weights[start:stop], self.listOfSplits[start:stop], self.rooted, queue)
# thread.start()
# threads.append(thread)
# results.append(queue.get())
# thread.join()
# print 'List of splits %s:%s ' % (start, stop)
# print 'Len(): ', len(self.listOfSplits[start:stop])
# print self.listOfSplits[start:stop]
#
# results = []
# for i in range(cpus):
# results.append(queue.get())
# for t in threads:
# t.join()
thread1 = ConcurrentIntersections(self.bitkeys, self.minProp, self.minNoOfTaxa, self.weights, self.listOfSplits[
len(self.listOfSplits) / 2:], self.rooted, queue)
thread2 = ConcurrentIntersections(self.bitkeys, self.minProp, self.minNoOfTaxa, self.weights, self.listOfSplits[
:len(self.listOfSplits) / 2], self.rooted, queue)
# thread1.setName('Worker 1')
# thread2.setName('Worker 2')
results = []
thread1.start()
results.append(queue.get())
thread1.join()
dict1 = results[0][0]
print(dict1)
weight1 = results[0][1]
print('weight1: ', weight1)
self.dict = dict1
thread2.start()
results.append(queue.get())
thread2.join()
dict2 = results[0][0]
print(dict2)
weight2 = results[1][1]
print('weight2: ', weight2)
self.dict = dict2
#
combine = ConcurrentCombineIntersectionDicts(
self.bitkeys, self.minProp, dict1, dict2, self.rooted, self.weights, weight1 + weight2, queue)
#
# combine.setName('Combinator')
#
combine.start()
#
combine.join()
#
# self.dict = combine.getCombinedDict()
self.dict = queue.get()
# self.dict = dict2
print('self.dict: ', self.dict)
else:
# print 'Starting single cpu computations'
queue = Queue()
threads = []
thread = ConcurrentIntersections(
self.bitkeys, self.minProp, self.minNoOfTaxa, self.weights, self.listOfSplits, self.rooted, queue, verbose=0)
thread.start()
threads.append(thread)
results = []
results.append(queue.get())
dict1 = results[0][0]
weight1 = results[0][1]
self.dict = dict1
self.finalList = self.informativeList()
return self.finalList
def buildExtendedProfile(self, generations=2):
alreadyAdded = {}
newlyAdded = {}
completeList = []
for list in self.finalList:
for i in list:
if i[2] == 1.0 and self._addToDict(alreadyAdded, i[0], i[1]):
completeList.append(i)
for counter in range(generations):
list = []
for i in range(0, len(completeList) - 1):
for j in range(i + 1, len(completeList)):
extend = self._extend(completeList[i], completeList[j])
if len(extend) > 0:
if self._addToDict(alreadyAdded, extend[0][0], extend[0][1]):
self._addToDict(
newlyAdded, extend[0][0], extend[0][1])
list.append(extend[0])
if self._addToDict(alreadyAdded, extend[1][0], extend[1][1]):
self._addToDict(
newlyAdded, extend[1][0], extend[1][1])
list.append(extend[1])
completeList.extend(list)
list = []
for key in newlyAdded.keys():
list.extend(self._getInformativeList(newlyAdded[key], key))
list = self._sortByLeafSet(list)
tempList = []
for l in list:
tempList.extend(self._buildUnconflictingLists(l))
return tempList
def _addToDict(self, dict, i, j):
if i in dict:
if j not in dict[i]:
dict[i][j] = self.weightedNoTrees
return True
else:
return False
else:
dict[i] = {}
dict[i][j] = self.weightedNoTrees
return True
def _extend(self, i, j):
if i[0] & j[0] == 0:
if i[0] & j[1] == 0:
if i[1] & j[0] == 0:
return [[i[0], i[1] | j[1], self.weightedNoTrees], [j[0], i[1] | j[1], self.weightedNoTrees]]
if i[0] | j[0] == i[0] or i[0] | j[0] == j[0]:
if i[0] & j[1] == 0 and j[0] & i[1] == 0:
return [[i[0], i[1] | j[1], self.weightedNoTrees], [j[0], i[1] | j[1], self.weightedNoTrees]]
return []
def popcountLimit(self, n, limit):
count = 0
for bk in self.bitkeys:
if bk & n:
count += 1
if n >= limit:
return True
return False
def isCompatible(self, i, j, excluded):
def isPowerOf2(number):
return (not(number & (number - 1)))
if not self.popcountLimit(i & j, 2) or not self.popcountLimit(i ^ j, 2):
return True
else:
ones = i & j
one1 = 0
if ones > 0:
if not isPowerOf2(ones):
one1 = 2
else:
one1 = 1
# print 'one1: ', one1
zeroOnes = (i ^ j) & j
zero1 = 0
if zeroOnes > 0:
if not isPowerOf2(zeroOnes):
zero1 = 2
else:
zero1 = 1
# print 'zero1: ', zero1
oneZeroes = (i ^ j) & i
one0 = 0
if oneZeroes > 0:
if not isPowerOf2(oneZeroes):
one0 = 2
else:
one0 = 1
# print 'one0: ', one0
zeroZeroes = ((j ^ self.allOnes) & (i ^ self.allOnes)) ^ excluded
zero0 = 0
if zeroZeroes > 0:
if not isPowerOf2(zeroZeroes):
zero0 = 2
else:
zero0 = 1
# print 'zero0: ', zero0
if one1 and zero1 and one0 and zero0:
return False
return True
def quickList(self, list, bitkeys):
# print 'Initial list'
# for i in list:
# i.printIntersection(bitkeys)
# print
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
if list[i].frequency < list[j].frequency:
temp = list[i]
list[i] = list[j]
list[j] = temp
elif list[i].frequency == list[j].frequency:
pop1 = self._pop(list[i].right, bitkeys)
pop2 = self._pop(list[j].right, bitkeys)
if pop1 > pop2:
temp = list[i]
list[i] = list[j]
list[j] = temp
elif pop1 == pop2:
if list[i].firstHit(bitkeys) > list[j].firstHit(bitkeys):
temp = list[i]
list[i] = list[j]
list[j] = temp
elif list[i].firstHit(bitkeys) == list[j].firstHit(bitkeys):
if self._pop(list[i].left, bitkeys) < self._pop(list[j].left, bitkeys):
temp = list[i]
list[i] = list[j]
list[j] = temp
# print 'Sorted'
# for i in list:
# i.printIntersection(bitkeys)
nonRedundantList = []
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
if list[i].left | list[j].right == list[j].left | list[j].right:
if list[i].right | list[j].right == list[j].right:
list[j].hits = 0
if self.verbose:
print('Splits')
for intersection in list:
if intersection.hits > 0:
nonRedundantList.append(intersection)
if self.verbose:
intersection.printIntersection(
bitkeys, self.weightedNoTrees)
list = nonRedundantList
self.fullList = []
for i in nonRedundantList:
copy = i.copy()
copy.frequency = copy.frequency / self.weightedNoTrees
self.fullList.append(copy)
lists = []
while len(list) > 0 and list[0].frequency >= self.cutoff:
split = list.pop(0)
exclude = split.right
compatibleList = [split]
for i in range(len(list)):
added = False
wannabe = list.pop(0)
if wannabe.right | exclude == exclude:
temp = wannabe.copy()
if wannabe.right != exclude:
added = True
list.append(temp)
wannabe.right = wannabe.right | exclude
wannabe.left = wannabe.left ^ (wannabe.left & exclude)
compatible = True
for j in range(len(compatibleList)):
if compatibleList[j].left == wannabe.left:
compatible = False
break
if not self.isCompatible(compatibleList[j].left, wannabe.left, exclude):
compatible = False
break
if compatible:
if self._pop(wannabe.left, bitkeys) > 1:
compatibleList.append(wannabe)
else:
if not added:
list.append(temp)
else:
if not added:
list.append(temp)
else:
list.append(wannabe)
if len(compatibleList) > 0:
lists.append(compatibleList)
# print 'Lists'
# for list in lists:
# for i in list:
# i.printIntersection(bitkeys)
# print
# print 'len(lists): ', len(lists)
for index in range(1, len(lists)):
excluded = lists[index][0].right
for i in lists[0]:
j = i.copy()
j.right = excluded
j.left = j.left ^ (j.left & excluded)
if j.left == 0 or j.popcount(bitkeys) <= 1:
pass
else:
append = True
for k in lists[index]:
if k.left == j.left:
append = False
if k.frequency < j.frequency:
k.frequency = j.frequency
break
if append:
lists[index].append(j)
if self.sortByNoSplits:
for i in range(0, len(lists) - 1):
for j in range(i, len(lists)):
if len(lists[i]) > len(lists[j]):
temp = lists[i]
lists[i] = lists[j]
lists[j] = temp
elif len(lists[i]) == len(lists[j]):
if lists[i][0].popcountExcluded(bitkeys) > lists[j][0].popcountExcluded(bitkeys):
temp = lists[i]
lists[i] = lists[j]
lists[j] = temp
# print 'Partitioned splits'
for list in lists:
for i in list:
i.setExcluded(bitkeys)
i.frequency = i.frequency / self.weightedNoTrees
# i.printIntersection(bitkeys)
# print ''
return lists
def informativeList(self):
self._remove()
list = []
for key, value in self.dict.items():
for k, v in value.items():
if v >= self.cutoff:
list.append(Intersection(key, k, v))
# for i in list:
# i.printIntersection(self.bitkeys)
# print ''
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
if list[i].frequency < list[j].frequency:
temp = list[i]
list[i] = list[j]
list[j] = temp
elif list[i].frequency == list[j].frequency:
if self._popcount(list[i].right) > self._popcount(list[j].right):
temp = list[i]
list[i] = list[j]
list[j] = temp
elif self._popcount(list[i].right) == self._popcount(list[j].right):
if list[i].firstHit(self.bitkeys) > list[j].firstHit(self.bitkeys):
temp = list[i]
list[i] = list[j]
list[j] = temp
elif list[i].firstHit(self.bitkeys) == list[j].firstHit(self.bitkeys):
if self._popcount(list[i].left) < self._popcount(list[j].left):
temp = list[i]
list[i] = list[j]
list[j] = temp
# print 'Sorted'
# for i in list:
# i.printIntersection(self.bitkeys)
nonRedundantList = []
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
if list[i].left | list[j].right == list[j].left | list[j].right:
if list[i].right | list[j].right == list[j].right:
list[j].hits = 0
if self.verbose:
print('Splits')
for intersection in list:
if intersection.hits > 0:
nonRedundantList.append(intersection)
if self.verbose:
intersection.printIntersection(
self.bitkeys, self.weightedNoTrees)
list = nonRedundantList
self.fullList = []
for i in nonRedundantList:
copy = i.copy()
copy.frequency = copy.frequency / self.weightedNoTrees
self.fullList.append(copy)
lists = []
while len(list) > 0 and list[0].frequency >= self.cutoff:
split = list.pop(0)
exclude = split.right
compatibleList = [split]
for i in range(len(list)):
added = False
wannabe = list.pop(0)
if wannabe.right | exclude == exclude:
temp = wannabe.copy()
if wannabe.right != exclude:
added = True
list.append(temp)
wannabe.right = wannabe.right | exclude
wannabe.left = wannabe.left ^ (wannabe.left & exclude)
compatible = True
for j in range(len(compatibleList)):
if compatibleList[j].left == wannabe.left:
compatible = False
break
if not self.isCompatible(compatibleList[j].left, wannabe.left, exclude):
compatible = False
break
if compatible:
if self._popcount(wannabe.left) > 1:
compatibleList.append(wannabe)
else:
if not added:
list.append(temp)
else:
if not added:
list.append(temp)
else:
list.append(wannabe)
if len(compatibleList) > 0:
lists.append(compatibleList)
# print 'len(lists): ', len(lists)
for index in range(1, len(lists)):
excluded = lists[index][0].right
for i in lists[0]:
j = i.copy()
j.right = excluded
j.left = j.left ^ (j.left & excluded)
if j.left == 0 or j.popcount(self.bitkeys) <= 1:
pass
else:
append = True
for k in lists[index]:
if k.left == j.left:
append = False
if k.frequency < j.frequency:
k.frequency = j.frequency
break
if append:
lists[index].append(j)
if self.sortByNoSplits:
for i in range(0, len(lists) - 1):
for j in range(i, len(lists)):
if len(lists[i]) > len(lists[j]):
temp = lists[i]
lists[i] = lists[j]
lists[j] = temp
elif len(lists[i]) == len(lists[j]):
if lists[i][0].popcountExcluded(self.bitkeys) > lists[j][0].popcountExcluded(self.bitkeys):
temp = lists[i]
lists[i] = lists[j]
lists[j] = temp
# print 'Partitioned splits'
for list in lists:
for i in list:
i.setExcluded(self.bitkeys)
i.frequency = i.frequency / self.weightedNoTrees
# i.printIntersection(self.bitkeys)
# print ''
return lists
def _informativeList(self):
self._remove()
list = []
for key, value in self.dict.items():
for k, v in value.items():
list.append(Intersection(key, k, v))
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
if list[i].frequency < list[j].frequency:
temp = list[i]
list[i] = list[j]
list[j] = temp
elif list[i].frequency == list[j].frequency and self._popcount(list[i].right) < self._popcount(list[j].right):
temp = list[i]
list[i] = list[j]
list[j] = temp
# for i in list:
# i.printIntersection(self.bitkeys)
# print 'Consolidate within: '
list = self._consolidateWithinFrequency(list)
# for i in list:
# i.printIntersection(self.bitkeys)
list = self._consolidateBetweenFrequencies(list)
# print 'Consolidate between:'
# for i in list:
# i.printIntersection(self.bitkeys)
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
if list[i].frequency == list[j].frequency:
if list[i].popcountExcluded(self.bitkeys) > list[j].popcountExcluded(self.bitkeys):
temp = list[i]
list[i] = list[j]
list[j] = temp
elif list[i].firstHit(self.bitkeys) > list[j].firstHit(self.bitkeys):
temp = list[i]
list[i] = list[j]
list[j] = temp
elif list[i].firstHit(self.bitkeys) == list[j].firstHit(self.bitkeys):
if list[i].popcount(self.bitkeys) > list[j].popcount(self.bitkeys):
temp = list[i]
list[i] = list[j]
list[j] = temp
print('Splits')
printVerticalNumbers(len(self.bitkeys))
for i in list:
i.printIntersection(self.bitkeys)
smallestList, remainder = self._smallestList(list)
# for i in remainder:
# i.printIntersection(self.bitkeys)
lists = []
lists.append(smallestList)
oldList = smallestList
added = {}
for intersection in remainder:
if intersection.right not in added:
temp = []
addable = self._addable2List(intersection, oldList)
if addable == 2:
temp = self._add2List(intersection, oldList)
for i in remainder:
addableWithRestraint = self._addable2ListWithRestraint(
i, temp)
if addableWithRestraint == 2:
added[i.right] = 1
self._add2ListWithRestraint(i, temp)
# temp.append(i)
elif addableWithRestraint == 1:
added[i.right] = 1
self._replaceWithRestraint(i, temp)
elif addable == 1:
temp = self._replace(intersection, oldList)
for i in remainder:
addableWithRestraint = self._addable2ListWithRestraint(
i, temp)
if addableWithRestraint == 2:
self._add2ListWithRestraint(i, temp)
# temp.append(i)
added[i.right] = 1
elif addableWithRestraint == 1:
self._replaceWithRestraint(i, temp)
added[i.right] = 1
if len(temp) > 0:
lists.append(temp)
for i in range(0, len(lists) - 1):
for j in range(i + 1, len(lists)):
if lists[i][0].popcountExcluded(self.bitkeys) > lists[j][0].popcountExcluded(self.bitkeys):
temp = lists[i]
lists[i] = lists[j]
lists[j] = temp
for list in lists:
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
if list[i].popcount(self.bitkeys) > list[j].popcount(self.bitkeys):
temp = list[i]
list[i] = list[j]
list[j] = temp
elif list[i].popcount(self.bitkeys) == list[j].popcount(self.bitkeys):
if list[i].firstHit(self.bitkeys) > list[j].firstHit(self.bitkeys):
temp = list[i]
list[i] = list[j]
list[j] = temp
for i in list:
i.setExcluded(self.bitkeys)
i.frequency = i.frequency / self.weightedNoTrees
for list in lists:
for i in list:
i.printIntersection(self.bitkeys)
print('')
return lists
def _replaceWithRestraint(self, intersection, intersections):
counter = 0
for i in intersections:
if intersection.left == i.left:
# print 'replaced %s with %s' % (i.frequency,
# intersection.frequency)
i.frequency = intersection.frequency
counter += 1
if counter > 1:
intersection.printIntersection(self.bitkeys)
def _replace(self, intersection, intersections):
copy = intersection.copy()
jointExcluded = intersection.right | intersections[0].right
copy.left = (copy.left & jointExcluded) ^ copy.left
copy.right = jointExcluded
first = True
dict = {}
for i in intersections:
c = i.copy()
c.left = (c.left & jointExcluded) ^ c.left
if self._popcount(c.left) > 1:
c.right = jointExcluded
if copy.left == c.left:
if copy.frequency > c.frequency:
c.frequency = copy.frequency
# if first:
# first = False
if c.left in dict:
if c.frequency > dict[c.left].frequency:
dict[c.left] = c
else:
dict[c.left] = c
else:
if c.left in dict:
if c.frequency > dict[c.left].frequency:
dict[c.left] = c
else:
dict[c.left] = c
return dict.values()
def _add2ListWithRestraint(self, intersection, intersections):
for i in intersections:
if intersection.left == i.left:
print('Adding identical intersection to list')
intersections.append(intersection)
def _add2List(self, intersection, intersections):
jointExcluded = intersection.right | intersections[0].right
copy = intersection.copy()
copy.left = (copy.left & jointExcluded) ^ copy.left
copy.right = jointExcluded
add = True
dict = {}
for i in intersections:
temp = i.copy()
temp.left = (i.left & jointExcluded) ^ i.left
temp.right = jointExcluded
if self._popcount(temp.left) > 1:
if temp.left in dict:
if dict[temp.left].frequency < temp.frequency:
if temp.left == copy.left and temp.frequency >= copy.frequency:
add = False
# list.append(temp)
dict[temp.left] = temp
else:
if temp.left == copy.left and temp.frequency >= copy.frequency:
add = False
# list.append(temp)
dict[temp.left] = temp
list = dict.values()
if add:
list.append(copy)
return list
def _addable2ListWithRestraint(self, proposed, intersections):
if not proposed.right == intersections[0].right:
return 0
higher = False
for intersection in intersections:
# temp = proposed.right & intersection.left
# if proposed.left == intersection.left ^ temp:
if proposed.left == intersection.left:
if proposed.frequency <= intersection.frequency:
return 0
else:
higher = True
if higher:
return 1
return 2
def _addable2List(self, proposed, intersections):
higher = False
for intersection in intersections:
temp = proposed.right & intersection.left
if proposed.left == intersection.left ^ temp:
if proposed.frequency <= intersection.frequency:
return 0
else:
higher = True
if higher:
return 1
return 2
def _consolidateBetweenFrequencies(self, intersections):
list = []
for index in range(len(intersections) - 1, -1, -1):
if index == 0:
list.append(intersections[index])
list.reverse()
return list
# intersections[index].printIntersection(self.bitkeys)
pop = self._popcount(intersections[index].right)
left = intersections[index].left
frequency = intersections[index].frequency
replace = False
for jndex in range(index - 1, -1, -1):
# intersections[jndex].printIntersection(self.bitkeys)
if intersections[jndex].left == left:
# print 'true'
if frequency < intersections[jndex].frequency:
# print 'True'
if pop > self._popcount(intersections[jndex].right):
# print 'TRUE'
replace = True
# print ''
if not replace:
list.append(intersections[index])
return list
def _consolidateWithinFrequency(self, intersections):
list = []
for index in range(len(intersections)):
if index == len(intersections) - 1:
list.append(intersections[index])
return list
pop = self._popcount(intersections[index].right)
if pop > 0:
replace = False
makeup = intersections[index].left | intersections[index].right
hits = intersections[index].frequency
# print '&', intersections[index].left | intersections[index].right
for jndex in range(index + 1, len(intersections)):
# print '.', intersections[jndex].left
if intersections[jndex].left == makeup:
# print 'true'
if intersections[jndex].frequency == hits:
replace = True
if not replace:
list.append(intersections[index])
else:
list.append(intersections[index])
return list
def _smallestList(self, list):
smallest = sys.maxint
smallestList = []
remainder = []
for intersection in list:
pop = self._popcount(intersection.right)
if pop < smallest:
smallest = pop
remainder.extend(smallestList)
smallestList = [intersection]
elif pop == smallest:
smallestList.append(intersection)
else:
remainder.append(intersection)
return smallestList, remainder
def _remove(self):
dict = {}
for key, value in self.dict.items():
temp = {}
for v in value.values():
temp[v] = 1
tempDict = {}
for k in temp.keys():
list = []
smallest = sys.maxint
for ke, v in value.items():
if v == k:
list.append((ke, v))
pop = self._popcount(ke)
if pop < smallest:
smallest = pop
for tuple in list:
pop = self._popcount(tuple[0])
if pop == smallest:
tempDict[tuple[0]] = tuple[1]
dict[key] = tempDict
self.dict = dict
def removeRedundant(self):
leafsets = {}
for d1 in self.dict.keys():
for k, v in self.dict[d1].items():
leafset = d1 | k
if leafset in leafsets:
if v in leafsets[leafset]:
leafsets[leafset][v].append((d1, k))
else:
leafsets[leafset][v] = [(d1, k)]
else:
leafsets[leafset] = {}
leafsets[leafset][v] = [(d1, k)]
dict = {}
for k, v in leafsets.items():
for hits, list in v.items():
biggestList = []
biggest = 0
for i in range(len(list)):
pop = self._popcount(list[i][0])
if pop > biggest:
biggestList = [list[i]]
biggest = pop
elif pop == biggest:
biggestList.append(list[i])
for i in biggestList:
if i[0] in dict:
if i[1] not in dict[i[0]]:
dict[i[0]][i[1]] = hits
else:
dict[i[0]] = {}
dict[i[0]][i[1]] = hits
self.dict = dict
def _getInformativeList(self, dict, left):
freqDict = {}
for v in dict.keys():
if dict[v] >= self.cutoff:
if dict[v] in freqDict:
freqDict[dict[v]].append(v)
else:
freqDict[dict[v]] = [v]
informativeList = []
for v in freqDict.keys():
for i in self._getMostInformative(freqDict[v]):
# print 'v:', v
# print 'self.weightedNoTrees: ', self.weightedNoTrees
informativeList.append([left, i, v / self.weightedNoTrees])
return informativeList
def _getMostInformative(self, list):
smallest = self._popcount(list[0])
index = [list[0]]
for i in range(1, len(list)):
pop = self._popcount(list[i])
if pop == smallest:
index.append(list[i])
elif pop < smallest:
smallest = pop
index = [list[i]]
return index
def _pop(self, countie, bitkeys):
count = 0
for bk in bitkeys:
if countie < bk:
return count
if bk & countie:
count += 1
return count
def _popcount(self, countie):
count = 0
for bk in self.bitkeys:
if countie < bk:
return count
if bk & countie:
count += 1
return count
def _sortByLeafSet(self, list):
leafDict = {}
for i in list:
if i[1] in leafDict:
leafDict[i[1]].append(i)
else:
leafDict[i[1]] = [i]
list = []
for v in leafDict.values():
list.append(v)
return list
def _buildUnconflictingLists(self, list, sort=True):
if sort:
self._sortListByHits(list)
unconflictingLists, unresolved = self._unConflictingList(list)
if len(unresolved) > 0:
lenght = len(unresolved) + 1
while len(unresolved) < lenght:
unconflictingLists, unresolved = self._unconflictingList(
unconflictingLists, unresolved, False)
lenght = len(unresolved)
if lenght > 0:
unconflictingLists, unresolved = self._unconflictingList(
unconflictingLists, unresolved, True)
if len(unresolved) > 0:
print('Damn, unconlictingList creation gone wrong')
return unconflictingLists
def _unConflictingList(self, list):
return self._unconflictingList([[list.pop(0)]], list, False)
def _unconflictingList(self, unconflictingLists, list, addUnresolvedToMultiple=False):
unresolved = []
for i in range(0, len(list)):
temp = []
for ii in unconflictingLists:
temp.append(ii[:])
insert = []
for j in range(0, len(temp)):
added = False
conflict = False
for k in temp[j]:
if list[i][0] & k[0] != 0:
if list[i][0] & k[0] != list[i][0] and list[i][0] & k[0] != k[0]:
conflict = True
if not conflict:
insert.append(j)
if len(insert) > 0:
if len(insert) > 1:
if addUnresolvedToMultiple:
for index in insert:
unconflictingLists[index].append(list[i])
else:
unresolved.append(list[i])
else:
for index in insert:
unconflictingLists[index].append(list[i])
else:
unconflictingLists.append([list[i]])
return unconflictingLists, unresolved
def _sortListByHits(self, list):
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
if list[i][2] < list[j][2]:
temp = list[i]
list[i] = list[j]
list[j] = temp
class TreeHandler(object):
def __init__(self, inThing=None, skip=0, max=None):
self.trees = []
self.intersections = []
self.nTrees = 0
self.taxNames = None
self.nTax = 0 # Merely len(taxNames)
self.splits = []
self.bitkeys = []
self.splitsDict = {}
# biSplits and biSplitsDict for the "even" side of the biRoot
# bifurcation.
self.biSplits = []
self.biSplitsDict = {}
self.isBiRoot = None # Is it a bifurcating root?
self.doModelComments = 0
self.modelInfo = None
self.treesBuilt = False
self.tfl = None
self.verbose = 1
if inThing:
self.read(inThing, skip, max)
def read(self, inThing, skip=0, max=None):
"""Read in a tree, or some trees, or a file of trees.
Arg inThing can be a file name, a Tree instance, or a
Trees instance. Args skip and max only come into play if
inThing is a file or a Trees object."""
gm = ['TreePartitions.read()']
if not inThing:
gm.append("No input?")
raise P4Error(gm)
# print "inThing = %s, type %s" % (inThing, type(inThing))
if isinstance(inThing, str):
var.trees = []
read(inThing)
if len(var.trees) < 1:
gm.append(
'Sorry, at least one tree must be supplied as input tree')
raise P4Error(gm)
self.trees = var.trees
# self.tfl = TreeFileLite(inThing)
elif isinstance(inThing, list):
for t in inThing:
if not isinstance(t, Tree):
gm.append(
"Input trees should be a list of p4 Tree objects. Got %s" % t)
raise P4Error(gm)
self.trees = inThing
else:
gm.append(
"Sorry-- I can't grok your input. What is '%s'?" % inThing)
raise P4Error(gm)
def _printSplitList(self, list):
for split in list:
# for split in list:
intersection = ''
for bk in self.bitkeys:
if bk & split[0]:
intersection = intersection + '1'
elif bk & split[1]:
intersection = intersection + '?'
else:
intersection = intersection + '*'
print('Split: %s' % (intersection))
print(' ')
def buildTreesFromIntersections(self, intersections, minimumSplitsToBuildTree, treeLabel):
print(treeLabel)
# print 'Taxnumber to taxname translation: '
# for i in range(0, len(self.taxNames)):
# print '%s,%s' % ( i+1, self.taxNames[i])
treeBuilder = TreeBuilderFromSplits(self.bitkeys, self.taxNames)
conTrees = []
for l in intersections:
if len(l) >= minimumSplitsToBuildTree:
l.sort()
t = treeBuilder.buildTreeFromInformativeList(l)
ex = 'Excluded taxa: '
for j in l[0].excluded:
ex += self.taxNames[j] + ', '
if len(l[0].excluded) > 0:
print(ex[0:len(ex) - 2])
else:
print(ex + 'None')
# sp = 'Included splits: '
t.draw(showInternalNodeNames=1, addToBrLen=0.2,
width=None, showNodeNums=0, partNum=0, model=None)
print('')
conTrees.append(t)
# for i in range(len(conTrees)):
# if isinstance(conTrees[i], Tree):
# ex = 'Exluded taxa: '
# print 'len(conTrees[i].excluded): ', len(conTrees[i].excluded)
# for j in conTrees[i].excluded:
# ex += self.taxNames[j] +', '
# print ex[0:len(ex)-1]
# self.printVerticalNumbers(len(self.taxNames))
# for j in intersections[i]:
# j.printIntersection(self.bitkeys)
# conTrees[i].draw(showInternalNodeNames=1, addToBrLen=0.2, width=None, showNodeNums=0, partNum=0, model=None)
# print ''
if len(conTrees) == 0:
print('No trees where produced')
return []
else:
return conTrees
def buildConsensusTreeHash(self, contree):
if not contree._taxNames:
contree._setTaxNamesFromLeaves()
# t.taxNames.sort()
dict = {}
self.bitkeys = []
for i in range(len(contree.taxNames)):
self.bitkeys.append(1 << i)
dict[contree.taxNames[i]] = 1 << i
self.taxNames = contree.taxNames
def commonLeafSetTrees(self, trees):
list = []
if self.verbose:
print('Processing input trees')
sys.stdout.flush()
sampler = int(float(len(trees) / 20))
if sampler < 1:
sampler = 1
t = trees[0]
if not t._taxNames:
t._setTaxNamesFromLeaves()
t.taxNames.sort()
dict = {}
self.bitkeys = []
for i in range(len(t.taxNames)):
self.bitkeys.append(1 << i)
dict[t.taxNames[i]] = 1 << i
self.taxNames = t.taxNames
self.weights = []
splits = []
treeNames = []
index = 0
if self.verbose:
sys.stdout.write('0% ')
sys.stdout.flush()
for t in trees:
index += 1
if self.verbose:
if index % sampler == 0:
sys.stdout.write('.')
sys.stdout.flush()
treeNames.append(t.name)
for n in t.iterLeavesNoRoot():
n.br.rc = dict[n.name]
t.splits = []
t._makeRCSplitKeys(t.splits)
weight = 1.0
if hasattr(t, "weight"):
if t.weight != None:
weight = t.weight
elif hasattr(t, "recipWeight"):
if t.recipWeight != None:
weight = 1.0 / int(t.recipWeight)
self.weights.append(weight)
splits.append(t.splits)
if self.verbose:
sys.stdout.write(' 100%\n')
sys.stdout.flush()
return splits, treeNames
def commonLeafSet(self, tfl):
list = []
print('Processing input trees')
sys.stdout.flush()
sampler = int(float(tfl.nSamples) / 20)
if sampler < 1:
sampler = 1
t = tfl.getTree(1)
if not t._taxNames:
t._setTaxNamesFromLeaves()
t.taxNames.sort()
dict = {}
self.bitkeys = []
for i in range(len(t.taxNames)):
self.bitkeys.append(1 << i)
dict[t.taxNames[i]] = 1 << i
self.taxNames = t.taxNames
self.weights = []
splits = []
treeNames = []
index = 0
sys.stdout.write('0% ')
sys.stdout.flush()
for i in range(tfl.nSamples):
index += 1
if index % sampler == 0:
sys.stdout.write('.')
sys.stdout.flush()
t = tfl.getTree(i)
treeNames.append(t.name)
for n in t.iterLeavesNoRoot():
n.br.rc = dict[n.name]
t.splits = []
t._makeRCSplitKeys(t.splits)
weight = 1.0
if hasattr(t, "weight"):
if t.weight != None:
weight = t.weight
elif hasattr(t, "recipWeight"):
if t.recipWeight != None:
weight = 1.0 / int(t.recipWeight)
self.weights.append(weight)
splits.append(t.splits)
sys.stdout.write(' 100%\n')
sys.stdout.flush()
return splits, treeNames
def updateToCommonLeafSet(self, tfl):
uniqueSet = set()
print('Processing input trees')
sampler = int(float(tfl.nSamples) / 20)
if sampler < 1:
sampler = 1
index = 0.0
sys.stdout.write('0% ')
for i in range(tfl.nSamples):
index += 0.5
if index % sampler == 0:
sys.stdout.write('.')
# print '%s trees processed ' % (int (index))
sys.stdout.flush()
t = tfl.getTree(i)
# t.setPreAndPostOrder()
if not t._taxNames:
t._setTaxNamesFromLeaves()
for name in t.taxNames:
uniqueSet.add(name)
self.list = []
for name in uniqueSet:
self.list.append(name)
self.list.sort()
dict = {}
self.bitkeys = []
for i in range(len(self.list)):
self.bitkeys.append(1 << i)
dict[self.list[i]] = 1 << i
self.taxNames = self.list
self.weights = []
splits = []
treeNames = []
for i in range(tfl.nSamples):
index += 0.5
if index % sampler == 0:
sys.stdout.write('.')
# print '%s trees processed' % (int(index))
sys.stdout.flush()
t = tfl.getTree(i)
treeNames.append(t.name)
if not t._taxNames:
t._setTaxNamesFromLeaves()
t.missingTaxa = 0
for name in self.list:
if not t.taxNames.count(name):
t.missingTaxa = t.missingTaxa | dict[name]
for n in t.iterLeavesNoRoot():
n.br.rc = dict[n.name]
# t.setPreAndPostOrder()
t._makeRCSplitKeys()
weight = 1.0
if hasattr(t, "weight"):
if t.weight != None:
weight = t.weight
elif hasattr(t, "recipWeight"):
if t.recipWeight != None:
weight = 1.0 / int(t.recipWeight)
self.weights.append(weight)
t.splits = []
for n in t.iterInternalsNoRoot():
for m in n.br.rcList:
t.splits.append([m, t.missingTaxa])
splits.append(t.splits)
sys.stdout.write(' 100%\n')
return splits, treeNames
def writeNexus(self, trees, filename='rcTrees'):
"""
This method will simply write the trees that were produced by calling reduced()
If a filename is provided the trees will be written to a number of files starting
with filename1.nex and then filename2.nex and so on and so forth. This is done to avoid
mixing trees with different leafsets in the same nexus file as this will make it hard
to open the file in other software.
Beware that if no filename is provided the default filename will be used and might therefore overwrite
already completed analysis.
"""
if len(trees) == 0:
print('No trees created, can not write to file')
return
for i in range(len(trees)):
trees[i].writeNexus('%s%d%s' % (filename, i + 1, '.nex'))
def first(self, split):
first = 0.0
for bk in self.bitkeys:
if bk & split:
return first
first += 1
return first
def popcount(self, n):
count = 0
for bk in self.bitkeys:
if n < bk:
return count
if bk & n:
count += 1
return count
def _reduceSplitsPair(self, reduceRules, splits1, splits2):
newSplits1 = []
newSplits1.extend(splits1)
newSplits2 = []
newSplits2.extend(splits2)
splitList = []
splitList.append(newSplits1)
splitList.append(newSplits2)
excludedTaxa = []
for row in csvReader:
# print row
for name in row:
if name not in taxa2bitkey:
print('Taxname not in any tree: ', name)
return
if len(row) > 1:
reduce2 = taxa2bitkey[row[0]]
substitute = 0
for i in range(1, len(row)):
excludedTaxa.append(taxa2bitkey[row[i]])
substitute = substitute | taxa2bitkey[row[i]]
for i in range(len(splitList)):
# if i == 239:
# print 'Before'
# for split in splitList[i]:
# Intersection(split[0], split[1]).printIntersection(self.bitkeys)
for split in splitList[i]:
if split[0] & substitute != 0:
split[0] = (split[0] & substitute) ^ split[0]
split[0] = split[0] | reduce2
if split[1] & reduce2 != 0:
split[1] = split[1] ^ reduce2
split[1] = split[1] | substitute
# if i == 239:
# print 'After'
# for split in splitList[i]:
# Intersection(split[0], split[1]).printIntersection(self.bitkeys)
# sys.exit()
excluded = 0
for bitkey in excludedTaxa:
excluded = excluded | bitkey
return excluded
return newSplits1, newSplits2
def _reduceSplits(self, csvReader, splitList):
bitkey2taxa = {}
taxa2bitkey = {}
for i in range(0, len(self.taxNames)):
taxa2bitkey[self.taxNames[i]] = self.bitkeys[i]
bitkey2taxa[self.bitkeys[i]] = self.taxNames[i]
excludedTaxa = []
for row in csvReader:
# print row
for name in row:
if name not in taxa2bitkey:
print('Taxname not in any tree: ', name)
return
if len(row) > 1:
reduce2 = taxa2bitkey[row[0]]
substitute = 0
for i in range(1, len(row)):
excludedTaxa.append(taxa2bitkey[row[i]])
substitute = substitute | taxa2bitkey[row[i]]
for i in range(len(splitList)):
# if i == 239:
# print 'Before'
# for split in splitList[i]:
# Intersection(split[0], split[1]).printIntersection(self.bitkeys)
for split in splitList[i]:
if split[0] & substitute != 0:
split[0] = (split[0] & substitute) ^ split[0]
split[0] = split[0] | reduce2
if split[1] & reduce2 != 0:
split[1] = split[1] ^ reduce2
split[1] = split[1] | substitute
# if i == 239:
# print 'After'
# for split in splitList[i]:
# Intersection(split[0], split[1]).printIntersection(self.bitkeys)
# sys.exit()
excluded = 0
for bitkey in excludedTaxa:
excluded = excluded | bitkey
return excluded
def _sortCompatibleSplits(self, compatibleSplits):
pops = []
i = 0
for split in compatibleSplits:
pop = self.popcount(split[0])
if pop > 1:
first = self.first(split[0]) / len(self.bitkeys)
pop = pop - first
pops.append([pop, i])
i += 1
# print 'unsorted', pops
pops = sorted(pops, key=operator.itemgetter(0), reverse=True)
# print 'sorted', pops
sortedList = []
# print 'Sorting indices'
for i in range(0, len(pops)):
# print pops[i][1]
sortedList.append(compatibleSplits[pops[i][1]])
return sortedList
class SemiStrictSuperTree(TreeHandler):
def __init__(self, inThing=None, skip=0, max=None):
TreeHandler.__init__(self, inThing, skip, max)
def semiStrictSuperTree(self, semiStrictOnly=True, treatMultifurcatingRootsAsUnrooted=0, readCSV=False, csvFile='clades.csv'):
"""
Creates a semistrict supertree from a set of trees. Setting the semiStricOnly to False will
cause the algorithm to try and resolve any unresolved nodes by using majority rule consensus.
Usage:
rd = Reduced('filename')
rd.semiStrictSuperTree()
to safe supertree to file:
rd.writeNexus('outFilename')
"""
self.conTrees = []
self.splits, treeNames = self.updateToCommonLeafSet(self.tfl)
if readCSV:
csvReader = csv.reader(open(csvFile), delimiter=',', quotechar='"')
excludedTaxa = self._reduceSplits(csvReader, self.splits)
else:
excludedTaxa = 0
# print 'Excluded taxa: ', excludedTaxa
noSplitSets = len(self.splits)
# Found no splits, exiting
if len(self.splits) < 1:
print('No splits found')
return
# Found only one set of splits, as the set is from a single tree it is
# assumed to be compatible
elif len(self.splits) == 1:
print('Only one set of splits')
treeBuilder = TreeBuilderFromSplits(self.bitkeys, self.taxNames)
intersections = []
self._printSplitList(self.splits[0])
for split in self.splits[0]:
intersections.append(Intersection(split[0], split[1], 1, 1))
# print self.splits
# print self.splits[0]
mrc = treeBuilder.buildTreeFromInformativeList(intersections)
mrc.draw()
print(' ')
return
print('Creating compatible split set, collapsing where necessary')
set = []
for t in range(0, len(self.splits)):
for split in self.splits[t]:
set.append(split)
print('Initial splits')
self._printSplitList(set)
overlappingLists = self._buildOverlappingLists(set)
print('Overlapping lists: ', len(overlappingLists))
for list in overlappingLists:
self._printSplitList(list)
combinedLists = []
for list in overlappingLists:
combinedLists.append(self._combineList(list))
print('Combined overlapping lists')
for list in combinedLists:
self._printSplitList(list)
# print 'Creating compatible splits list'
compatibleSplits = []
for list in combinedLists:
for split in list:
compatibleSplits.append(split)
uniqueSet = {}
for split in compatibleSplits:
if (split[0], split[1]) not in uniqueSet:
uniqueSet[(split[0], split[1])] = 1
# self._isListCompatible(uniqueSet.keys())
compatibleSplits = []
for split in uniqueSet.keys():
compatibleSplits.append([split[0], split[1]])
compatibleSplits = self._sortCompatibleSplits(compatibleSplits)
print('Sorted compatible splits')
self._printSplitList(compatibleSplits)
if semiStrictOnly:
# To conform to the format used by treeBuilderFromSplits we need to convert splits into intersections
# We use a set to remove any duplicate splits created during the
# refinement
# print '1. Looking for incompatibility in fixed compatible splits'
compatibleIntersections = []
for s in compatibleSplits:
i = Intersection(s[0], s[1], 1, 1)
compatibleIntersections.append(i)
if False:
print('List of unique splits')
for s in compatibleSplits:
self._printSplitList([s])
print('Building tree')
treeBuilder = TreeBuilderFromSplits(self.bitkeys, self.taxNames)
mrc = treeBuilder.buildTreeFromInformativeList(
compatibleIntersections, excluded=excludedTaxa)
# mrc.dump()
mrc.draw()
print(' ')
self.treesBuilt = True
self.conTrees.append(mrc)
return mrc
def _compatible(self, split1, split2):
if split1[0] & split2[0] == 0:
# print '1'
# self._printSplitList([split1])
# self._printSplitList([split2])
return True
if split1[0] | split2[0] == split1[0] or split1[0] | split2[0] == split2[0]:
# print '2'
# self._printSplitList([split1])
# self._printSplitList([split2])
return True
if (split1[0] | split1[1]) | split2[0] == (split1[0] | split1[1]):
# print '3'
# self._printSplitList([split1])
# self._printSplitList([split2])
return True
if (split2[0] | split2[1]) | split1[0] == (split2[0] | split2[1]):
# print '4'
# self._printSplitList([split1])
# self._printSplitList([split2])
return True
if (split1[0] | split1[1]) | (split2[0] | split2[1]) == (split1[0] | split1[1]):
# print '5'
# self._printSplitList([split1])
# self._printSplitList([split2])
return True
if (split1[0] | split1[1]) | (split2[0] | split2[1]) == (split2[0] | split2[1]):
# print '6'
# self._printSplitList([split1])
# self._printSplitList([split2])
return True
if (split1[0] | (split2[0] | split2[1]) != (split1[0] | split1[1]) and split1[0] | (split2[0] | split2[1]) != (split2[0] | split2[1])):
# print '7'
# self._printSplitList([split1])
# self._printSplitList([split2])
return False
print('Defaults to true, not good')
print(split1)
print(split2)
print(' ')
return True
def _isListCompatible(self, list):
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
if not self._compatible(list[i], list[j]):
print('incompatible')
self._printSplitList([list[i]])
self._printSplitList([list[j]])
def _buildOverlappingLists(self, list):
overlappingList = [[list.pop()]]
firstTime = True
while len(list) > 0:
# print 'while 1'
i = 0
while i <= 1:
for oList in overlappingList:
remaining = []
for split in list:
add = False
for s in oList:
if s[0] & split[0] != 0:
add = True
break
if add:
oList.append(split)
else:
remaining.append(split)
list = remaining
i += 1
if len(list) > 0:
overlappingList.append([list.pop()])
#
# if len(list) == 1 and not firstTime:
# overlappingList.append([list.pop()])
# if len(list) == 1 and firstTime:
# firstTime = False
#
return overlappingList
def _combineList(self, list):
allOnes = 2 ** (len(self.bitkeys)) - 1
# zero will hold the unknown taxa off all splits in list
zero = 0
for split in list:
zero = zero | split[1]
# print 'Zero: '
self._printSplitList([[zero, 0]])
xor = zero ^ allOnes
# print 'Xor: '
self._printSplitList([[xor, 0]])
includeSplit = []
if len(list) > 2:
for split in list:
# self._printSplitList([split])
# split[1] = split[1] & xor
# self._printSplitList([split])
# print '*********************************'
includeSplit.append(False)
# print '*********************************'
else:
for split in list:
includeSplit.append(True)
includesCombined = {}
for i in range(0, len(list) - 1):
for j in range(i + 1, len(list)):
# self._printSplitList([list[i]])
# self._printSplitList([list[j]])
if list[i][0] & list[j][0] != 0 and list[i][0] | list[j][0] != list[i][0] and list[i][0] | list[j][0] != list[j][0]:
combined = list[i][0] | list[j][0]
if self.popcount(list[i][0] | list[i][1]) >= self.popcount(list[j][0] | list[j][1]):
# list[i][0] = combined
# list[i][1] = list[i][1] & xor
if list[i][0] in includesCombined:
includesCombined[list[i][0]] = -1
includesCombined[combined] = list[i][1] & xor
else:
# list[j][0] = combined
# list[j][1] = list[j][1] & xor
if list[j][0] in includesCombined:
includesCombined[list[j][0]] = -1
includesCombined[combined] = list[j][1] & xor
# self._printSplitList([list[i]])
# self._printSplitList([list[j]])
# print '------------------------------------'
else:
# if (list[i][0] | list[i][1]) | (list[j][0]) == list[j][0]
# or (list[j][0] | list[j][1]) | list[i][0] == list[i][0]:
includeSplit[i] = True
includeSplit[j] = True
# if not self._compatible(list[i], list[j]):
# self._printSplitList([list[i]])
# self._printSplitList([list[j]])
# print
# 'incompatible****************************************************************************'
newList = []
for i in range(0, len(includeSplit)):
if includeSplit[i]:
list[i][1] = list[i][1] & xor
newList.append(list[i])
# print includesCombined
for i in includesCombined.keys():
if includesCombined[i] >= 0:
newList.append([i, includesCombined[i]])
return newList
class CompatibleTreePairs(TreeHandler):
def __init__(self, inThing=None, skip=0, max=None):
TreeHandler.__init__(self, inThing, skip, max)
def slidingWindow(self, csvInfile, csvOutFile, startYear, stopYear, maxDistance):
csvReader = csv.reader(open(csvInfile), delimiter=',', quotechar='"')
csvWriter = csv.writer(open(csvOutFile, "wb"))
yearDict = {}
first = True
index = 0
added = 0
rows = []
for row in csvReader:
rows.append(row)
if first:
first = False
else:
index += 1
name = str(row[0])
for i in range(startYear, stopYear + 1):
if name.rfind(str(i)) >= 0 and name.rfind('SCC') < 0:
if i in yearDict:
added += 1
yearDict[i].append([index, row])
else:
added += 1
yearDict[i] = [[index, row]]
break
# print 'Counted rows: ', index
# print 'Added rows: ', added
# print ''
ranges = []
for index in range(startYear, stopYear - 4):
# print 'Range: %s, %s ' % (index, index+4)
studies = []
for i in range(index, index + 5):
if i in yearDict:
studies.extend(yearDict[i])
# for studie in studies:
# print studie[1][0]
# print studie
# print ''
# print '[%s, %s, %'
ranges.append([index, index + 4, studies])
# print ''
outputRows = []
for dataset in ranges:
row = [dataset[0], dataset[1]]
print('%s, %s' % (dataset[0], dataset[1]))
indecies = []
print(dataset)
for data in dataset[2]:
indecies.append(data[0])
print(indecies)
sum = 0.0
index = 0
for data in dataset[2]:
for data1 in dataset[2]:
if data[0] != data1[0]:
if data[1][data1[0]] != '*':
index += 1
# print data[1][data1[0]]
sum += float(data[1][data1[0]]) / maxDistance
if sum == 0:
row.append(0)
print('Mean: ', 0)
print('')
else:
row.append(sum / index)
print('Mean: ', sum / index)
print('')
outputRows.append(row)
csvWriter.writerows(outputRows)
def formatNexusFileFromCSV(self, csvInFile, nexusOutFile, maxDistance):
csvDataReader = csv.reader(
open(csvInFile), delimiter=',', quotechar='"')
data = []
nTax = -1
for row in csvDataReader:
nTax += 1
data.append(row)
writer = open(nexusOutFile, 'w')
writer.write('#NEXUS\n')
writer.write('[!user defined distances]\n')
writer.write('Begin distances;\n')
dim = 'Dimensions ntax=%s;\n' % (str(nTax))
writer.write(dim)
writer.write('format nodiagonal;\n')
writer.write('matrix\n')
first = True
for index in range(0, len(data)):
if first:
first = False
else:
print(data[index][0:index])
line = ''
for datapoint in data[index][0:index]:
if datapoint == '*':
line += '0.0'
else:
line += datapoint
line += ' '
if index == len(data) - 1:
line += ';'
line += '\n'
print(line)
writer.write(line)
writer.write('end;\n')
writer.write('[! nj with user defined distances]\n')
writer.write('Begin paup;\n')
writer.write('dset distance=user;\n')
writer.write('nj;\n')
writer.write('end;\n')
writer.close()
def calcStatsFromCSV(self, csvInFileIntraGroups, csvInFileInterGroups, csvInFileData, csvOutFile, maxDistance):
csvInterGroupReader = csv.reader(
open(csvInFileInterGroups), delimiter=',', quotechar='"')
csvIntraGroupReader = csv.reader(
open(csvInFileIntraGroups), delimiter=',', quotechar='"')
csvDataReader = csv.reader(
open(csvInFileData), delimiter=',', quotechar='"')
data = []
for row in csvDataReader:
if len(row) > 0:
data.append(row)
length = 0
if len(data) > 0:
length = len(data[0])
intraGroups = {}
for row in csvIntraGroupReader:
if len(row) > 0:
# print 'Adding:', row[0]
intraGroups[row[0]] = [length, 0]
# print intraGroups.keys()
interGroups = {}
for row in csvInterGroupReader:
if len(row) > 0:
interGroups[row[0]] = [0]
# print interGroups.keys()
for index in range(1, len(data)):
if data[index][0] in interGroups:
interGroups[data[index][0]] = index
# print 'Inter: %s, %s ' % (data[index][0], index)
else:
if data[index][0] in intraGroups:
# print 'data[index][0]:', data[index][0]
# print 'data[index][1]:', data[index][1]
# print 'intraGroups[data[index][0]]:',intraGroups[data[index][0]]
# print '%s : %s,%s' % (data[index][0], index, index)
intraGroups[data[index][0]] = [index, index]
# if index < intraGroups[data[index][0]]:
# intraGroups[data[index][0]] = [index,intraGroups[data[index][1]]]
# if index > intraGroups[data[index][1]]:
# intraGroups[data[index][0]] = [intraGroups[data[index][0],index]]
else:
start = data[index][0].find('mpt')
substring = ''
if start > 0:
substring = data[index][0].partition('mpt')[0]
else:
start = data[index][0].find('SCC')
if start > 0:
substring = data[index][0].partition('SCC')[0]
# print 'Created substring:', substring
if substring in intraGroups:
if index < intraGroups[substring][0]:
intraGroups[substring] = [index, intraGroups[substring][1]]
# print 'Intra1: %s, %s - %s ' % (substring, index,
# intraGroups[substring][1])
elif index > intraGroups[substring][1]:
intraGroups[substring] = [intraGroups[substring][0], index]
# print 'Intra2: %s, %s - %s ' % (substring, intraGroups[substring][0],
# index )
else:
print('Found nowhere to put: ', data[index][0])
writer = csv.writer(open(csvOutFile, "wb"))
outputRows = []
outputRows.append(
['Real trees intra group distances, normalized, mean distance'])
outputRows.append(['Study', 'Mean'])
# print 'Study: max, min, mean'
def compareStudyYear(x, y):
# print '%s, %s' % (x[0], x[0][len(x[0])-4:len(x[0])])
# print '%s, %s' % (y[0], y[0][len(y[0])-4:len(y[0])])
# print ''
return int(x[0][len(x[0]) - 4:len(x[0])]) - int(y[0][len(y[0]) - 4:len(y[0])])
items = intraGroups.items()
items.sort(cmp=compareStudyYear)
# print items
for k, v in items:
# max, min, mean = self.calcIntraMinMaxMean(data, v[0], v[1])
outputRows.append(
[k, self.calcIntraMinMaxMean(data, v[0], v[1], maxDistance)])
# print '%s: %s, %s, %s' % (k, max, min, mean)
outputRows.append([' '])
outputRows.append(
['Real trees inter group distances, normalized, mean distance'])
items = intraGroups.items()
items.sort(cmp=compareStudyYear)
# print 'Study: max, min, mean'
row = ['']
for i in items:
row.append(i[0])
outputRows.append(row)
for i in range(0, len(items)):
row = [items[i][0]]
for j in range(0, len(items)):
row.append(self.calcIntraInterMinMaxMean(data, items[i][1][0], items[
i][1][1], items[j][1][0], items[j][1][1], maxDistance))
# max, min, mean = self.calcIntraMinMaxMean(data, v[0], v[1])
outputRows.append(row)
# print '%s: %s, %s, %s' % (k, max, min, mean)
items = interGroups.items()
items.sort(cmp=compareStudyYear)
outputRows.append([' '])
outputRows.append(
['Expertograms intra tree distances, normalized, distance'])
row = ['']
for i in items:
row.append(i[0])
outputRows.append(row)
for i in range(0, len(items)):
row = [items[i][0]]
for j in range(0, len(items)):
# print items[i][1]
# print items[j][1]
if data[items[i][1]][items[j][1]] == '*':
row.append(data[items[i][1]][items[j][1]])
else:
row.append(
float(data[items[i][1]][items[j][1]]) / maxDistance)
outputRows.append(row)
intraItems = intraGroups.items()
intraItems.sort(cmp=compareStudyYear)
outputRows.append([' '])
outputRows.append(
['Expertograms to real trees inter group distances, normalized, mean distance'])
row = ['']
for i in intraItems:
row.append(i[0])
outputRows.append(row)
for key, value in items:
row = []
row.append(key)
for d, v in intraItems:
row.append(
self.calcInterMinMaxMean(data, value, v[0], v[1], maxDistance))
outputRows.append(row)
writer.writerows(outputRows)
def median(self, numbers):
# "Return the median of the list of numbers."
# Sort the list and take the middle element.
n = len(numbers)
copy = numbers[:] # So that "numbers" keeps its original order
copy.sort()
if n & 1: # There is an odd number of elements
return copy[n // 2]
else:
return (copy[n // 2 - 1] + copy[n // 2]) / 2
def calcInterMinMaxMean(self, data, inter, start, stop, maxDistance):
min = 99999999999
max = 0
sum = 0.0
index = 0
# print 'range: ', range(start, stop+1)
for i in range(start, stop + 1):
if data[inter][i] == '*':
# return '*/*/*'
return '*'
# print '%s,%s' % (inter,i)
dist = float(data[inter][i]) / maxDistance
if dist < min:
min = dist
if dist > max:
max = dist
sum += dist
index += 1
# return str(max) +'/' + str(min) +'/' + str(sum/index)
return sum / index
def calcIntraInterMinMaxMean(self, data, start1, stop1, start2, stop2, maxDistance):
min = 99999999999.0
max = 0.0
sum = 0.0
index = 0.0
for i in range(start1, stop1 + 1):
for j in range(start2, stop2 + 1):
if i != j:
# print '%s, %s' % (data[i][0], data[j][0])
if data[i][j] == '*':
# return '*','*','*'
return '*'
# print '%s,%s' % (i,j)
dist = float(data[i][j]) / maxDistance
if dist < min:
min = dist
if dist > max:
max = dist
sum += dist
index += 1
# return max, min, sum/index
if sum != 0:
return sum / index
return 0
def calcIntraMinMaxMean(self, data, start, stop, maxDistance):
min = 99999999999.0
max = 0.0
sum = 0.0
index = 0.0
# print range(start, stop+1)
for i in range(start, stop + 1):
for j in range(i + 1, stop + 1):
# print '%s, %s' % (data[i][0], data[j][0])
if data[i][j] == '*':
# return '*','*','*'
return '*'
# print '%s,%s' % (i,j)
dist = float(data[i][j]) / maxDistance
if dist < min:
#==========================================================
# print 'MIN: %s < %s' % (float(data[i][j]) ,min)
#==========================================================
min = dist
if dist > max:
#==========================================================
# print 'MAX: %s > %s' % (float(data[i][j]) ,max)
#==========================================================
max = dist
sum += dist
index += 1
# print 'Max: ', max
# print 'Min: ', min
# print 'Sum: ', sum
# print 'Index: ', index
# print 'Mean: ', sum/index
# return max, min, sum/index
if sum != 0:
return sum / index
return 0
def createCompatibleTreePairs(self, readCSV=True, csvInFile='clades.csv', writeCSV=False, csvOutFile='treeMatrix.csv', writeNexus=True, nexusFilename='compatiblePair', calcDistances=False, distanceMeasure='sd'):
from p4.trees import Trees
self.splits, treeNames = self.updateToCommonLeafSet(self.tfl)
self.bitkey2taxa = {}
self.taxa2bitkey = {}
for i in range(0, len(self.taxNames)):
self.taxa2bitkey[self.taxNames[i]] = self.bitkeys[i]
self.bitkey2taxa[self.bitkeys[i]] = self.taxNames[i]
if readCSV:
csvReader = csv.reader(
open(csvInFile), delimiter=',', quotechar='"')
self.replace = []
self.reduceRules = []
for row in csvReader:
self.reduceRules.append(row)
self.replace.append(self.taxa2bitkey[row[0]])
if writeCSV:
writer = csv.writer(open(csvOutFile, "wb"))
treeBuilder = TreeBuilderFromSplits(
self.bitkeys, self.taxNames, internalNames=False)
# bitkey2taxa = {}
# taxa2bitkey = {}
# for i in range(0,len(self.taxNames)):
# taxa2bitkey[self.taxNames[i]] = self.bitkeys[i]
# bitkey2taxa[self.bitkeys[i]] = self.taxNames[i]
index = 0
csvOut = []
if writeCSV:
names = []
names.append('')
for i in range(0, len(treeNames)):
names.append(treeNames[i])
csvOut.append(names)
if calcDistances:
max = 0.0
min = 99999999999999.0
sum = 0.0
distances = 0
if distanceMeasure not in ['sd', 'triplet', 'quartet']:
print('No such distance measure: ', distanceMeasure)
print('Defaulting to Symmetric Difference')
distanceMeasure = 'sd'
print('Distance measure: ', distanceMeasure)
dict = {}
nonUnique = {}
for i in range(0, len(self.splits)):
row = []
if writeCSV:
row.append(treeNames[i])
for j in range(0, len(self.splits)):
if j >= i:
if i != j:
set1, set2, taxnames, pairExcluded = self._commonTaxaFromTreePair(
self.splits[i], self.splits[j])
if len(set1) > 0 and len(set2) > 0:
uniqueSet = {}
for split in set1:
if (split[0], split[1]) not in uniqueSet:
uniqueSet[(split[0], split[1])] = 1
list = []
for s in uniqueSet.keys():
list.append([s[0], s[1]])
list = self._sortCompatibleSplits(list)
intersections1 = []
for s in list:
intersections1.append(
Intersection(s[0], s[1], 1, 1))
mrc = treeBuilder.buildTreeFromInformativeList(
intersections1, treeNames[i])
mrc.taxNames = taxnames
uniqueSet = {}
for split in set2:
if (split[0], split[1]) not in uniqueSet:
uniqueSet[(split[0], split[1])] = 1
list = []
for s in uniqueSet.keys():
list.append([s[0], s[1]])
list = self._sortCompatibleSplits(list)
intersections2 = []
for s in list:
intersections2.append(
Intersection(s[0], s[1], 1, 1))
mrc2 = treeBuilder.buildTreeFromInformativeList(
intersections2, treeNames[j])
mrc2.taxNames = taxnames
if calcDistances:
if mrc.root.getNChildren() <= 2:
mrc.removeRoot()
if mrc2.root.getNChildren() <= 2:
mrc2.removeRoot()
dist = mrc.topologyDistance(
mrc2, distanceMeasure)
# dist = float(dist)/setSize
dict[(j, i)] = float(dist)
if dist > max:
max = dist
if dist < min:
min = dist
sum += dist
# if dist > 0:
distances += 1
if dist == 0.0:
study = 'Clark2004'
if mrc.name.startswith(study) and mrc2.name.startswith(study):
print('%s, %s: %s' % (mrc.name, mrc2.name, dist))
print()
if writeCSV:
row.append(dist)
if dist == 0:
nonUnique[mrc2.name] = 1
sys.stdout.flush()
if writeNexus:
trees = Trees([mrc, mrc2], taxnames)
trees.writeNexus(fName='%s%d%s' % (
nexusFilename, index + 1, '.nex'), append=0, withTranslation=1, writeTaxaBlock=1, likeMcmc=0)
index += 1
else:
dict[(j, i)] = '*'
if writeCSV:
row.append('*')
else:
row.append('0')
else:
row.append(dict[(i, j)])
if writeCSV:
csvOut.append(row)
print(' ')
if calcDistances:
print('Max: ', max)
print('Min: ', min)
print('Mean: ', sum / distances)
if writeCSV:
writer.writerows(csvOut)
print('Wrote csv-data to file: ', csvOutFile)
if writeNexus:
print('Wrote tree data to nexus files on format %s ' % (nexusFilename + '#' + '.nex'))
nonUniquetrees = nonUnique.keys()
# print sorted(nonUniquetrees)
def _commonTaxaFromTreePair(self, splits1, splits2):
replace = False
for i in self.replace:
# for split in splits1:
if not splits1[0][1] & i:
replace = True
break
# if replace:
# break
replace2 = False
for i in self.replace:
# for split in splits2:
if not splits2[0][1] & i:
replace2 = True
break
# if replace2:
# break
if (replace and not replace2) or (not replace and replace2):
# print
# '=========================================================='
splits1, splits2 = self._reduceSplitsPair(splits1, splits2)
included1 = 2 ** (len(self.bitkeys)) - 1
included1 = included1 ^ splits1[0][1]
included2 = 2 ** (len(self.bitkeys)) - 1
included2 = included2 ^ splits2[0][1]
combined = included1 & included2
excluded = 2 ** (len(self.bitkeys)) - 1 ^ combined
set1 = []
set2 = []
taxnames = []
if self.popcount(combined) >= 3:
for i in self.bitkeys:
if i & combined:
taxnames.append(self.bitkey2taxa[i])
temp = {}
for split in splits1:
if self.popcount(split[0] & combined) > 1 and ((split[0] & combined) ^ (split[1] | excluded)) ^ (2 ** (len(self.bitkeys)) - 1) != 0:
temp[(split[0] & combined, split[1] | excluded)] = 1
for i in temp.keys():
set1.append([i[0], i[1]])
temp = {}
for split in splits2:
if self.popcount(split[0] & combined) > 1 and ((split[0] & combined) ^ (split[1] | excluded)) ^ (2 ** (len(self.bitkeys)) - 1) != 0:
temp[(split[0] & combined, split[1] | excluded)] = 1
for i in temp.keys():
set2.append([i[0], i[1]])
return set1, set2, taxnames, excluded
def _reduceSplitsPair(self, splits1, splits2):
newSplits1 = []
for i in splits1:
newSplits1.append([i[0], i[1]])
newSplits2 = []
for i in splits2:
newSplits2.append([i[0], i[1]])
splitList = []
splitList.append(newSplits1)
splitList.append(newSplits2)
excludedTaxa = []
for row in self.reduceRules:
present = []
for name in row:
if name in self.taxa2bitkey:
present.append(name)
# else:
# print 'Taxname not in any tree: ', name
if len(present) > 1:
reduce2 = self.taxa2bitkey[present[0]]
substitute = 0
for i in range(1, len(present)):
excludedTaxa.append(self.taxa2bitkey[present[i]])
substitute = substitute | self.taxa2bitkey[present[i]]
for i in range(len(splitList)):
for split in splitList[i]:
if split[0] & substitute != 0:
split[0] = (split[0] & substitute) ^ split[0]
split[0] = split[0] | reduce2
if split[1] & reduce2 != 0:
split[1] = split[1] ^ reduce2
split[1] = split[1] | substitute
excluded = 0
for bitkey in excludedTaxa:
excluded = excluded | bitkey
# return excluded
return splitList[0], splitList[1]
def _buildAndDisplayTree(self, set1):
treeBuilder = TreeBuilderFromSplits(
self.bitkeys, self.taxNames, internalNames=False)
uniqueSet = {}
for split in set1:
if (split[0], split[1]) not in uniqueSet:
uniqueSet[(split[0], split[1])] = 1
list = []
for s in uniqueSet.keys():
list.append([s[0], s[1]])
list = self._sortCompatibleSplits(list)
intersections1 = []
for s in list:
intersections1.append(Intersection(s[0], s[1], 1, 1))
mrc = treeBuilder.buildTreeFromInformativeList(intersections1, 'name')
mrc.draw()
[docs]
class Reduced(TreeHandler):
def __init__(self, inThing=None, skip=0, max=None):
self.minimumProportion = 0.5
self.rooted = True
self.minNoOfTaxa = 3
self.treatMultifurcatingRootsAsUnrooted = 0
self.minimumSplitsToBuildTree = 1
self.extendProfile = False
self.saveTreesToFile = False
self.saveSplitsToFile = False
self.extendProfile = False
self.saveExtendedTreesToFile = False
self.saveExtendedToFile = False
self.treeFileName = 'trees'
self.splitFileName = 'splits.txt'
self.extendedTreeFileName = 'extendedTrees'
self.extendedFileName = 'extended.txt'
self.sortByNoSplits = True
TreeHandler.__init__(self, inThing, skip, max)
"""
A container for
* reduced strict consensus
* reduced majority consensus, .
* reduced strict supertree
* reduced majority supertree
Start it up like this:
rd = Reduced(inThing)
where inThing can be a file name
If you are reading from a file (generally a bootstrap or mcmc
output), you can skip some trees at the beginning, and optionally
after that read only a maximum number of trees, like this::
rd = Reduced('myFile', skip=1000, max=500)
Then you can do::
reduced strict consensus
rd.reducedStrictConsensus()
majority rule consensus
rd.majorityRuleConsensus()
write trees to file
rd.writeNexus(filename)
The input trees do not need to have the same taxa.
These are the default settings::
rd.minimumProportion=0.5
rd.rooted=True
rd.minNoOfTaxa = 3
rd.treatMultifurcatingRootsAsUnrooted=0
rd.minimumSplitsToBuildTree=1
rd.extendProfile=False
rd.saveTreesToFile=False
rd.saveSplitsToFile=False
rd.extendProfile=False
rd.saveExtendedTreesToFile=False
rd.saveExtendedToFile=False
rd.treeFileName='trees'
rd.splitFileName='splits.txt'
rd.extendedTreeFileName='extendedTrees'
rd.extendedFileName='extended.txt'
self.sortByNoSplits=True
To change one of them do this::
rd = Reduced(inThing)
rd.rooted = False
"""
def _getTreeSplits(self):
for t in self.trees:
self._determineRooting(t)
for t in self.trees:
t.setPreAndPostOrder()
t._makeRCSplitKeys()
t.splits = []
for n in t.iterInternalsNoRoot():
t.splits.append([n.br.rc, t.missingTaxa])
# print t.splits
def _determineRooting(self, theTree):
gm = ['ReducedStrictConsensus._determineRooting()']
nRootChildren = theTree.root.getNChildren()
if not nRootChildren:
gm.append("Root has no children.")
raise P4Error(gm)
elif nRootChildren == 1:
gm.append("Tree is rooted on a terminal node. No workee.")
raise P4Error(gm)
elif nRootChildren == 2:
if self.isBiRoot == None:
print('Found bifurcating root')
self.isBiRoot = 1
else:
if self.isBiRoot == None:
print('Found multifurcating root')
self.isBiRoot = 0
elif self.isBiRoot == 0:
pass
else:
gm.append(
"Self.isBiRoot has been previously turned on, but now we have a non-biRoot tree.")
raise P4Error(gm)
[docs]
def writeIntersections(self, filename='rcIntersections'):
if len(self.intersections) == 0:
print('No splits created, can not write splits to file')
return
fileHandle = open(filename, 'w')
for name in self.taxNames:
fileHandle.write(name)
fileHandle.write(',')
fileHandle.write('\n')
list = []
for il in self.intersections:
list.extend(il)
# print 'list1:', list
for i in range(0, len(list) - 1):
for j in range(i, len(list)):
if list[i].frequency < list[j].frequency:
temp = list[i]
list[i] = list[j]
list[j] = temp
# print 'list2:',list
fileHandle.write('Taxnumber to taxname translation: \n')
for i in range(0, len(self.taxNames)):
fileHandle.write('%s,%s \n' % (i + 1, self.taxNames[i]))
# self.printVerticalNumbers(len(self.taxNames))
for i in list:
fileHandle.write(i.intersectionToString(self.bitkeys))
fileHandle.write('\n')
# fileHandle.write('\n')
fileHandle.close()
[docs]
def writeExtended(self, extendedFileName='rcExtended'):
if len(self.extendedIntersections) > 0:
list = []
for il in self.extendedIntersections:
list.extend(il)
# print 'list1:', list
for i in range(0, len(list) - 1):
for j in range(i, len(list)):
if list[i].frequency < list[i].frequency:
temp = list[i]
list[i] = list[j]
list[j] = temp
fileHandle = open(extendedFileName, 'w')
for name in self.taxNames:
fileHandle.write(name)
fileHandle.write(',')
fileHandle.write('\n')
for i in list:
fileHandle.write(i.intersectionToString(self.bitkeys))
fileHandle.write('\n')
fileHandle.close()
else:
print('No splits in the extended profile')
def _rule1(self, compatibleSplits):
highestbitkey = (self.bitkeys[-1] << 1) - 1
print('HighestBitkey: %s' % (str(highestbitkey)))
split2Parents = {}
parents = {}
newSplits = []
for i in range(0, len(compatibleSplits) - 1):
for j in range(i + 1, len(compatibleSplits)):
# Rule 1
if compatibleSplits[i][0] & compatibleSplits[j][0] != 0 and compatibleSplits[i][0] & (compatibleSplits[j][0] | compatibleSplits[j][1]) == 0 and compatibleSplits[j][0] & (compatibleSplits[i][0] | compatibleSplits[i][1]) == 0:
split1 = [compatibleSplits[i][0] & compatibleSplits[j][0], ((compatibleSplits[j][0] | compatibleSplits[
j][1]) | highestbitkey) | ((compatibleSplits[i][0] | compatibleSplits[i][1]) | highestbitkey)]
if (split1[0], split1[1]) in split2Parent:
split2Parent[(split1[0], split1[1])][i] = 1
split2Parent[(split1[0], split1[1])][j] = 1
else:
dict = {}
dict[i] = 1
split12Parent[(split1[0], split1[1])] = dict
newSplits.append(split1)
parents[(split1[0], split1[1])]
split2 = [compatibleSplits[i][0] | compatibleSplits[j][0], ((compatibleSplits[j][0] | compatibleSplits[
j][1]) | highestbitkey) & ((compatibleSplits[i][0] | compatibleSplits[i][1]) | highestbitkey)]
newSplits.append(split2)
def _printMajorityRuleSplits(self, list, noSplitSets):
for split in list:
# for split in list:
intersection = ''
for bk in self.bitkeys:
if bk & split[0]:
intersection = intersection + '1'
elif bk & split[1]:
intersection = intersection + '?'
else:
intersection = intersection + '*'
print('Split: %s, %s' % (intersection, str(float(split[2]) / noSplitSets)))
print(' ')
def _findSmallestSet(self, bk, compatibleSplits):
smallest = 0
minPopcount = len(compatibleSplits) + 1
for split in compatibleSplits:
if split[0] & bk:
popCount = self.popcount(split[0])
if popCount < minPopcount:
minPopcount = popCount
smallest = split[0]
elif popCount == minPopcount:
smallest = smallest | split[0]
return smallest
def _applySmallest(self, unknown, smallest, compatibleSplits):
# print 'Unknown'
# print str(unknown)
for split in compatibleSplits:
if split[0] & smallest != 0:
print('11')
self._printSplitList([split])
split[0] = split[0] | unknown
split[1] = split[1] ^ unknown
self._printSplitList([split])
if split[1] & unknown:
print('22')
self._printSplitList([split])
split[1] = split[1] ^ unknown
self._printSplitList([split])
# print ''
def _isOverlapping(self, split1, split2):
if split1[0] | split2[0] == split1[0] or split1[0] | split2[0] == split2[0]:
# print 'True 1'
return True
if (split1[0] | split1[1]) | (split2[0] | split2[1]) == (split1[0] | split1[1]):
# print 'True 2'
return True
if (split1[0] | split1[1]) | (split2[0] | split2[1]) == (split2[0] | split2[1]):
# print 'True 3'
return True
return False
def _combineSet(self, set):
newSet = set()
used = {}
for i in range(0, len(set) - 1):
for j in range(i, len(set)):
if set[i][0] & set[j][0] != 0:
and1 = set[i][0] & set[j][1]
and2 = set[j][0] & set[i][1]
or1 = and1 | and2
newSet.add(
(set[i][0] | or1, (set[i][1] & or1) ^ set[i][1]))
newSet.add(
(set[j][0] | or1, (set[j][1] & or1) ^ set[j][1]))
used[set[i]] = 1
used[set[j]] = 1
else:
if set[i] not in used:
newSet.add(set[i])
if set[j] not in used:
newSet.add(set[j])
## set[i][0] = set[i][0] | or1
# set[i][1] = (set[i][1] &or1) ^ set[i][1]
# set[j][0] = set[j][0] | or1
# set[j][1] = (set[j][1] &or1) ^ set[j][1]
#
# newSet.add((set[i][0],set[i][1]))
# newSet.add((set[j][0],set[j][1]))
return newSet
def _combineSplits(self, split1, split2):
if split1[0] & split2[0] != 0:
and1 = split1[0] & split2[1]
and2 = split2[0] & split1[1]
or1 = and1 | and2
split1[0] = split1[0] | or1
split1[1] = (split1[1] & or1) ^ split1[1]
split2[0] = split2[0] | or1
split2[1] = (split2[1] & or1) ^ split2[1]
def _makeCompatible(self, split1, split2):
if split1[0] & split2[0] == 0:
return
if split1[0] | split2[0] == split1[0] or split1[0] | split2[0] == split2[0]:
return
if (split1[0] | split1[1]) | split2[0] == (split1[0] | split1[1]):
return
if (split2[0] | split2[1]) | split1[0] == (split2[0] | split2[1]):
return
if (split1[0] | split1[1]) | (split2[0] | split2[1]) == (split1[0] | split1[1]):
return
if (split1[0] | split1[1]) | (split2[0] | split2[1]) == (split2[0] | split2[1]):
return
if (split1[0] | (split2[0] | split2[1]) != (split1[0] | split1[1]) and
split1[0] | (split2[0] | split2[1]) != (split2[0] | split2[1])):
# print 'incompatible'
# self._printSplitList([split1])
# self._printSplitList([split2])
xor1 = split1[0] ^ split2[0]
and1 = xor1 & split1[1]
xor2 = xor1 ^ and1
and2 = xor2 & split2[1]
xor3 = and2 ^ xor2
split1[0] = split1[0] ^ (split1[0] & xor3)
split1[1] = split1[1] | xor3
split2[0] = split2[0] ^ (split2[0] & xor3)
split2[1] = split2[1] | xor3
# print 'made compatible'
# self._printSplitList([split1])
# self._printSplitList([split2])
#
return
print('default, not good at all')
# self._printSplitList([split1])
# self._printSplitList([split2])
[docs]
def listOfSplits(self, treatMultifurcatingRootsAsUnrooted, minimumProportion=1.0):
"""
listOfSplits will return a list of splits (maybe zero lenght) that fulfill the minimumProportion criterion
"""
self.splits, treeNames = self.updateToCommonLeafSet(self.tfl)
print('Building intersectionlist')
print(time.ctime(time.time()))
intersectionList = BuildIntersections(
self.bitkeys, minimumProportion, self.weights, self.splits, not treatMultifurcatingRootsAsUnrooted)
splitData = intersectionList.buildIntersections()
print('Build intersectionlist')
print(time.ctime(time.time()))
intersections = []
for list in splitData:
tempList = []
for sd in list:
tempList.append(Intersection(sd[0], sd[1], sd[2]))
intersections.append(tempList)
for il in intersections:
print('')
for i in il:
i.printIntersection(self.bitkeys)
return intersections
[docs]
def reducedStrictConsensus(self):
self.minimumProportion = 1.0
trees, extendedTrees = self.reducedConsensusProfile()
if self.saveTreesToFile:
self.writeNexus(trees, self.treeFileName)
if self.saveSplitsToFile:
self.writeIntersections(self.splitFileName)
if self.saveExtendedTreesToFile:
self.writeNexus(extendedTrees, self.extendedTreeFileName)
if self.saveExtendedToFile:
self.writeExtended(self.extendedFileName)
[docs]
def majorityRuleConsensus(self):
self.minimumProportion = 0.5
trees, extendedTrees = self.reducedConsensusProfile()
if self.saveTreesToFile:
self.writeNexus(trees, self.treeFileName)
if self.saveSplitsToFile:
self.writeIntersections(self.splitFileName)
if self.saveExtendedTreesToFile:
self.writeNexus(extendedTrees, self.extendedTreeFileName)
if self.saveExtendedToFile:
self.writeExtended(self.extendedFileName)
[docs]
def reducedMemQuick(self):
from p4.LeafSupport import TreeSubsets
print('Running divide and conquer to speed things up')
sys.stdout.flush()
ts = TreeSubsets(self.trees)
self.trees = None
self.buildConsensusTreeHash(ts.getConsensusTree())
subtreeDicts, taxNames, taxonSets, cherries = ts.getSubTreesAndTaxonSetsFromInputTrees(
verbose=True)
ts = None
if len(subtreeDicts[0].keys()) == 1:
print('No subdivisions possible, solving whole tree')
else:
print('Got %s subproblems to solve' % (len(cherries) + len(subtreeDicts[0].keys())))
sys.stdout.flush()
allTaxa = self.taxNames[:]
allBitkeys = self.bitkeys[:]
taxa2bitkey = {}
for index in range(0, len(allTaxa)):
taxa2bitkey[allTaxa[index]] = allBitkeys[index]
self.verbose = 0
print('Building intersectionlists, divide and conquer style')
print(time.ctime(time.time()))
sys.stdout.flush()
weight = 0.0
splits = []
for i in range(0, len(taxNames)):
if len(taxonSets[i]) > 2:
trees = []
for dict in subtreeDicts:
trees.append(dict[taxNames[i]])
self.splits, treeNames = self.commonLeafSetTrees(trees)
if not weight:
weight = sum(self.weights)
names = []
bitkeys = []
for index in range(0, len(self.taxNames)):
n = self.taxNames[index].split(':')
for t in n:
names.append(t)
bitkeys.append(self.bitkeys[index])
left = 0
right = 0
for i in range(0, len(allTaxa)):
if allTaxa[i] in names:
left = left ^ allBitkeys[i]
i = Intersection(left, right, weight)
# i.printIntersection(allBitkeys)
splits.append(i)
taxon2bitkey = {}
for index2 in range(0, len(names)):
taxon2bitkey[names[index2]] = bitkeys[index2]
intersectionList = BuildIntersections(self.bitkeys, self.minimumProportion, self.minNoOfTaxa, self.weights,
self.splits, not self.treatMultifurcatingRootsAsUnrooted, self.sortByNoSplits, self.verbose)
intersectionList.buildIntersections()
for split in intersectionList.fullList:
left = 0
right = 0
for i in range(0, len(allTaxa)):
if allTaxa[i] in taxon2bitkey:
if taxon2bitkey[allTaxa[i]] & split.left:
left = left ^ allBitkeys[i]
elif taxon2bitkey[allTaxa[i]] & split.right:
right = right ^ allBitkeys[i]
i = Intersection(left, right, split.hits)
# i.printIntersection(allBitkeys)
splits.append(i)
for cherry in cherries:
names = {}
for set in cherry:
for name in set.split(':'):
names[name] = 1
left = 0
right = 0
for i in range(0, len(allTaxa)):
if allTaxa[i] in names:
left = left ^ allBitkeys[i]
i = Intersection(left, right, weight)
# i.printIntersection(allBitkeys)
splits.append(i)
print('Built intersectionlists')
print(time.ctime(time.time()))
sys.stdout.flush()
print('Taxnumber to taxname translation: ')
for i in range(0, len(allTaxa)):
print('%s,%s' % (i + 1, allTaxa[i]))
intersectionList.verbose = True
self.intersections = intersectionList.quickList(splits, allBitkeys)
self.taxNames = allTaxa
self.bitkeys = allBitkeys
self.extendedIntersections = []
if self.extendProfile:
extendedProfile = intersectionList.buildExtendedProfile(1)
for list in extendedProfile:
tempList = []
for sd in list:
tempList.append(Intersection(sd[0], sd[1], sd[2]))
self.extendedIntersections.append(tempList)
self.conTrees = self.buildTreesFromIntersections(
self.intersections, self.minimumSplitsToBuildTree, 'Reduced strict consensus trees')
if self.extendProfile and len(extendedProfile) > 0:
self.extendedProfileTrees = self.buildTreesFromIntersections(
self.extendedIntersections, self.minimumSplitsToBuildTree, 'Extended profile trees')
return self.conTrees, self.extendedProfileTrees
return self.conTrees, []
[docs]
def reducedQuick(self):
# list of taxanames in order from consensus tree, get from TreeSubsets
# build taxname2bitkey with order from consensus tree
# translate partial splits into full splits using taxonSet2taxonNames
# and taxname2bitkey
from p4.LeafSupport import TreeSubsets
print('Running divide and conquer to speed things up')
sys.stdout.flush()
ts = TreeSubsets(self.trees)
subtreeDicts, taxNames, taxonSets = ts.getSubTreesAndTaxonSetsFromInputTrees(
verbose=False)
self.buildConsensusTreeHash(ts.getConsensusTree())
ts = None
if len(subtreeDicts[0].keys()) == 1:
print('No subdivisions possible, solving whole tree')
else:
print('Got %s subproblems to solve' % (len(subtreeDicts[0].keys())))
sys.stdout.flush()
allTaxa = self.taxNames[:]
allBitkeys = self.bitkeys[:]
taxa2bitkey = {}
for index in range(0, len(allTaxa)):
taxa2bitkey[allTaxa[index]] = allBitkeys[index]
# print 'All taxa: ', allTaxa
# taxname2splitList i.e. A:B:C -> [split, split,split]
# taxname2taxonNames i.e. A:B:C -> ['A','B','C']
taxname2splitList = {}
taxname2taxonNames = {}
self.verbose = 0
print('Building intersectionlists, divide and conquer style')
print(time.ctime(time.time()))
sys.stdout.flush()
for i in range(0, len(taxNames)):
trees = []
for dict in subtreeDicts:
trees.append(dict[taxNames[i]])
self.splits, treeNames = self.commonLeafSetTrees(trees)
taxname2taxonNames[taxNames[i]] = [
self.taxNames[:], self.bitkeys[:]]
if len(self.taxNames) > 2:
intersectionList = BuildIntersections(self.bitkeys, self.minimumProportion, self.minNoOfTaxa, self.weights,
self.splits, not self.treatMultifurcatingRootsAsUnrooted, self.sortByNoSplits, self.verbose)
intersectionList.buildIntersections()
taxname2splitList[taxNames[i]] = intersectionList.fullList
weight = sum(self.weights)
print('Built intersectionlists')
print(time.ctime(time.time()))
sys.stdout.flush()
splits = []
for name in taxNames:
temp = taxname2taxonNames[name][0]
temp2 = taxname2taxonNames[name][1]
names = []
bitkeys = []
for index in range(0, len(temp)):
n = temp[index].split(':')
for t in n:
names.append(t)
bitkeys.append(temp2[index])
# print 'names:',names
taxon2bitkey = {}
for index2 in range(0, len(names)):
taxon2bitkey[names[index2]] = bitkeys[index2]
left = 0
right = 0
for i in range(0, len(allTaxa)):
if allTaxa[i] in taxon2bitkey:
left = left ^ allBitkeys[i]
i = Intersection(left, right, weight)
# i.printIntersection(allBitkeys)
splits.append(i)
if index <= 1:
pass
else:
for split in taxname2splitList[name]:
# split.printIntersection(bitkeys)
left = 0
right = 0
for i in range(0, len(allTaxa)):
# print allTaxa[i]
if allTaxa[i] in taxon2bitkey:
# print '1'
if taxon2bitkey[allTaxa[i]] & split.left:
# print '2'
left = left ^ allBitkeys[i]
elif taxon2bitkey[allTaxa[i]] & split.right:
# print '3'
right = right ^ allBitkeys[i]
# else:
# left = left ^ allBitkeys[i]
# right = right ^ allBitkeys[i]
# print 'Freq:', split.hits
i = Intersection(left, right, split.hits)
# i.printIntersection(allBitkeys)
splits.append(i)
intersectionList.verbose = True
self.intersections = intersectionList.quickList(splits, allBitkeys)
self.taxNames = allTaxa
self.bitkeys = allBitkeys
self.extendedIntersections = []
if self.extendProfile:
extendedProfile = intersectionList.buildExtendedProfile(1)
for list in extendedProfile:
tempList = []
for sd in list:
tempList.append(Intersection(sd[0], sd[1], sd[2]))
self.extendedIntersections.append(tempList)
self.conTrees = self.buildTreesFromIntersections(
self.intersections, self.minimumSplitsToBuildTree, 'Reduced strict consensus trees')
if self.extendProfile and len(extendedProfile) > 0:
self.extendedProfileTrees = self.buildTreesFromIntersections(
self.extendedIntersections, self.minimumSplitsToBuildTree, 'Extended profile trees')
return self.conTrees, self.extendedProfileTrees
return self.conTrees, []
[docs]
def reducedConsensusProfile(self):
"""
reducedConsensusProfile will depending on the minimumProportion and the input data produce 4 different outputs.
* Given a set of trees with identical leafsets and a minimumProportion=1.0 a reduced strict consensus tree
will be produced, i.e. a tree where all splits can be found in all input trees
* Given a set of trees with identical leafsets and a 1.0 < minimumProportion >= 0.5 a reduced majority profile
will be produced, i.e. a set of trees defined on differing leafsets with differing support, this profile will
include the reduced strict consensus tree if it exists
* Given a set of trees with differing leafsets and a minimumProportion= 1.0 a reduced strict super tree will
be produced, i.e. a tree where all splits can be found in all input trees
* Given a set of trees with differing leafsets and a 1.0 < minimimProportion >= 0.5 a reduced majority supertree
profile will be produced, i.e. a set of trees defined on differing leafsets with differing support, this profile will
include the reduced strict supertree tree if it exists
"""
self.splits, treeNames = self.commonLeafSetTrees(self.trees)
# self.splits, treeNames = self.updateToCommonLeafSet(self.tfl)
print('Building intersectionlist')
print(time.ctime(time.time()))
sys.stdout.flush()
intersectionList = BuildIntersections(self.bitkeys, self.minimumProportion, self.minNoOfTaxa,
self.weights, self.splits, not self.treatMultifurcatingRootsAsUnrooted, self.sortByNoSplits)
self.intersections = intersectionList.buildIntersections()
print('Done')
print(time.ctime(time.time()))
sys.stdout.flush()
self.extendedIntersections = []
if self.extendProfile:
extendedProfile = intersectionList.buildExtendedProfile(1)
for list in extendedProfile:
tempList = []
for sd in list:
tempList.append(Intersection(sd[0], sd[1], sd[2]))
self.extendedIntersections.append(tempList)
self.conTrees = self.buildTreesFromIntersections(
self.intersections, self.minimumSplitsToBuildTree, 'Reduced strict consensus trees')
if self.extendProfile and len(extendedProfile) > 0:
self.extendedProfileTrees = self.buildTreesFromIntersections(
self.extendedIntersections, self.minimumSplitsToBuildTree, 'Extended profile trees')
return self.conTrees, self.extendedProfileTrees
return self.conTrees, []
class TreeBuilderFromSplits(object):
def __init__(self, bitKeys, taxnames, internalNames=True):
self.taxNames = taxnames
self.bitkeys = bitKeys
self.internalNames = internalNames
def buildTreeFromInformativeList(self, list, treeName='conTreeName', excluded=0):
informativeBits = ((self.bitkeys[-1] << 1) - 1) ^ list[0].right
# print 'informative: ',informativeBits
# print 'Excluded: ', excluded
informativeBits = informativeBits ^ excluded
# print 'informative: '
# Intersection(informativeBits).printIntersection(self.bitkeys)
conTree = Tree()
conTree.name = treeName
conTree.root = Node()
conTree.root.nodeNum = 0
conTree.root.isLeaf = 0
conTree.nodes.append(conTree.root)
index = 0
leafNo = 1
first = True
rootBitkey = 0
n = Node()
previous = n
for bk in self.bitkeys:
if bk & informativeBits:
if first:
n.nodeNum = leafNo
conTree.nodes.append(n)
conTree.root.leftChild = n
n.parent = conTree.root
n.isLeaf = 1
n.name = self.taxNames[index]
n.br.splitKey = bk
rootBitkey = rootBitkey | bk
first = False
else:
n = Node()
n.nodeNum = leafNo
conTree.nodes.append(n)
previous.sibling = n
n.parent = conTree.root
n.isLeaf = 1
n.name = self.taxNames[index]
n.br.splitKey = bk
rootBitkey = rootBitkey | bk
previous = n
leafNo += 1
index += 1
conTree.root.br.splitKey = rootBitkey
for s in list:
self.applySplitToTree(s, conTree)
conTree.setPreAndPostOrder()
conTree._setTaxNamesFromLeaves()
return conTree
def popcount(self, n):
count = 0
for bk in self.bitkeys:
if n < bk:
return count
if bk & n:
count += 1
return count
def applySplitToTree(self, split, tree):
insertNode = self.findInsertNode(split, tree)
self.insertSplit(split, insertNode, tree)
return 1
def findInsertNode(self, split, tree):
node = tree.root
andScore = 0
nodeSplitScore = 9999999999
for n in tree.nodes:
if not n.isLeaf:
tScore = self.popcount(n.br.splitKey & split.left)
if tScore > andScore:
andScore = tScore
node = n
nodeSplitScore = self.popcount(node.br.splitKey)
elif tScore == andScore and self.popcount(n.br.splitKey) < nodeSplitScore:
nodeSplitScore = self.popcount(n.br.splitKey)
node = n
return node
def insertSplit(self, split, node, tree):
n = Node()
n.isLeaf = 0
n.parent = node
n.br.splitKey = split.left
n.nodeNum = len(tree.nodes)
tree.nodes.append(n)
n.sibling = node.leftChild
node.leftChild = n
if split.name == None:
if self.internalNames:
n.name = '%.0f' % (100. * split.frequency)
else:
n.name = split.name
addNode = self.getSiblingInSplit(node, split)
while addNode:
# print 'while addNode'
self.addNodeToNode(n, addNode)
addNode = self.getSiblingInSplit(node, split)
def getSiblingInSplit(self, node, split):
previous = node.leftChild
n = previous.sibling
if not n:
# print 'no n'
return None
while n.sibling:
if self.isNodeInSplit(n, split):
# print 'found node in split'
previous.sibling = n.sibling
n.sibling = None
return n
if not n.sibling:
# print 'reached end of siblings'
return None
previous = n
n = n.sibling
if self.isNodeInSplit(n, split):
# print 'found node in split outside loop'
previous.sibling = None
# n.wipe()
n.sibling = None
# n.parent = None
return n
# print 'reached end of loop'
return None
def addNodeToNode(self, node, addNode):
addNode.parent = node
if node.getNChildren() == 0:
# print 'node.getNChildren == 0'
node.leftChild = addNode
elif node.getNChildren() == 1:
# print 'node.getNChildren == 1'
leftChild = node.leftChild
leftChild.sibling = addNode
else:
# print 'node.getNChildren == @@@'
rightChild = node.rightmostChild()
rightChild.sibling = addNode
def isNodeInSplit(self, node, split):
if node.br.splitKey | split.left == split.left:
return True
else:
return False
class Intersection(object):
def __init__(self, bitKey, bright=0, bhits=0, bfrequency=0):
# print 'Frequency: ', bhits
self.left = bitKey
self.right = bright
self.hits = bhits
self.frequency = bhits
self.name = None
self.excluded = []
def __cmp__(self, other):
return cmp(other.frequency, self.frequency)
def __str__(self):
intersection = ''
for bk in bitkeys:
if bk & self.left:
intersection = intersection + '*'
elif bk & self.right:
intersection = intersection + '?'
else:
intersection = intersection + '.'
return intersection + ', ' + str(self.frequency * 100)
def copy(self):
return Intersection(self.left, self.right, self.frequency)
def firstHit(self, bitkeys):
hit = 0
for bk in bitkeys:
if bk & self.left:
return hit
hit += 1
return hit
def setExcluded(self, bitkeys):
self.excluded = []
index = 0
for bk in bitkeys:
if self.right & bk:
self.excluded.append(index)
index += 1
def popcount(self, bitkeys):
count = 0
for bk in bitkeys:
if self.left < bk:
return count
if bk & self.left:
count += 1
return count
def getInformativeBits(self, bitkeys):
highestNo = (bitkeys[-1] << 1) - 1
return highestNo ^ self.right
def popcountLR(self, bitkeys):
countie = self.left | self.right
count = 0
for bk in bitkeys:
if countie < bk:
return count
if bk & countie:
count += 1
return count
def popcountExcluded(self, bitkeys):
index = 0
for bk in bitkeys:
if self.right < bk:
return index
if self.right & bk:
index += 1
return index
def isIntersectionTrivial(self, bitkeyList):
if self.left == 0:
return 1
if self.popcount(bitkeyList) <= 1 or self.popcountLR(bitkeyList) == len(bitkeyList):
return 1
if self.left | self.right == 0:
return 1
return 0
def dump(self):
print("%s %s, %6s %s, %6s %s, %6s %s " % ('L:', self.left, 'R:', self.right, 'F:', self.frequency, 'H:', self.hits))
def printIntersection(self, bitkeys, weight=1.0):
intersection = ''
for bk in bitkeys:
if bk & self.left:
intersection = intersection + '*'
elif bk & self.right:
intersection = intersection + '?'
else:
intersection = intersection + '.'
if self.name != None:
print('%s %s' % (intersection, self.name))
else:
f = self.frequency / weight
if f <= 1:
f = f * 100
f = int(f)
print('%s %s' % (intersection, f))
def intersectionToString(self, bitkeys):
intersection = ''
for bk in bitkeys:
if bk & self.left:
intersection = intersection + '*'
elif bk & self.right:
intersection = intersection + '?'
else:
intersection = intersection + '.'
return intersection + ', ' + str(self.frequency * 100)
def addSplit(self, bitkey):
newSplit = Intersection(0)
newSplit.left = self.left & bitkey
newSplit.right = (self.left ^ bitkey) | self.right
# print 'Creating new intersection: '
# newSplit.dump()
return newSplit
def addReverseSplit(self, bitkey, highestBitkey):
newSplit = Intersection(0)
newSplit.left = (self.left ^ bitkey) & self.left
newSplit.right = (
self.left ^ (~bitkey & ((highestBitkey << 1) - 1))) | self.right
return newSplit
def intersection(self, other):
newSplit = Intersection(0)
newSplit.left = self.left & other.left
newSplit.right = (self.left ^ other.left) | self.right | other.right
return newSplit
def reverseIntersection(self, other, highestBitkey):
newSplit = Intersection(0)
newSplit.left = (self.left ^ other.left) & self.left
newSplit.right = (self.left ^ (
~other.left & ((highestBitkey << 1) - 1))) | self.right | other.right
return newSplit
def matchingSplit(self, split):
if self.left == split:
# if self.left == 768:
# print '%s i.l:%s i.r:%s == s:%s' % ('Found matching split',
# self.left, self.right, split)
return 1
if ((self.left ^ split) | self.right) == self.right:
# if self.left == 768:
# print '%s i.left:%s, i.right:%s, s:%s' % ('Found matching split
# XOR', self.left, self.right, split)
return 1
# if self.left == 768:
# print '%s i.l:%s i.r%s != s:%s' % ('Found NONmatching split',
# self.left, self.right, split)
return 0
def overlapping(self, other):
if self.right & other.right != 0:
return True
return False
def smaller(self, other):
if (self.right & other.right) == self.right:
# print 'True'
return True
return False
def definesNonOverlappingSets(self, other):
if self.right == 0 or other.right == 0:
return False
if self.right & other.right == 0:
return True
return False
def definesSmallerAndOverlappingSet(self, other, minimumProportion):
if self.right & other.right == self.right:
return True
return False
def isMoreInformative(self, list, minimumProportion):
for other in list:
# self.printIntersection()
# other.sprintIntersection()
if self.frequency > other.frequency and other.frequency < minimumProportion:
# print '1'
return 1
if other.frequency > self.frequency and self.frequency < minimumProportion:
# print '0'
return 0
selfORother = self.right | other.right
if selfORother == self.right:
# print '0'
return 0
if selfORother == other.right:
# print '1'
return 1
return 2
def isSameSet(self, other):
# if self.left | self.right == other.left | other.right:
# return True
if self.right == other.right:
return True
def isSameLooseSet(self, other):
if self.left | self.right == other.left | other.right:
return True
return False
def isOverlapping(self, other):
if self.right == 0 or other.right == 0:
return True
if self.right == other.right:
return True
# if self.left & other.right != 0 or other.left & self.right != 0:
# return 0
# if self.left & other.left != 0:
# print 'Is overlapping: %s:%s ' % (self.left, other.left)
# return 1
# print 'Is not overlapping: %s:%s ' % (self.left, other.left)
# if self.left & other.left == 0 and self.right ^ other.right == 0:
# return 1
return 0
def isOtherRedundant(self, other):
if self.left | self.right == other.left | other.right:
# if self.left | other.left == self.left:
if self.right & other.right == self.right:
# if self.left | self.right == other.left | other.right:
# print 'Redundant o.l: %s, o.r: %s, i.l: %s' % (other.left,
# other.right, self.left)
return 1
# print 'NonRedundant o.l: %s, o.r: %s, i.l: %s' % (other.left,
# other.right, self.left)
return 0
def isCorrectIntersection(self):
if self.left & self.right != 0:
return 0
return 1
def isIdentical(self, other):
if self.left == other.left and self.right == other.right:
return 1
return 0