Skip to content

Commit 83742dd

Browse files
authored
Diverse fixes (#14)
* fix on device training * fixed val metrics * fixed parse bool * new training parameters * copy typo * best param and new eval * resampling error * ressampling typo * ressampling typo * bias typo * device typo * table format * np hist * bug hist * new defaults * debug * fixed hist * tests activated features * bug test * debug * bigger datasets * no reload * to cpu * new asset link
1 parent e0e447a commit 83742dd

File tree

13 files changed

+522
-317
lines changed

13 files changed

+522
-317
lines changed

.vscode/launch.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,23 @@
5151
"console": "integratedTerminal",
5252
"justMyCode": false
5353
}
54+
,
55+
{
56+
"name": "Script simple sae",
57+
"type": "debugpy",
58+
"request": "launch",
59+
"module": "scripts.simple_sae",
60+
"console": "integratedTerminal",
61+
"justMyCode": false
62+
}
63+
,
64+
{
65+
"name": "Debug",
66+
"type": "debugpy",
67+
"request": "launch",
68+
"module": "ignored.debug",
69+
"console": "integratedTerminal",
70+
"justMyCode": false
71+
}
5472
]
5573
}

assets/resolve-test-assets.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ poetry run gdown 1Ssl4JanqzQn3p-RoHRDk_aApykl-SukE -O assets/tinygyal-8.pb.gz
22
poetry run gdown 1WzBQV_zn5NnfsG0K8kOion0pvWxXhgKM -O assets/384x30-2022_0108_1903_17_608.pb.gz
33
poetry run gdown 1erxB3tULDURjpPhiPWVGr6X986Q8uE6U -O assets/maia-1100.pb.gz
44
poetry run gdown 1YqqANK-wuZIOmMweuK_oCU7vfPN7G_Z6 -O assets/t1-smolgen-512x15x8h-distilled-swa-3395000.pb.gz
5-
poetry run gdown 1Gw40ENElDrpqx9p0uuAnO0jieX1Clw89 -O assets/test_stockfish_10.jsonl
5+
poetry run gdown 15-eGN7Hz2NM6aEMRaQrbW3ScxxQpAqa5 -O assets/test_stockfish_10.jsonl

demo/attention_interface.py

Lines changed: 87 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99

1010
from demo import constants, utils, visualisation
1111

12-
cache = None
13-
boards = None
14-
board_index = 0
15-
1612

1713
def list_models():
1814
"""
@@ -38,89 +34,94 @@ def compute_cache(
3834
attention_layer,
3935
attention_head,
4036
square,
41-
quantity,
42-
func,
43-
trick,
44-
aggregate,
37+
state_board_index,
38+
state_boards,
39+
state_cache,
4540
):
46-
global cache
47-
global boards
4841
if model_name == "":
4942
gr.Warning("No model selected.")
50-
return None, None, None
43+
return None, None, None, state_boards, state_cache
5144

5245
try:
5346
board = chess.Board(board_fen)
5447
except ValueError:
5548
board = chess.Board()
5649
gr.Warning("Invalid FEN, using starting position.")
57-
boards = [board.copy()]
50+
state_boards = [board.copy()]
5851
if action_seq:
5952
try:
6053
if action_seq.startswith("1."):
6154
for action in action_seq.split():
6255
if action.endswith("."):
6356
continue
6457
board.push_san(action)
65-
boards.append(board.copy())
58+
state_boards.append(board.copy())
6659
else:
6760
for action in action_seq.split():
6861
board.push_uci(action)
69-
boards.append(board.copy())
62+
state_boards.append(board.copy())
7063
except ValueError:
7164
gr.Warning(f"Invalid action {action} stopping before it.")
7265
try:
7366
wrapper, lens = utils.get_wrapper_lens_from_state(
74-
model_name, "attention"
67+
model_name,
68+
"activation",
69+
lens_name="attention",
70+
module_exp=r"encoder\d+/mha/QK/softmax",
7571
)
7672
except ValueError:
7773
gr.Warning("Could not load model.")
78-
return None, None, None
79-
cache = []
80-
for board in boards:
81-
attention_cache = copy.deepcopy(lens.compute_heatmap(board, wrapper))
82-
cache.append(attention_cache)
83-
return make_plot(
84-
attention_layer,
85-
attention_head,
86-
square,
87-
quantity,
88-
func,
89-
trick,
90-
aggregate,
74+
return None, None, None, state_boards, state_cache
75+
state_cache = []
76+
for board in state_boards:
77+
attention_cache = copy.deepcopy(lens.analyse_board(board, wrapper))
78+
state_cache.append(attention_cache)
79+
return (
80+
*make_plot(
81+
attention_layer,
82+
attention_head,
83+
square,
84+
state_board_index,
85+
state_boards,
86+
state_cache,
87+
),
88+
state_boards,
89+
state_cache,
9190
)
9291

9392

9493
def make_plot(
95-
attention_layer, attention_head, square, quantity, func, trick, aggregate
94+
attention_layer,
95+
attention_head,
96+
square,
97+
state_board_index,
98+
state_boards,
99+
state_cache,
96100
):
97-
global cache
98-
global boards
99-
global board_index
100101

101-
if cache is None:
102-
gr.Warning("Cache not computed!")
103-
return None, None
102+
if state_cache == []:
103+
gr.Warning("No cache available.")
104+
return None, None, None
104105

105-
board = boards[board_index]
106-
num_attention_layers = len(cache[board_index])
106+
board = state_boards[state_board_index]
107+
num_attention_layers = len(state_cache[state_board_index])
107108
if attention_layer > num_attention_layers:
108109
gr.Warning(
109110
f"Attention layer {attention_layer} does not exist, "
110111
f"using layer {num_attention_layers} instead."
111112
)
112113
attention_layer = num_attention_layers
113114

114-
key = f"{attention_layer-1}-{quantity}-{func}"
115+
key = f"encoder{attention_layer-1}/mha/QK/softmax"
115116
try:
116-
attention_tensor = cache[board_index][key]
117+
attention_tensor = state_cache[state_board_index][key]
117118
except KeyError:
118119
gr.Warning(f"Combination {key} does not exist.")
119120
return None, None, None
120121
if attention_head > attention_tensor.shape[1]:
121122
gr.Warning(
122123
f"Attention head {attention_head} does not exist, "
123-
f"using head {attention_tensor.shape[1]} instead."
124+
f"using head {attention_tensor.shape[1]+1} instead."
124125
)
125126
attention_head = attention_tensor.shape[1]
126127
try:
@@ -132,15 +133,7 @@ def make_plot(
132133
if board.turn == chess.BLACK:
133134
square_index = chess.square_mirror(square_index)
134135

135-
if trick == "revert":
136-
square_index = 63 - square_index
137-
138-
if aggregate == "Row":
139-
heatmap = attention_tensor[0, attention_head - 1, square_index, :]
140-
elif aggregate == "Column":
141-
heatmap = attention_tensor[0, attention_head - 1, :, square_index]
142-
else:
143-
heatmap = attention_tensor[0, attention_head - 1]
136+
heatmap = attention_tensor[0, attention_head - 1, square_index]
144137
if board.turn == chess.BLACK:
145138
heatmap = heatmap.view(8, 8).flip(0).view(64)
146139
svg_board, fig = visualisation.render_heatmap(
@@ -155,37 +148,49 @@ def previous_board(
155148
attention_layer,
156149
attention_head,
157150
square,
158-
from_to,
159-
color_flip,
160-
trick,
161-
aggregate,
151+
state_board_index,
152+
state_boards,
153+
state_cache,
162154
):
163-
global board_index
164-
board_index -= 1
165-
if board_index < 0:
155+
state_board_index -= 1
156+
if state_board_index < 0:
166157
gr.Warning("Already at first board.")
167-
board_index = 0
168-
return make_plot(
169-
attention_layer, attention_head, square, from_to, color_flip
158+
state_board_index = 0
159+
return (
160+
*make_plot(
161+
attention_layer,
162+
attention_head,
163+
square,
164+
state_board_index,
165+
state_boards,
166+
state_cache,
167+
),
168+
state_board_index,
170169
)
171170

172171

173172
def next_board(
174173
attention_layer,
175174
attention_head,
176175
square,
177-
from_to,
178-
color_flip,
179-
trick,
180-
aggregate,
176+
state_board_index,
177+
state_boards,
178+
state_cache,
181179
):
182-
global board_index
183-
board_index += 1
184-
if board_index >= len(boards):
180+
state_board_index += 1
181+
if state_board_index >= len(state_boards):
185182
gr.Warning("Already at last board.")
186-
board_index = len(boards) - 1
187-
return make_plot(
188-
attention_layer, attention_head, square, from_to, color_flip
183+
state_board_index = len(state_boards) - 1
184+
return (
185+
*make_plot(
186+
attention_layer,
187+
attention_head,
188+
square,
189+
state_board_index,
190+
state_boards,
191+
state_cache,
192+
),
193+
state_board_index,
189194
)
190195

191196

@@ -254,38 +259,6 @@ def next_board(
254259
value="a1",
255260
scale=1,
256261
)
257-
quantity = gr.Dropdown(
258-
label="Quantity",
259-
choices=["QK", "Q", "K", "out", "QKV"],
260-
value="QK",
261-
scale=2,
262-
)
263-
aggregate = gr.Dropdown(
264-
label="Aggregate",
265-
choices=["Row", "Column", "None"],
266-
value="Row",
267-
scale=2,
268-
)
269-
func = gr.Dropdown(
270-
label="Function",
271-
choices=[
272-
"softmax",
273-
"transpose",
274-
"matmul",
275-
"scale",
276-
],
277-
value="softmax",
278-
scale=2,
279-
)
280-
trick = gr.Dropdown(
281-
label="Trick",
282-
choices=[
283-
"none",
284-
"revert",
285-
],
286-
value="none",
287-
scale=2,
288-
)
289262
with gr.Row():
290263
previous_board_button = gr.Button("Previous board")
291264
next_board_button = gr.Button("Next board")
@@ -298,32 +271,34 @@ def next_board(
298271
with gr.Column():
299272
image = gr.Image(label="Board")
300273

274+
state_board_index = gr.State(0)
275+
state_boards = gr.State([])
276+
state_cache = gr.State([])
301277
base_inputs = [
302278
attention_layer,
303279
attention_head,
304280
square,
305-
quantity,
306-
func,
307-
trick,
308-
aggregate,
281+
state_board_index,
282+
state_boards,
283+
state_cache,
309284
]
310285
outputs = [image, current_board_fen, colorbar]
311286

312287
compute_cache_button.click(
313288
compute_cache,
314289
inputs=[board_fen, action_seq, model_name] + base_inputs,
315-
outputs=outputs,
290+
outputs=outputs + [state_boards, state_cache],
316291
)
317292

318293
previous_board_button.click(
319-
previous_board, inputs=base_inputs, outputs=outputs
294+
previous_board,
295+
inputs=base_inputs,
296+
outputs=outputs + [state_board_index],
297+
)
298+
next_board_button.click(
299+
next_board, inputs=base_inputs, outputs=outputs + [state_board_index]
320300
)
321-
next_board_button.click(next_board, inputs=base_inputs, outputs=outputs)
322301

323302
attention_layer.change(make_plot, inputs=base_inputs, outputs=outputs)
324303
attention_head.change(make_plot, inputs=base_inputs, outputs=outputs)
325304
square.submit(make_plot, inputs=base_inputs, outputs=outputs)
326-
quantity.change(make_plot, inputs=base_inputs, outputs=outputs)
327-
func.change(make_plot, inputs=base_inputs, outputs=outputs)
328-
trick.change(make_plot, inputs=base_inputs, outputs=outputs)
329-
aggregate.change(make_plot, inputs=base_inputs, outputs=outputs)

demo/state.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
wrappers: Dict[str, ModelWrapper] = {}
1010

1111
lenses: Dict[str, Dict[str, Lens]] = {
12-
"attention": {},
12+
"activation": {},
1313
"lrp": {},
1414
"crp": {},
1515
"policy": {},
1616
"probing": {},
17+
"patching": {},
1718
}

demo/statistics_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
dataset = GameDataset("assets/test_stockfish_10.jsonl")
1515
check_concept = HasThreatConcept("K", relative=True)
1616
unique_check_dataset = ConceptDataset.from_game_dataset(dataset)
17-
unique_check_dataset.concept = check_concept
17+
unique_check_dataset.set_concept(check_concept)
1818

1919

2020
def list_models():

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ src_paths = ["src", "tests", "scripts", "docs", "demo"]
88

99
[tool.poetry]
1010
name = "lczerolens"
11-
version = "0.1.2"
11+
version = "0.1.3"
1212
description = "Interpretability for LeelaChessZero networks."
1313
readme = "README.md"
1414
license = "MIT"

0 commit comments

Comments
 (0)