@@ -458,44 +458,40 @@ def _set_result_units(self, table, unit):
458458 def compute_histograms (
459459 self , data , masked_elements_of_sample
460460 ) -> ChunkHistogramsContainer :
461- first_dim = data .shape [0 ]
461+ # Build the histograms over the event dimension (axis=0) for each element of the data dimensions
462+ event_dim = data .shape [0 ]
462463 spatial_shape = data .shape [1 :]
463- last_dim = int (np .prod (spatial_shape ))
464-
464+ n_elements = int (np .prod (spatial_shape ))
465465 # Broadcast mask to full shape
466466 if masked_elements_of_sample is not None :
467467 mask = np .broadcast_to (masked_elements_of_sample , data .shape )
468468 else :
469469 mask = np .zeros_like (data , dtype = bool )
470-
471470 # Mask invalid values (NaN, inf)
472471 invalid = ~ np .isfinite (data )
473472 mask = mask | invalid
474- flat_data = data . reshape ( first_dim , last_dim )
475- flat_mask = mask .reshape (first_dim , last_dim )
476-
477- # Build histogram object
473+ # The histogram is computed for each element of the data dimensions, so we need to flatten
474+ flat_data = data .reshape (event_dim , n_elements )
475+ flat_mask = mask . reshape ( event_dim , n_elements )
476+ # Build histogram object over the event dimension for each element of the data dimensions
478477 hist_object = Hist (
479478 self .hist_axis ,
480- hist .axis .Integer (0 , last_dim , name = "last_dimension " ),
479+ hist .axis .Integer (0 , n_elements , name = "element " ),
481480 storage = hist .storage .Int64 (),
482481 )
483-
484- # Fill histogram (loop over the last dimension, but fast backend)
485- for i in range (last_dim ):
486- valid = ~ flat_mask [:, i ]
487- if not np .any (valid ):
488- continue
489-
490- values = flat_data [valid , i ]
491- hist_object .fill (value = values , last_dimension = i )
492-
493- # Extract histogram counts
482+ # Vectorized filling - all valid values and their dimension indices at once
483+ valid_mask = ~ flat_mask
484+ values = flat_data [valid_mask ]
485+ dimension_indices = np .where (valid_mask )[1 ] # column indices (which dimension)
486+ if len (values ) > 0 :
487+ hist_object .fill (value = values , element = dimension_indices )
488+ # Extract histogram counts and reshape to original data dimensions (with bin dimension first)
494489 n_bins = hist_object .axes [0 ].size
495- hist_counts = hist_object .values () # shape: (bins, n_pixels )
490+ hist_counts = hist_object .values () # shape: (bins, n_elements )
496491 hist_counts = hist_counts .reshape ((n_bins ,) + spatial_shape )
497492 # Count valid entries per element (excludes masked and invalid values)
498493 n_events_valid = np .sum (~ flat_mask , axis = 0 ).reshape (spatial_shape )
494+ # Build and return the ChunkHistogramsContainer
499495 return ChunkHistogramsContainer (
500496 n_events = n_events_valid ,
501497 histogram = hist_counts ,
0 commit comments