"""
desitarget.myRF
===============
This module computes the Random Forest probability
and it stores the RF with our own persistency.
"""
import numpy as np
import sys
[docs]class myRF(object):
""" Class for I/O operations and probability calculation for Random Forest
"""
def __init__(self, data, modelDir, numberOfTrees=200, version=2):
# loads the data once and initializes arrays
self.data = data.copy()
self.proba = np.zeros(len(data))
self.bdtOutput = np.zeros(len(data))
self.modelDir = modelDir
self.version = version
self.nTrees = numberOfTrees
if self.version in [1, 2]:
# print ("version is :",self.version)
self.filesPerTree = 4 # for models-decals-dr3, (was 5 for models-decals)
else:
print("unsupported version=", self.version)
sys.exit()
def loadTree(self, treeFile, answerFile):
# loads one tree and checks that the recursion limit is enough
self.treeInfo = np.load(treeFile)
self.treeAnswer = np.load(answerFile)
if len(self.treeInfo) > sys.getrecursionlimit():
sys.setrecursionlimit(int(len(self.treeInfo)*1.2))
# print "WARNING recursion limit set to length(tree)*1.2 :",sys.getrecursionlimit()
def unloadTree(self):
# delete the current tree information to avoid memory leaks
del self.treeInfo
if self.version == 1:
del self.treeAnswer
def searchNodes(self, indices, nodeId=0):
# recursively navigates in the tree and calculate the tree response
nodeInfo = self.treeInfo[nodeId]
# version without probability per leaf
# if nodeInfo[0]==-1 :
# if self.treeAnswer[nodeId,0,0]<self.treeAnswer[nodeId,0,1] :
# score=1.
# else :
# score=0.
# self.proba[indices]=score
# return
if nodeInfo[0] == -1:
if self.version == 1:
self.proba[indices] = self.treeAnswer[nodeId, 0, 1]*1./(self.treeAnswer[nodeId, 0, 0]+self.treeAnswer[nodeId, 0, 1])
else:
self.proba[indices] = nodeInfo[4]
return
leftChildId = nodeInfo[0]
rightChildId = nodeInfo[1]
feature = nodeInfo[2]
threshold = nodeInfo[3]
leftCond = (self.data[indices, feature] <= threshold)
leftChildIndices = indices[leftCond]
# rightCond = (self.data[indices,feature] > threshold)
rightChildIndices = indices[~leftCond]
self.searchNodes(leftChildIndices, nodeId=leftChildId)
self.searchNodes(rightChildIndices, nodeId=rightChildId)
return
def predict_proba(self):
# calculate the forest response using the average response of the trees in the forest
for iTree in np.arange(self.nTrees):
# if iTree%10 == 0 : print ("tree=",iTree)
self.loadTreeFromForest(iTree)
self.searchNodes(np.arange(len(self.data)))
self.bdtOutput += self.proba
self.bdtOutput /= self.nTrees
return self.bdtOutput
def loadForest(self, forestFileName):
# loads forest
t = np.load(forestFileName, encoding='bytes')
self.forest = t['arr_0']
return
def loadTreeFromForest(self, iTree):
# loads one tree from the forest file and checks that the recursion limit is enough
if self.version == 1:
self.treeInfo = self.forest[iTree*2]
self.treeAnswer = self.forest[iTree*2+1]
elif self.version == 2:
self.treeInfo = self.forest[iTree]
else:
print("unsupported version=", self.version)
sys.exit()
if len(self.treeInfo) > sys.getrecursionlimit():
sys.setrecursionlimit(int(len(self.treeInfo)*1.2))
# print "WARNING recursion limit set to length(tree)*1.2 :",sys.getrecursionlimit()
def saveForest(self, forestFileName):
# reads trees useful information and stores them in forestFileName
def getFilledNumber(iFile):
# just because fileNumber <10 have been padded with one 0 in scikit-learn
if iFile < 10:
return str(iFile).zfill(2)
else:
return str(iFile)
forest = []
for iTree in np.arange(self.nTrees):
if iTree % 10 == 0:
print("tree=", iTree)
fileNumber = (iTree*self.filesPerTree+4)
treeFile = self.modelDir+"bdt.pkl_"+getFilledNumber(fileNumber)+".npy"
answerFile = self.modelDir+"bdt.pkl_"+getFilledNumber(fileNumber-1)+".npy"
# Store only useful information
newt = None
t = np.load(treeFile)
a = np.load(answerFile)
if self.version == 1:
newt = np.zeros(len(t), dtype='int16, int16, int8, float32')
elif self.version == 2:
newt = np.zeros(len(t), dtype='int16, int16, int8, float32, float32')
else:
pass
for i in np.arange(len(t)):
temp_t = t[i]
if self.version == 1:
tup = (temp_t[0], temp_t[1], temp_t[2], temp_t[3])
elif self.version == 2:
temp_a = a[i]
proba = temp_a[0, 1]/(temp_a[0, 0]+temp_a[0, 1])
tup = (temp_t[0], temp_t[1], temp_t[2], temp_t[3], proba)
newt[i] = tup
if self.version == 1:
forest.append(newt)
forest.append(np.load(answerFile))
elif self.version == 2:
forest.append(newt)
else:
pass
del newt
np.savez_compressed(forestFileName, forest)
return