Skip to content

Commit 1e8d97b

Browse files
committed
Used 3 class decorrelation
1 parent 83a4d35 commit 1e8d97b

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

DataCollection.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,8 @@ def get(self):
940940
dimz=0
941941

942942
binWidth = (MMAX-MMIN)/NBINS
943-
hystored=[numpy.zeros((ystored[0].shape[0],NBINS+2))]
943+
nClasses = ystored[0].shape[1]
944+
hystored=[numpy.zeros((ystored[0].shape[0],NBINS+nClasses))]
944945
for ii in xrange(0,ystored[0].shape[0]):
945946
# bin of mass (for kldiv loss)
946947
if zstored[0][ii,0,2]<MMIN:
@@ -949,8 +950,8 @@ def get(self):
949950
hystored[0][ii,NBINS-1]=1
950951
else:
951952
hystored[0][ii,int((zstored[0][ii,0,2]-MMIN)/binWidth)]=1
952-
hystored[0][ii,NBINS] = ystored[0][ii,0]
953-
hystored[0][ii,NBINS+1] = ystored[0][ii,1]
953+
for jj in range(nClasses):
954+
hystored[0][ii,NBINS+jj] = ystored[0][ii,jj]
954955
dimhy=len(hystored)
955956
if not decor:
956957
dimhy=0
@@ -1008,7 +1009,8 @@ def get(self):
10081009
zstored[i] = numpy.vstack((zstored[i],td.z[i]))
10091010

10101011
binWidth = float(MMAX-MMIN)/NBINS
1011-
hy=[numpy.zeros((td.y[0].shape[0],NBINS+2))]
1012+
nClasses = td.y[0].shape[1]
1013+
hy=[numpy.zeros((td.y[0].shape[0],NBINS+nClasses))]
10121014
for ii in xrange(0,td.y[0].shape[0]):
10131015
# bin of mass (for kldiv loss)
10141016
if td.z[0][ii,0,2]<MMIN:
@@ -1017,8 +1019,8 @@ def get(self):
10171019
hy[0][ii,NBINS-1]=1
10181020
else:
10191021
hy[0][ii,int((td.z[0][ii,0,2]-MMIN)/binWidth)]=1
1020-
hy[0][ii,NBINS] = td.y[0][ii,0]
1021-
hy[0][ii,NBINS+1] = td.y[0][ii,1]
1022+
for jj in range(nClasses):
1023+
hy[0][ii,NBINS+jj] = td.y[0][ii,jj]
10221024
hystored[0] = numpy.vstack((hystored[0],hy[0]))
10231025

10241026
if xstored[0].shape[0] >= self.__batchsize:

0 commit comments

Comments
 (0)