8
8
9
9
import os
10
10
11
- import chess
12
11
import matplotlib .pyplot as plt
13
12
import numpy as np
14
13
import torch
15
14
from crp .attribution import CondAttribution
16
- from crp .concepts import ChannelConcept
17
15
from pylatex import Document
18
16
from pylatex .package import Package
19
17
from safetensors import safe_open
30
28
#######################################
31
29
# HYPERPARAMETERS
32
30
#######################################
33
- n_clusters = 10
31
+ n_clusters = 15
34
32
batch_size = 500
35
- save_files = True
36
- conv_sum_dims = (2 , 3 )
33
+ save_files = False
37
34
model_name = "64x6-2018_0627_1913_08_161.onnx"
38
35
dataset_name = "TCEC_game_collection_random_boards_bestlegal_knight.jsonl"
39
36
only_config_rel = True
40
37
best_legal = True
41
38
run_name = (
42
39
f"bestres_tcec_bestlegal_knight_{ 'expbest' if best_legal else 'full' } "
43
40
)
41
+ n_samples = 1000
42
+ conv_sum_dims = ()
43
+ cosine_sim = False
44
+ kmeans_on_tsne = True
45
+ viz_latent = True
46
+ viz_name = (
47
+ f"{ 'latent' if viz_latent else 'input' } "
48
+ f"_nosum_{ 'cosine' if cosine_sim else 'norm' } "
49
+ f"_{ 'after' if kmeans_on_tsne else 'before' } -tsne"
50
+ )
44
51
#######################################
45
52
46
53
@@ -60,7 +67,6 @@ def legal_init_rel(board_list, board_tensor):
60
67
print (f"[INFO] Board dataset len: { len (concept_dataset )} " )
61
68
62
69
composite = LrpLens .make_default_composite ()
63
- cc = ChannelConcept ()
64
70
layer_names = [f"model.block{ b } /conv2/relu" for b in [0 , 3 , 5 ]]
65
71
print (layer_names )
66
72
@@ -96,9 +102,6 @@ def init_rel_fn(board_tensor):
96
102
97
103
for layer_name in layer_names :
98
104
latent_rel = attr .relevances [layer_name ]
99
- latent_rel = cc .attribute (latent_rel , abs_norm = True )
100
- if len (latent_rel .shape ) == 4 :
101
- latent_rel = latent_rel .sum (conv_sum_dims )
102
105
if layer_name not in all_relevances :
103
106
all_relevances [layer_name ] = latent_rel .detach ().cpu ()
104
107
else :
@@ -131,23 +134,34 @@ def init_rel_fn(board_tensor):
131
134
#######################################
132
135
133
136
print ("############ Clustering ..." )
134
- os .makedirs (f"scripts/results/{ run_name } " , exist_ok = True )
137
+ os .makedirs (f"scripts/results/{ run_name } / { viz_name } " , exist_ok = True )
135
138
136
139
for layer_name , relevances in all_relevances .items ():
137
- kmeans = KMeans (n_clusters = n_clusters , init = "k-means++" )
138
- kmeans .fit (relevances )
140
+ relevances = relevances [:n_samples ]
141
+ if conv_sum_dims :
142
+ relevances = relevances .sum (dim = conv_sum_dims ).view (
143
+ relevances .shape [0 ], - 1
144
+ )
145
+ else :
146
+ relevances = relevances .view (relevances .shape [0 ], - 1 )
139
147
140
148
# Perform t-SNE dimensionality reduction
141
149
tsne = TSNE (n_components = 2 )
142
150
latent_rel_tsne = tsne .fit_transform (relevances )
143
151
152
+ if kmeans_on_tsne :
153
+ relevances = latent_rel_tsne
154
+ kmeans = KMeans (n_clusters = n_clusters , init = "k-means++" )
155
+ kmeans .fit (relevances )
156
+
144
157
# Plot the clustered data
145
158
plt .scatter (latent_rel_tsne [:, 0 ], latent_rel_tsne [:, 1 ], c = kmeans .labels_ )
146
159
plt .title ("Clustered Latent Relevances" )
147
160
plt .xlabel ("Dimension 1" )
148
161
plt .ylabel ("Dimension 2" )
149
162
plt .savefig (
150
- f"scripts/results/{ run_name } /{ layer_name .replace ('/' ,'.' )} _t-sne.png"
163
+ f"scripts/results/{ run_name } /{ viz_name } /"
164
+ f"{ layer_name .replace ('/' ,'.' )} _t-sne.png"
151
165
)
152
166
plt .close ()
153
167
@@ -160,10 +174,25 @@ def init_rel_fn(board_tensor):
160
174
attribution = CondAttribution (modifed_model )
161
175
for idx_cluster in tqdm (range (n_clusters )):
162
176
cluster_center = kmeans .cluster_centers_ [idx_cluster ]
163
- distances = np .linalg .norm (relevances - cluster_center , axis = 1 )
164
- nearest_neighbors = np .argsort (distances )[:8 ]
177
+ if cosine_sim :
178
+ dot_prod = relevances @ cluster_center .T
179
+ similarities = dot_prod / (
180
+ np .linalg .norm (relevances , axis = 1 )
181
+ * np .linalg .norm (cluster_center )
182
+ )
183
+ nearest_neighbors = np .argsort (similarities )[- 8 :]
184
+ else :
185
+ distances = np .linalg .norm (relevances - cluster_center , axis = 1 )
186
+ nearest_neighbors = np .argsort (distances )[:8 ]
165
187
166
- doc = Document () # create a new document
188
+ doc = Document (
189
+ geometry_options = {
190
+ "lmargin" : "3cm" ,
191
+ "tmargin" : "0.5cm" ,
192
+ "bmargin" : "1.5cm" ,
193
+ "rmargin" : "3cm" ,
194
+ }
195
+ )
167
196
doc .packages .append (Package ("xskak" ))
168
197
169
198
# compute heatmap for each nearest neighbor
@@ -179,37 +208,83 @@ def init_rel_fn(board_tensor):
179
208
rel [:, label_tensor ] = board_tensor [:, label_tensor ]
180
209
return rel
181
210
182
- board_tensor .requires_grad = True
183
- attr = attribution (
184
- board_tensor ,
185
- [{"y" : None }],
186
- composite ,
187
- init_rel = init_rel_fn if best_legal else None ,
188
- )
189
- if only_config_rel :
190
- heatmap = board_tensor .grad [0 , :12 ].sum (dim = 0 ).view (64 )
191
- else :
192
- heatmap = board_tensor .grad [0 ].sum (dim = 0 ).view (64 )
193
- if board .turn == chess .BLACK :
194
- heatmap = heatmap .view (8 , 8 ).flip (0 ).view (64 )
195
211
move = move_utils .decode_move (
196
212
label , (board .turn , not board .turn ), board
197
213
)
198
214
uci_move = move .uci ()
199
- heatmap = heatmap / heatmap .abs ().max ()
200
- heatmap_str = create_heatmap_string (heatmap )
215
+
216
+ if viz_latent :
217
+ latent_rel = all_relevances [layer_name ][idx_sample ]
218
+ if not board .turn :
219
+ latent_rel = latent_rel .flip (1 )
220
+ latent_rel = latent_rel .view (- 1 , 64 )
221
+ channel_rels = latent_rel .abs ().sum (dim = 1 )
222
+ c1 , c2 = torch .topk (channel_rels , 2 ).indices
223
+ heatmap_str_list = [
224
+ create_heatmap_string (latent_rel .sum (0 ), abs_max = True ),
225
+ create_heatmap_string (latent_rel [c1 ], abs_max = True ),
226
+ create_heatmap_string (latent_rel [c2 ], abs_max = True ),
227
+ ]
228
+ heatmap_caption_list = [
229
+ "Total relevance" ,
230
+ "Best channel" ,
231
+ "Second best channel" ,
232
+ ]
233
+ add_caption = "latent"
234
+ else :
235
+ board_tensor .requires_grad = True
236
+ attr = attribution (
237
+ board_tensor ,
238
+ [{"y" : None }],
239
+ composite ,
240
+ init_rel = init_rel_fn if best_legal else None ,
241
+ )
242
+ input_relevances = board_tensor .grad
243
+ if not board .turn :
244
+ input_relevances = (
245
+ input_relevances .view (112 , 8 , 8 )
246
+ .flip (1 )
247
+ .view (112 , 64 )
248
+ )
249
+ input_relevances = input_relevances .view (112 , 64 )
250
+ heatmap_str_list = [
251
+ create_heatmap_string (
252
+ input_relevances .sum (dim = 0 ), abs_max = True
253
+ ),
254
+ create_heatmap_string (
255
+ input_relevances [:13 ].sum (dim = 0 ), abs_max = True
256
+ ),
257
+ create_heatmap_string (
258
+ input_relevances [104 :].sum (dim = 0 ), abs_max = True
259
+ ),
260
+ ]
261
+ heatmap_caption_list = [
262
+ "Total relevance" ,
263
+ "Current config relevance" ,
264
+ "Meta relevance" ,
265
+ ]
266
+ h0 = input_relevances [:13 ].abs ().sum ()
267
+ hist = input_relevances [13 :104 ].abs ().sum ()
268
+ meta = input_relevances [104 :].abs ().sum ()
269
+ total = (h0 + hist + meta ) / 100
270
+ add_caption = (
271
+ f"{ h0 / total :.0f} %|{ hist / total :.0f} %|{ meta / total :.0f} %"
272
+ )
201
273
202
274
doc = add_plot (
203
275
doc ,
204
276
board .fen (),
205
- heatmap_str ,
277
+ heatmap_str_list ,
206
278
current_piece_pos = uci_move [:2 ],
207
279
next_move = uci_move [2 :4 ],
280
+ caption = f"Sample { idx_sample } - { add_caption } " ,
281
+ heatmap_caption_list = heatmap_caption_list ,
208
282
)
209
283
210
284
# Generate pdf
211
285
doc .generate_pdf (
212
286
f"scripts/results/{ run_name } "
213
- f"/{ layer_name .replace ('/' ,'.' )} _cluster_{ idx_cluster } " ,
287
+ f"/{ viz_name } /{ layer_name .replace ('/' ,'.' )} "
288
+ f"_cluster_{ idx_cluster } " ,
214
289
clean_tex = True ,
215
290
)
0 commit comments