Skip to content

Commit 83a4d35

Browse files
committed
Merge branch 'master' of github.com:anovak10/DeepJetCore into update
Merge with local changes
2 parents 8247a55 + cde362b commit 83a4d35

File tree

3 files changed

+9
-20
lines changed

3 files changed

+9
-20
lines changed

DataCollection.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def createTestDataForDataCollection(
374374
import copy
375375
self.readFromFile(collectionfile)
376376
self.dataclass.remove=False
377-
self.dataclass.weight=False #True #False
377+
self.dataclass.weight=False
378378
self.readRootListFromFile(inputfile)
379379
self.createDataFromRoot(
380380
self.dataclass, outputDir, False,
@@ -930,6 +930,7 @@ def get(self):
930930
ystored=td.y
931931
dimy=len(ystored)
932932
wstored=td.w
933+
wstored = [wstored[0].reshape(wstored[0].shape[0])]
933934
dimw=len(wstored)
934935
if not self.useweights:
935936
dimw=0
@@ -972,6 +973,7 @@ def get(self):
972973
else:
973974

974975
#randomly^2 shuffle - not needed every time
976+
#print("RANDOM SHUFFLE")
975977
if psamples%2==0 and nepoch%2==1:
976978
for i in range(0,dimx):
977979
td.x[i]=shuffle(td.x[i], random_state=psamples)
@@ -981,7 +983,6 @@ def get(self):
981983
td.w[i]=shuffle(td.w[i], random_state=psamples)
982984
for i in range(0,dimz):
983985
td.z[i]=shuffle(td.z[i], random_state=psamples)
984-
985986
for i in range(0,dimx):
986987
if(xstored[i].ndim==1):
987988
xstored[i] = numpy.append(xstored[i],td.x[i])
@@ -999,7 +1000,7 @@ def get(self):
9991000
wstored[i] = numpy.append(wstored[i],td.w[i])
10001001
else:
10011002
wstored[i] = numpy.vstack((wstored[i],td.w[i]))
1002-
1003+
10031004
for i in range(0,dimz):
10041005
if(zstored[i].ndim==1):
10051006
zstored[i] = numpy.append(zstored[i],td.z[i])
@@ -1019,7 +1020,7 @@ def get(self):
10191020
hy[0][ii,NBINS] = td.y[0][ii,0]
10201021
hy[0][ii,NBINS+1] = td.y[0][ii,1]
10211022
hystored[0] = numpy.vstack((hystored[0],hy[0]))
1022-
1023+
10231024
if xstored[0].shape[0] >= self.__batchsize:
10241025
batchcomplete = True
10251026

@@ -1057,7 +1058,7 @@ def get(self):
10571058
splitted=numpy.split(hystored[i],[self.__batchsize])
10581059
hystored[i] = splitted[1]
10591060
hyout[i] = splitted[0]
1060-
1061+
10611062
for i in range(0,dimx):
10621063
if(xout[i].ndim==1):
10631064
xout[i]=(xout[i].reshape(xout[i].shape[0],1))
@@ -1070,12 +1071,6 @@ def get(self):
10701071
if not yout[i].shape[1] >0:
10711072
raise Exception('serious problem with the output shapes!!')
10721073

1073-
for i in range(0,dimw):
1074-
if(wout[i].ndim==1):
1075-
wout[i]=(wout[i].reshape(wout[i].shape[0],1))
1076-
if not wout[i].shape[1] >0:
1077-
raise Exception('serious problem with the output shapes!!')
1078-
10791074
for i in range(0,dimz):
10801075
if(zout[i].ndim==1):
10811076
zout[i]=(zout[i].reshape(zout[i].shape[0],1))
@@ -1087,7 +1082,6 @@ def get(self):
10871082
hyout[i]=(hyout[i].reshape(hyout[i].shape[0],1))
10881083
if not hyout[i].shape[1] >0:
10891084
raise Exception('serious problem with the output shapes!!')
1090-
10911085
processedbatches+=1
10921086

10931087

@@ -1101,13 +1095,12 @@ def get(self):
11011095
if self.useweights and decor:
11021096
yield (xout,hyout,wout)
11031097
elif self.useweights:
1104-
yield (xout,yout,wout)
1098+
yield (xout,yout,wout)
11051099
elif decor:
11061100
yield (xout,hyout)
11071101
else:
11081102
yield (xout,yout)
11091103

1110-
11111104

11121105

11131106

TrainData.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,6 @@ def readIn_async(self,fileprefix,read_async=True,shapesOnly=False,ramdiskpath=''
305305
print('\nTrainData::readIn_async: started new read before old was finished. Intended? Waiting for first to finish...\n')
306306
self.readIn_join()
307307

308-
#print('read')
309-
310308
import h5py
311309
import multiprocessing
312310

@@ -627,7 +625,6 @@ def reshape_fast(arr,shapeinfo):
627625
self.z_list=None
628626
self.readthread=None
629627

630-
631628
def readTreeFromRootToTuple(self, filenames, limit=None, branches=None):
632629
'''
633630
To be used to get the initial tupel for further processing in inherting classes

training/training_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,10 @@ def __init__(
124124
var = raw_input('Output dir exists. To recover a training, please type "yes"\n To overwrite a training, please type "continue"\n')
125125
if var == 'yes':
126126
isNewTraining=False
127-
elif var == 'continue':
128-
shutil.rmtree(self.outputDir)
129127
else:
130128
raise Exception('output directory must not exists yet')
131-
#isNewTraining=False
129+
#else:
130+
# isNewTraining=False
132131
else:
133132
os.mkdir(self.outputDir)
134133
self.outputDir = os.path.abspath(self.outputDir)

0 commit comments

Comments
 (0)