@@ -940,7 +940,8 @@ def get(self):
940
940
dimz = 0
941
941
942
942
binWidth = (MMAX - MMIN )/ NBINS
943
- hystored = [numpy .zeros ((ystored [0 ].shape [0 ],NBINS + 2 ))]
943
+ nClasses = ystored [0 ].shape [1 ]
944
+ hystored = [numpy .zeros ((ystored [0 ].shape [0 ],NBINS + nClasses ))]
944
945
for ii in xrange (0 ,ystored [0 ].shape [0 ]):
945
946
# bin of mass (for kldiv loss)
946
947
if zstored [0 ][ii ,0 ,2 ]< MMIN :
@@ -949,8 +950,8 @@ def get(self):
949
950
hystored [0 ][ii ,NBINS - 1 ]= 1
950
951
else :
951
952
hystored [0 ][ii ,int ((zstored [0 ][ii ,0 ,2 ]- MMIN )/ binWidth )]= 1
952
- hystored [ 0 ][ ii , NBINS ] = ystored [ 0 ][ ii , 0 ]
953
- hystored [0 ][ii ,NBINS + 1 ] = ystored [0 ][ii ,1 ]
953
+ for jj in range ( nClasses ):
954
+ hystored [0 ][ii ,NBINS + jj ] = ystored [0 ][ii ,jj ]
954
955
dimhy = len (hystored )
955
956
if not decor :
956
957
dimhy = 0
@@ -1008,7 +1009,8 @@ def get(self):
1008
1009
zstored [i ] = numpy .vstack ((zstored [i ],td .z [i ]))
1009
1010
1010
1011
binWidth = float (MMAX - MMIN )/ NBINS
1011
- hy = [numpy .zeros ((td .y [0 ].shape [0 ],NBINS + 2 ))]
1012
+ nClasses = td .y [0 ].shape [1 ]
1013
+ hy = [numpy .zeros ((td .y [0 ].shape [0 ],NBINS + nClasses ))]
1012
1014
for ii in xrange (0 ,td .y [0 ].shape [0 ]):
1013
1015
# bin of mass (for kldiv loss)
1014
1016
if td .z [0 ][ii ,0 ,2 ]< MMIN :
@@ -1017,8 +1019,8 @@ def get(self):
1017
1019
hy [0 ][ii ,NBINS - 1 ]= 1
1018
1020
else :
1019
1021
hy [0 ][ii ,int ((td .z [0 ][ii ,0 ,2 ]- MMIN )/ binWidth )]= 1
1020
- hy [ 0 ][ ii , NBINS ] = td . y [ 0 ][ ii , 0 ]
1021
- hy [0 ][ii ,NBINS + 1 ] = td .y [0 ][ii ,1 ]
1022
+ for jj in range ( nClasses ):
1023
+ hy [0 ][ii ,NBINS + jj ] = td .y [0 ][ii ,jj ]
1022
1024
hystored [0 ] = numpy .vstack ((hystored [0 ],hy [0 ]))
1023
1025
1024
1026
if xstored [0 ].shape [0 ] >= self .__batchsize :
0 commit comments