diff --git a/connectx/nns/models.py b/connectx/nns/models.py index 48028f8..7a3b1cf 100644 --- a/connectx/nns/models.py +++ b/connectx/nns/models.py @@ -192,7 +192,8 @@ def forward( ) -> Dict[str, Any]: logging.info(f"Beginning forward pass") x, available_actions_mask, subtask_embeddings = self.dict_input_layer(x) - logging.info(f"Getting base_model outputs {x.shape}") + for key,val in x.items(): + logging.info(f"Getting base_model outputs {key}:{val.shape}") base_out = self.base_model(x) logging.info(f"Ignoring subtasks") if subtask_embeddings is not None: