Skip to content

Commit e0e447a

Browse files
authored
Probing and SAE integration (#12)
* expanded clustering * running scripts with apptainer * gpu label * wandb datasets * new lenses * sae training scripts * fixed depedencies * flexible SAE script * fixed CI checks
1 parent 4d5fa69 commit e0e447a

25 files changed

+1414
-70
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,6 @@ debug
134134
*.zip
135135
lc0
136136
!bin/lc0
137+
wandb
137138

138139
*secret*

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: 23.11.0
3+
rev: 24.2.0
44
hooks:
55
- id: black
66
args: ["--config", "pyproject.toml"]
@@ -20,17 +20,17 @@ repos:
2020
hooks:
2121
- id: poetry-check
2222
- repo: https://github.com/pre-commit/mirrors-mypy
23-
rev: v1.7.1
23+
rev: v1.8.0
2424
hooks:
2525
- id: mypy
2626
additional_dependencies: ['types-requests', 'types-toml']
2727
- repo: https://github.com/pycqa/flake8
28-
rev: 6.1.0
28+
rev: 7.0.0
2929
hooks:
3030
- id: flake8
3131
args: ['--ignore=E203,W503', '--per-file-ignores=__init__.py:F401']
3232
- repo: https://github.com/pycqa/isort
33-
rev: 5.12.0
33+
rev: 5.13.2
3434
hooks:
3535
- id: isort
3636
args: ["--settings-path", "pyproject.toml"]

poetry.lock

Lines changed: 312 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ line-length = 79
44
[tool.isort]
55
profile = "black"
66
line_length = 79
7+
src_paths = ["src", "tests", "scripts", "docs", "demo"]
78

89
[tool.poetry]
910
name = "lczerolens"
@@ -28,12 +29,13 @@ python = "^3.9"
2829
python-chess = "^1.999"
2930
torch = ">=2"
3031
onnx2torch = "^1.5.13"
31-
tensordict = "^0.2.1"
32+
tensordict = "^0.3.0"
3233
gradio = {version = "^4.14.0", optional = true}
3334
zennit = "<=0.4.6"
3435
jsonlines = "^4.0.0"
3536
scikit-learn = "^1.4.0"
3637
zennit-crp = "^0.6.0"
38+
einops = "^0.7.0"
3739

3840
[tool.poetry.extras]
3941
demo = ["gradio"]
@@ -77,6 +79,7 @@ optional = true
7779
safetensors = "^0.4.2"
7880
pylatex = "^1.4.2"
7981
matplotlib = "^3.8.2"
82+
wandb = "^0.16.3"
8083

8184
[build]
8285
target-dir = "build/dist"

scripts/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
im_viz
22
results
33
clusters
4+
saes

scripts/create_figure.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Nice plotting of chessboard and heatmap with arrows.
22
"""
33

4-
54
import chess
65
from pylatex import Figure, NoEscape, SubFigure
76

scripts/make_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#######################################
2222
# HYPERPARAMETERS
2323
#######################################
24-
parser = argparse.ArgumentParser("leela")
24+
parser = argparse.ArgumentParser("make-datasets")
2525
parser.add_argument("--output-root", type=str, default=".")
2626
make_test_10 = False
2727
make_test_5000 = False

scripts/register_wandb_dataset.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Register a dataset in Weights & Biases.
2+
3+
Run with:
4+
```bash
5+
poetry run python -m scripts.register_wandb_dataset
6+
```
7+
"""
8+
9+
import argparse
10+
import os
11+
import random
12+
13+
import wandb
14+
15+
from lczerolens import BoardDataset
16+
17+
from .secret import WANDB_API_KEY
18+
19+
#######################################
20+
# HYPERPARAMETERS
21+
#######################################
22+
parser = argparse.ArgumentParser("make-datasets")
23+
parser.add_argument("--output-root", type=str, default=".")
24+
make_dataset = False
25+
seed = 42
26+
train_samples = 10_000
27+
val_samples = 1_000
28+
test_samples = 1_000
29+
log_dataset = False
30+
#######################################
31+
32+
ARGS = parser.parse_args()
33+
os.makedirs(f"{ARGS.output_root}/assets", exist_ok=True)
34+
35+
if make_dataset:
36+
dataset = BoardDataset("./assets/TCEC_game_collection_random_boards.jsonl")
37+
all_indices = list(range(len(dataset)))
38+
random.seed(seed)
39+
random.shuffle(all_indices)
40+
train_indices = all_indices[:train_samples]
41+
val_slice = train_samples + val_samples
42+
val_indices = all_indices[train_samples:val_slice]
43+
test_slice = val_slice + test_samples
44+
test_indices = all_indices[val_slice:test_slice]
45+
46+
dataset.save(
47+
f"{ARGS.output_root}/assets/"
48+
"TCEC_game_collection_random_boards_train.jsonl",
49+
indices=train_indices,
50+
)
51+
dataset.save(
52+
f"{ARGS.output_root}/assets/"
53+
"TCEC_game_collection_random_boards_val.jsonl",
54+
indices=val_indices,
55+
)
56+
dataset.save(
57+
f"{ARGS.output_root}/assets/"
58+
"TCEC_game_collection_random_boards_test.jsonl",
59+
indices=test_indices,
60+
)
61+
62+
# type: ignore
63+
if log_dataset:
64+
wandb.login(key=WANDB_API_KEY) # type: ignore
65+
with wandb.init( # type: ignore
66+
project="lczerolens-saes", job_type="make-datasets"
67+
) as run:
68+
artifact = wandb.Artifact("tcec_train", type="dataset") # type: ignore
69+
artifact.add_file(
70+
f"{ARGS.output_root}/assets/"
71+
"TCEC_game_collection_random_boards_train.jsonl"
72+
)
73+
run.log_artifact(artifact)
74+
artifact = wandb.Artifact("tcec_val", type="dataset") # type: ignore
75+
artifact.add_file(
76+
f"{ARGS.output_root}/assets/"
77+
"TCEC_game_collection_random_boards_val.jsonl"
78+
)
79+
run.log_artifact(artifact)
80+
artifact = wandb.Artifact("tcec_test", type="dataset") # type: ignore
81+
artifact.add_file(
82+
f"{ARGS.output_root}/assets/"
83+
"TCEC_game_collection_random_boards_test.jsonl"
84+
)
85+
run.log_artifact(artifact)

0 commit comments

Comments
 (0)