Skip to content

Commit 0fc5b6b

Browse files
authored
Compatibility with resnet architecture (#8)
* old resnet archi * new onnx cannonisation * increased compatibility * make datasets * fixed tests * sample exploration
1 parent 1e09c55 commit 0fc5b6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1470
-1735
lines changed

.vscode/launch.json

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,37 @@
1919
"justMyCode": false
2020
},
2121
{
22-
"name": "Script CRP Concepts",
22+
"name": "Script CRP concepts",
2323
"type": "debugpy",
2424
"request": "launch",
2525
"module": "scripts.find_concepts",
2626
"console": "integratedTerminal",
2727
"justMyCode": false
2828
},
2929
{
30-
"name": "Script CRP Clusters",
30+
"name": "Script CRP clusters",
3131
"type": "debugpy",
3232
"request": "launch",
3333
"module": "scripts.cluster_latent_relevances",
3434
"console": "integratedTerminal",
3535
"justMyCode": false
36+
},
37+
{
38+
"name": "Script make datasets",
39+
"type": "debugpy",
40+
"request": "launch",
41+
"module": "scripts.make_datasets",
42+
"console": "integratedTerminal",
43+
"justMyCode": false
44+
}
45+
,
46+
{
47+
"name": "Script sample exploration",
48+
"type": "debugpy",
49+
"request": "launch",
50+
"module": "scripts.sample_exploration",
51+
"console": "integratedTerminal",
52+
"justMyCode": false
3653
}
3754
]
3855
}

demo/policy_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def compute_policy(
6767
gr.Warning("Invalid action sequence.")
6868
return (None, None, "", None)
6969
wrapper = utils.get_wrapper_from_state(model_name)
70-
output = wrapper.predict(board)
70+
(output,) = wrapper.predict(board)
7171
current_raw_policy = output["policy"][0]
7272
policy = torch.softmax(output["policy"][0], dim=-1)
7373

demo/statistics_interface.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@
66

77
from demo import utils, visualisation
88
from lczerolens import GameDataset
9-
from lczerolens.xai import HasThreatConcept, UniqueConceptDataset
9+
from lczerolens.xai import ConceptDataset, HasThreatConcept
1010

1111
current_policy_statistics = None
1212
current_lrp_statistics = None
1313
current_probing_statistics = None
1414
dataset = GameDataset("assets/test_stockfish_10.jsonl")
1515
check_concept = HasThreatConcept("K", relative=True)
16-
unique_check_dataset = UniqueConceptDataset.from_game_dataset(
17-
dataset, check_concept
18-
)
16+
unique_check_dataset = ConceptDataset.from_game_dataset(dataset)
17+
unique_check_dataset.concept = check_concept
1918

2019

2120
def list_models():
@@ -47,7 +46,7 @@ def compute_policy_statistics(
4746
)
4847
return None
4948
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "policy")
50-
current_policy_statistics = lens.compute_statistics(dataset, wrapper, 10)
49+
current_policy_statistics = lens.analyse_dataset(dataset, wrapper, 10)
5150
return make_policy_plot()
5251

5352

demo/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import subprocess
88

99
from demo import constants, state
10-
from lczerolens import AutoLens, ModelWrapper
10+
from lczerolens import Lens, ModelWrapper
1111
from lczerolens.utils import lczero as lczero_utils
1212

1313

@@ -122,7 +122,7 @@ def get_wrapper_lens_from_state(
122122
if lens_name in state.lenses[lens_type]:
123123
lens = state.lenses[lens_type][lens_name]
124124
else:
125-
lens = AutoLens.from_type(lens_type, **kwargs)
125+
lens = Lens.from_name(lens_type, **kwargs)
126126
if not lens.is_compatible(wrapper):
127127
raise ValueError(
128128
f"Lens of type {lens_type} not compatible with model."

demo/visualisation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def render_architecture(model, name: str = "model", directory: str = ""):
102102
def render_policy_distribution(
103103
policy,
104104
legal_moves,
105-
n_bins=10,
105+
n_bins=20,
106106
):
107107
"""
108108
Render the policy distribution histogram.
@@ -112,11 +112,12 @@ def render_policy_distribution(
112112
).bool()
113113
fig = plt.figure(figsize=(6, 6))
114114
ax = plt.gca()
115-
_, bins, _ = ax.hist(
115+
_, bins = np.histogram(policy, bins=n_bins)
116+
ax.hist(
116117
policy[~legal_mask],
117-
bins=n_bins,
118-
density=True,
118+
bins=bins,
119119
alpha=0.5,
120+
density=True,
120121
label="Illegal moves",
121122
)
122123
ax.hist(

scripts/cluster_latent_relevances.py

Lines changed: 118 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import torch
1515
from crp.attribution import CondAttribution
1616
from crp.concepts import ChannelConcept
17-
from crp.helper import get_layer_names
1817
from pylatex import Document
1918
from pylatex.package import Package
2019
from safetensors import safe_open
@@ -23,88 +22,104 @@
2322
from sklearn.manifold import TSNE
2423
from tqdm import tqdm
2524

26-
from lczerolens import GameDataset, move_utils
27-
from lczerolens.adapt import PolicyFlow
28-
from lczerolens.xai import LrpLens, UniqueConceptDataset
29-
from lczerolens.xai.concepts import BestLegalMoveConcept
25+
from lczerolens import move_utils
26+
from lczerolens.game import PolicyFlow
27+
from lczerolens.xai import ConceptDataset, LrpLens
3028
from scripts.create_figure import add_plot, create_heatmap_string
3129

3230
#######################################
3331
# HYPERPARAMETERS
3432
#######################################
3533
n_clusters = 10
36-
layer_index = -1
3734
batch_size = 500
38-
save_files = False
35+
save_files = True
3936
conv_sum_dims = (2, 3)
40-
model_name = "tinygyal-8.onnx"
41-
dataset_name = "test_stockfish_10.jsonl"
37+
model_name = "64x6-2018_0627_1913_08_161.onnx"
38+
dataset_name = "TCEC_game_collection_random_boards_bestlegal_knight.jsonl"
4239
only_config_rel = True
40+
best_legal = True
41+
run_name = (
42+
f"bestres_tcec_bestlegal_knight_{'expbest' if best_legal else 'full'}"
43+
)
4344
#######################################
4445

4546

46-
class MaxLogitFlow(PolicyFlow):
47-
def forward(self, x):
48-
policy = super().forward(x)
49-
return policy.max(dim=1, keepdim=True).values
47+
def legal_init_rel(board_list, board_tensor):
48+
legal_move_mask = torch.zeros((len(board_list), 1858))
49+
for idx, board in enumerate(board_list):
50+
legal_moves = [
51+
move_utils.encode_move(move, (board.turn, not board.turn))
52+
for move in board.legal_moves
53+
]
54+
legal_move_mask[idx, legal_moves] = 1
55+
return legal_move_mask * board_tensor
5056

5157

52-
model = MaxLogitFlow.from_path(f"./assets/{model_name}")
53-
dataset = GameDataset(f"./assets/{dataset_name}")
54-
concept = BestLegalMoveConcept(model)
55-
unique_dataset = UniqueConceptDataset.from_game_dataset(dataset, concept)
56-
print(f"[INFO] Board dataset len: {len(unique_dataset)}")
58+
model = PolicyFlow.from_path(f"./assets/{model_name}")
59+
concept_dataset = ConceptDataset(f"./assets/{dataset_name}")
60+
print(f"[INFO] Board dataset len: {len(concept_dataset)}")
5761

5862
composite = LrpLens.make_default_composite()
59-
attribution = CondAttribution(model)
6063
cc = ChannelConcept()
61-
62-
layer_names = get_layer_names(model, [torch.nn.ReLU])
63-
layer_names = [
64-
layer_name for layer_name in layer_names if "block" in layer_name
65-
]
64+
layer_names = [f"model.block{b}/conv2/relu" for b in [0, 3, 5]]
6665
print(layer_names)
6766

6867
dataloader = torch.utils.data.DataLoader(
69-
unique_dataset,
68+
concept_dataset,
7069
batch_size=batch_size,
7170
shuffle=False,
72-
collate_fn=UniqueConceptDataset.collate_fn_tensor,
71+
collate_fn=ConceptDataset.collate_fn_tensor,
7372
)
7473

7574
if save_files:
7675
print("############ Collecting Relevances")
7776
all_relevances = {}
7877
for batch in tqdm(dataloader):
79-
_, board_tensor, _ = batch
78+
_, board_tensor, labels = batch
79+
label_tensor = torch.tensor(labels)
80+
81+
def init_rel_fn(board_tensor):
82+
rel = torch.zeros_like(board_tensor)
83+
rel[:, label_tensor] = board_tensor[:, label_tensor]
84+
return rel
85+
8086
board_tensor.requires_grad = True
81-
attr = attribution(
82-
board_tensor, [{"y": 0}], composite, record_layer=layer_names
83-
)
84-
85-
for layer_name in layer_names:
86-
latent_rel = attr.relevances[layer_name]
87-
latent_rel = cc.attribute(latent_rel, abs_norm=True)
88-
if len(latent_rel.shape) == 4:
89-
latent_rel = latent_rel.sum(conv_sum_dims)
90-
if layer_name not in all_relevances:
91-
all_relevances[layer_name] = latent_rel.detach().cpu()
92-
else:
93-
all_relevances[layer_name] = torch.cat(
94-
[all_relevances[layer_name], latent_rel.detach().cpu()],
95-
dim=0,
96-
)
87+
with LrpLens.context(model) as modifed_model:
88+
attribution = CondAttribution(modifed_model)
89+
attr = attribution(
90+
board_tensor,
91+
[{"y": None}],
92+
composite,
93+
record_layer=layer_names,
94+
init_rel=init_rel_fn if best_legal else None,
95+
)
9796

98-
os.makedirs(f"scripts/clusters/{model_name}-{dataset_name}", exist_ok=True)
97+
for layer_name in layer_names:
98+
latent_rel = attr.relevances[layer_name]
99+
latent_rel = cc.attribute(latent_rel, abs_norm=True)
100+
if len(latent_rel.shape) == 4:
101+
latent_rel = latent_rel.sum(conv_sum_dims)
102+
if layer_name not in all_relevances:
103+
all_relevances[layer_name] = latent_rel.detach().cpu()
104+
else:
105+
all_relevances[layer_name] = torch.cat(
106+
[
107+
all_relevances[layer_name],
108+
latent_rel.detach().cpu(),
109+
],
110+
dim=0,
111+
)
112+
113+
os.makedirs(f"scripts/clusters/{run_name}", exist_ok=True)
99114
save_file(
100115
all_relevances,
101-
f"scripts/clusters/{model_name}-{dataset_name}/relevances.safetensors",
116+
f"scripts/clusters/{run_name}/relevances.safetensors",
102117
)
103118

104119
else:
105120
all_relevances = {}
106121
with safe_open(
107-
f"scripts/clusters/{model_name}-{dataset_name}/relevances.safetensors",
122+
f"scripts/clusters/{run_name}/relevances.safetensors",
108123
framework="pt",
109124
device="cpu",
110125
) as f:
@@ -116,7 +131,7 @@ def forward(self, x):
116131
#######################################
117132

118133
print("############ Clustering ...")
119-
os.makedirs(f"scripts/results/{model_name}-{dataset_name}", exist_ok=True)
134+
os.makedirs(f"scripts/results/{run_name}", exist_ok=True)
120135

121136
for layer_name, relevances in all_relevances.items():
122137
kmeans = KMeans(n_clusters=n_clusters, init="k-means++")
@@ -132,7 +147,7 @@ def forward(self, x):
132147
plt.xlabel("Dimension 1")
133148
plt.ylabel("Dimension 2")
134149
plt.savefig(
135-
f"scripts/results/{model_name}-{dataset_name}/{layer_name}_t-sne.png"
150+
f"scripts/results/{run_name}/{layer_name.replace('/','.')}_t-sne.png"
136151
)
137152
plt.close()
138153

@@ -141,47 +156,60 @@ def forward(self, x):
141156
#######################################
142157

143158
print("############ Plotting chessboards for each cluster")
159+
with LrpLens.context(model) as modifed_model:
160+
attribution = CondAttribution(modifed_model)
161+
for idx_cluster in tqdm(range(n_clusters)):
162+
cluster_center = kmeans.cluster_centers_[idx_cluster]
163+
distances = np.linalg.norm(relevances - cluster_center, axis=1)
164+
nearest_neighbors = np.argsort(distances)[:8]
165+
166+
doc = Document() # create a new document
167+
doc.packages.append(Package("xskak"))
168+
169+
# compute heatmap for each nearest neighbor
170+
for idx_sample in nearest_neighbors:
171+
_, board, label = concept_dataset[idx_sample]
172+
_, board_tensor, _ = ConceptDataset.collate_fn_tensor(
173+
[concept_dataset[idx_sample]]
174+
)
175+
label_tensor = torch.tensor([label])
176+
177+
def init_rel_fn(board_tensor):
178+
rel = torch.zeros_like(board_tensor)
179+
rel[:, label_tensor] = board_tensor[:, label_tensor]
180+
return rel
181+
182+
board_tensor.requires_grad = True
183+
attr = attribution(
184+
board_tensor,
185+
[{"y": None}],
186+
composite,
187+
init_rel=init_rel_fn if best_legal else None,
188+
)
189+
if only_config_rel:
190+
heatmap = board_tensor.grad[0, :12].sum(dim=0).view(64)
191+
else:
192+
heatmap = board_tensor.grad[0].sum(dim=0).view(64)
193+
if board.turn == chess.BLACK:
194+
heatmap = heatmap.view(8, 8).flip(0).view(64)
195+
move = move_utils.decode_move(
196+
label, (board.turn, not board.turn), board
197+
)
198+
uci_move = move.uci()
199+
heatmap = heatmap / heatmap.abs().max()
200+
heatmap_str = create_heatmap_string(heatmap)
201+
202+
doc = add_plot(
203+
doc,
204+
board.fen(),
205+
heatmap_str,
206+
current_piece_pos=uci_move[:2],
207+
next_move=uci_move[2:4],
208+
)
144209

145-
for idx_cluster in tqdm(range(n_clusters)):
146-
cluster_center = kmeans.cluster_centers_[idx_cluster]
147-
distances = np.linalg.norm(relevances - cluster_center, axis=1)
148-
nearest_neighbors = np.argsort(distances)[:10]
149-
150-
doc = Document() # create a new document
151-
doc.packages.append(Package("xskak"))
152-
153-
# compute heatmap for each nearest neighbor
154-
for idx_sample in nearest_neighbors:
155-
_, board, label = unique_dataset[idx_sample]
156-
_, board_tensor, _ = UniqueConceptDataset.collate_fn_tensor(
157-
[unique_dataset[idx_sample]]
210+
# Generate pdf
211+
doc.generate_pdf(
212+
f"scripts/results/{run_name}"
213+
f"/{layer_name.replace('/','.')}_cluster_{idx_cluster}",
214+
clean_tex=True,
158215
)
159-
board_tensor.requires_grad = True
160-
attr = attribution(board_tensor, [{"y": 0}], composite)
161-
if only_config_rel:
162-
heatmap = board_tensor.grad[0, :12].sum(dim=0).view(64)
163-
else:
164-
heatmap = board_tensor.grad[0].sum(dim=0).view(64)
165-
if board.turn == chess.BLACK:
166-
heatmap = heatmap.view(8, 8).flip(0).view(64)
167-
move = move_utils.decode_move(
168-
label, (board.turn, not board.turn), board
169-
)
170-
uci_move = move.uci()
171-
heatmap = heatmap / heatmap.abs().max()
172-
heatmap_str = create_heatmap_string(heatmap)
173-
174-
doc = add_plot(
175-
doc,
176-
board.fen(),
177-
heatmap_str,
178-
current_piece_pos=uci_move[:2],
179-
next_move=uci_move[2:4],
180-
)
181-
182-
# Generate pdf
183-
doc.generate_pdf(
184-
f"scripts/results/{model_name}-{dataset_name}"
185-
f"/{layer_name}_cluster_{idx_cluster}",
186-
clean_tex=True,
187-
)

0 commit comments

Comments
 (0)