Skip to content

Commit 15378ab

Browse files
authored
Add categorisation to all wandb metrics (#136)
1 parent e681b29 commit 15378ab

13 files changed

+25
-21
lines changed

sparse_autoencoder/loss/abstract_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def batch_scalar_loss_with_log(
135135
)
136136

137137
# Add in the current loss module's metric
138-
log_name = self.log_name()
138+
log_name = "train/loss/" + self.log_name()
139139
metrics[log_name] = current_module_loss.detach().cpu().item()
140140

141141
return current_module_loss, metrics

sparse_autoencoder/loss/decoded_activations_l2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class L2ReconstructionLoss(AbstractLoss):
2727
>>> unused_activations = torch.zeros_like(input_activations)
2828
>>> # Outputs both loss and metrics to log
2929
>>> loss(input_activations, unused_activations, output_activations)
30-
(tensor(5.5000), {'l2_reconstruction_loss': 5.5})
30+
(tensor(5.5000), {'train/loss/l2_reconstruction_loss': 5.5})
3131
"""
3232

3333
_reduction: LossReductionType

sparse_autoencoder/loss/learned_activations_l1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def batch_scalar_loss_with_log(
124124
batch_scalar_loss_penalty = absolute_loss_penalty.sum().squeeze()
125125

126126
metrics = {
127-
"learned_activations_l1_loss": batch_scalar_loss.item(),
128-
self.log_name(): batch_scalar_loss_penalty.item(),
127+
"train/loss/" + "learned_activations_l1_loss": batch_scalar_loss.item(),
128+
"train/loss/" + self.log_name(): batch_scalar_loss_penalty.item(),
129129
}
130130

131131
return batch_scalar_loss_penalty, metrics

sparse_autoencoder/loss/tests/test_abstract_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def test_batch_scalar_loss_with_log(dummy_loss: DummyLoss) -> None:
6464
source_activations, learned_activations, decoded_activations
6565
)
6666
expected = 2.0 # Mean of [1.0, 2.0, 3.0]
67-
assert log["dummy"] == expected
67+
assert log["train/loss/dummy"] == expected
6868

6969

7070
def test_call_method(dummy_loss: DummyLoss) -> None:
7171
"""Test the call method."""
7272
source_activations = learned_activations = decoded_activations = torch.ones((1, 3))
7373
_loss, log = dummy_loss(source_activations, learned_activations, decoded_activations)
7474
expected = 2.0 # Mean of [1.0, 2.0, 3.0]
75-
assert log["dummy"] == expected
75+
assert log["train/loss/dummy"] == expected

sparse_autoencoder/metrics/train/capacity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,5 @@ def calculate(self, data: TrainMetricData) -> dict[str, Any]:
8585
train_batch_capacities_histogram = self.wandb_capacities_histogram(train_batch_capacities)
8686

8787
return {
88-
"train_batch_capacities_histogram": train_batch_capacities_histogram,
88+
"train/batch_capacities_histogram": train_batch_capacities_histogram,
8989
}

sparse_autoencoder/metrics/train/feature_density.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,5 @@ def calculate(self, data: TrainMetricData) -> dict[str, Any]:
111111
)
112112

113113
return {
114-
"train_batch_feature_density_histogram": train_batch_feature_density_histogram,
114+
"train/batch_feature_density_histogram": train_batch_feature_density_histogram,
115115
}

sparse_autoencoder/metrics/train/l0_norm_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ def calculate(self, data: TrainMetricData) -> dict[str, float]:
2222
batch_size = data.learned_activations.size(0)
2323
n_non_zero_activations = torch.count_nonzero(data.learned_activations)
2424
batch_average = n_non_zero_activations / batch_size
25-
return {"learned_activations_l0_norm": batch_average.item()}
25+
return {"train/learned_activations_l0_norm": batch_average.item()}

sparse_autoencoder/metrics/train/tests/test_capacities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,4 @@ def test_calculate_returns_histogram() -> None:
6262
decoded_activations=activations,
6363
)
6464
)
65-
assert "train_batch_capacities_histogram" in res
65+
assert "train/batch_capacities_histogram" in res

sparse_autoencoder/metrics/train/tests/test_feature_density.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@ def test_calculate_aggregates() -> None:
4141
)
4242

4343
# Check both metrics are in the result
44-
assert "train_batch_feature_density_histogram" in res
44+
assert "train/batch_feature_density_histogram" in res

sparse_autoencoder/metrics/train/tests/test_l0_norm_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_l0_norm_metric() -> None:
1616
)
1717
log = l0_norm_metric.calculate(data)
1818
expected = 3 / 2
19-
assert log["learned_activations_l0_norm"] == expected
19+
assert log["train/learned_activations_l0_norm"] == expected

0 commit comments

Comments
 (0)