14
14
import torch
15
15
from crp .attribution import CondAttribution
16
16
from crp .concepts import ChannelConcept
17
- from crp .helper import get_layer_names
18
17
from pylatex import Document
19
18
from pylatex .package import Package
20
19
from safetensors import safe_open
23
22
from sklearn .manifold import TSNE
24
23
from tqdm import tqdm
25
24
26
- from lczerolens import GameDataset , move_utils
27
- from lczerolens .adapt import PolicyFlow
28
- from lczerolens .xai import LrpLens , UniqueConceptDataset
29
- from lczerolens .xai .concepts import BestLegalMoveConcept
25
+ from lczerolens import move_utils
26
+ from lczerolens .game import PolicyFlow
27
+ from lczerolens .xai import ConceptDataset , LrpLens
30
28
from scripts .create_figure import add_plot , create_heatmap_string
31
29
32
30
#######################################
33
31
# HYPERPARAMETERS
34
32
#######################################
35
33
n_clusters = 10
36
- layer_index = - 1
37
34
batch_size = 500
38
- save_files = False
35
+ save_files = True
39
36
conv_sum_dims = (2 , 3 )
40
- model_name = "tinygyal-8 .onnx"
41
- dataset_name = "test_stockfish_10 .jsonl"
37
+ model_name = "64x6-2018_0627_1913_08_161 .onnx"
38
+ dataset_name = "TCEC_game_collection_random_boards_bestlegal_knight .jsonl"
42
39
only_config_rel = True
40
+ best_legal = True
41
+ run_name = (
42
+ f"bestres_tcec_bestlegal_knight_{ 'expbest' if best_legal else 'full' } "
43
+ )
43
44
#######################################
44
45
45
46
46
- class MaxLogitFlow (PolicyFlow ):
47
- def forward (self , x ):
48
- policy = super ().forward (x )
49
- return policy .max (dim = 1 , keepdim = True ).values
47
+ def legal_init_rel (board_list , board_tensor ):
48
+ legal_move_mask = torch .zeros ((len (board_list ), 1858 ))
49
+ for idx , board in enumerate (board_list ):
50
+ legal_moves = [
51
+ move_utils .encode_move (move , (board .turn , not board .turn ))
52
+ for move in board .legal_moves
53
+ ]
54
+ legal_move_mask [idx , legal_moves ] = 1
55
+ return legal_move_mask * board_tensor
50
56
51
57
52
- model = MaxLogitFlow .from_path (f"./assets/{ model_name } " )
53
- dataset = GameDataset (f"./assets/{ dataset_name } " )
54
- concept = BestLegalMoveConcept (model )
55
- unique_dataset = UniqueConceptDataset .from_game_dataset (dataset , concept )
56
- print (f"[INFO] Board dataset len: { len (unique_dataset )} " )
58
+ model = PolicyFlow .from_path (f"./assets/{ model_name } " )
59
+ concept_dataset = ConceptDataset (f"./assets/{ dataset_name } " )
60
+ print (f"[INFO] Board dataset len: { len (concept_dataset )} " )
57
61
58
62
composite = LrpLens .make_default_composite ()
59
- attribution = CondAttribution (model )
60
63
cc = ChannelConcept ()
61
-
62
- layer_names = get_layer_names (model , [torch .nn .ReLU ])
63
- layer_names = [
64
- layer_name for layer_name in layer_names if "block" in layer_name
65
- ]
64
+ layer_names = [f"model.block{ b } /conv2/relu" for b in [0 , 3 , 5 ]]
66
65
print (layer_names )
67
66
68
67
dataloader = torch .utils .data .DataLoader (
69
- unique_dataset ,
68
+ concept_dataset ,
70
69
batch_size = batch_size ,
71
70
shuffle = False ,
72
- collate_fn = UniqueConceptDataset .collate_fn_tensor ,
71
+ collate_fn = ConceptDataset .collate_fn_tensor ,
73
72
)
74
73
75
74
if save_files :
76
75
print ("############ Collecting Relevances" )
77
76
all_relevances = {}
78
77
for batch in tqdm (dataloader ):
79
- _ , board_tensor , _ = batch
78
+ _ , board_tensor , labels = batch
79
+ label_tensor = torch .tensor (labels )
80
+
81
+ def init_rel_fn (board_tensor ):
82
+ rel = torch .zeros_like (board_tensor )
83
+ rel [:, label_tensor ] = board_tensor [:, label_tensor ]
84
+ return rel
85
+
80
86
board_tensor .requires_grad = True
81
- attr = attribution (
82
- board_tensor , [{"y" : 0 }], composite , record_layer = layer_names
83
- )
84
-
85
- for layer_name in layer_names :
86
- latent_rel = attr .relevances [layer_name ]
87
- latent_rel = cc .attribute (latent_rel , abs_norm = True )
88
- if len (latent_rel .shape ) == 4 :
89
- latent_rel = latent_rel .sum (conv_sum_dims )
90
- if layer_name not in all_relevances :
91
- all_relevances [layer_name ] = latent_rel .detach ().cpu ()
92
- else :
93
- all_relevances [layer_name ] = torch .cat (
94
- [all_relevances [layer_name ], latent_rel .detach ().cpu ()],
95
- dim = 0 ,
96
- )
87
+ with LrpLens .context (model ) as modifed_model :
88
+ attribution = CondAttribution (modifed_model )
89
+ attr = attribution (
90
+ board_tensor ,
91
+ [{"y" : None }],
92
+ composite ,
93
+ record_layer = layer_names ,
94
+ init_rel = init_rel_fn if best_legal else None ,
95
+ )
97
96
98
- os .makedirs (f"scripts/clusters/{ model_name } -{ dataset_name } " , exist_ok = True )
97
+ for layer_name in layer_names :
98
+ latent_rel = attr .relevances [layer_name ]
99
+ latent_rel = cc .attribute (latent_rel , abs_norm = True )
100
+ if len (latent_rel .shape ) == 4 :
101
+ latent_rel = latent_rel .sum (conv_sum_dims )
102
+ if layer_name not in all_relevances :
103
+ all_relevances [layer_name ] = latent_rel .detach ().cpu ()
104
+ else :
105
+ all_relevances [layer_name ] = torch .cat (
106
+ [
107
+ all_relevances [layer_name ],
108
+ latent_rel .detach ().cpu (),
109
+ ],
110
+ dim = 0 ,
111
+ )
112
+
113
+ os .makedirs (f"scripts/clusters/{ run_name } " , exist_ok = True )
99
114
save_file (
100
115
all_relevances ,
101
- f"scripts/clusters/{ model_name } - { dataset_name } /relevances.safetensors" ,
116
+ f"scripts/clusters/{ run_name } /relevances.safetensors" ,
102
117
)
103
118
104
119
else :
105
120
all_relevances = {}
106
121
with safe_open (
107
- f"scripts/clusters/{ model_name } - { dataset_name } /relevances.safetensors" ,
122
+ f"scripts/clusters/{ run_name } /relevances.safetensors" ,
108
123
framework = "pt" ,
109
124
device = "cpu" ,
110
125
) as f :
@@ -116,7 +131,7 @@ def forward(self, x):
116
131
#######################################
117
132
118
133
print ("############ Clustering ..." )
119
- os .makedirs (f"scripts/results/{ model_name } - { dataset_name } " , exist_ok = True )
134
+ os .makedirs (f"scripts/results/{ run_name } " , exist_ok = True )
120
135
121
136
for layer_name , relevances in all_relevances .items ():
122
137
kmeans = KMeans (n_clusters = n_clusters , init = "k-means++" )
@@ -132,7 +147,7 @@ def forward(self, x):
132
147
plt .xlabel ("Dimension 1" )
133
148
plt .ylabel ("Dimension 2" )
134
149
plt .savefig (
135
- f"scripts/results/{ model_name } - { dataset_name } /{ layer_name } _t-sne.png"
150
+ f"scripts/results/{ run_name } /{ layer_name . replace ( '/' , '.' ) } _t-sne.png"
136
151
)
137
152
plt .close ()
138
153
@@ -141,47 +156,60 @@ def forward(self, x):
141
156
#######################################
142
157
143
158
print ("############ Plotting chessboards for each cluster" )
159
+ with LrpLens .context (model ) as modifed_model :
160
+ attribution = CondAttribution (modifed_model )
161
+ for idx_cluster in tqdm (range (n_clusters )):
162
+ cluster_center = kmeans .cluster_centers_ [idx_cluster ]
163
+ distances = np .linalg .norm (relevances - cluster_center , axis = 1 )
164
+ nearest_neighbors = np .argsort (distances )[:8 ]
165
+
166
+ doc = Document () # create a new document
167
+ doc .packages .append (Package ("xskak" ))
168
+
169
+ # compute heatmap for each nearest neighbor
170
+ for idx_sample in nearest_neighbors :
171
+ _ , board , label = concept_dataset [idx_sample ]
172
+ _ , board_tensor , _ = ConceptDataset .collate_fn_tensor (
173
+ [concept_dataset [idx_sample ]]
174
+ )
175
+ label_tensor = torch .tensor ([label ])
176
+
177
+ def init_rel_fn (board_tensor ):
178
+ rel = torch .zeros_like (board_tensor )
179
+ rel [:, label_tensor ] = board_tensor [:, label_tensor ]
180
+ return rel
181
+
182
+ board_tensor .requires_grad = True
183
+ attr = attribution (
184
+ board_tensor ,
185
+ [{"y" : None }],
186
+ composite ,
187
+ init_rel = init_rel_fn if best_legal else None ,
188
+ )
189
+ if only_config_rel :
190
+ heatmap = board_tensor .grad [0 , :12 ].sum (dim = 0 ).view (64 )
191
+ else :
192
+ heatmap = board_tensor .grad [0 ].sum (dim = 0 ).view (64 )
193
+ if board .turn == chess .BLACK :
194
+ heatmap = heatmap .view (8 , 8 ).flip (0 ).view (64 )
195
+ move = move_utils .decode_move (
196
+ label , (board .turn , not board .turn ), board
197
+ )
198
+ uci_move = move .uci ()
199
+ heatmap = heatmap / heatmap .abs ().max ()
200
+ heatmap_str = create_heatmap_string (heatmap )
201
+
202
+ doc = add_plot (
203
+ doc ,
204
+ board .fen (),
205
+ heatmap_str ,
206
+ current_piece_pos = uci_move [:2 ],
207
+ next_move = uci_move [2 :4 ],
208
+ )
144
209
145
- for idx_cluster in tqdm (range (n_clusters )):
146
- cluster_center = kmeans .cluster_centers_ [idx_cluster ]
147
- distances = np .linalg .norm (relevances - cluster_center , axis = 1 )
148
- nearest_neighbors = np .argsort (distances )[:10 ]
149
-
150
- doc = Document () # create a new document
151
- doc .packages .append (Package ("xskak" ))
152
-
153
- # compute heatmap for each nearest neighbor
154
- for idx_sample in nearest_neighbors :
155
- _ , board , label = unique_dataset [idx_sample ]
156
- _ , board_tensor , _ = UniqueConceptDataset .collate_fn_tensor (
157
- [unique_dataset [idx_sample ]]
210
+ # Generate pdf
211
+ doc .generate_pdf (
212
+ f"scripts/results/{ run_name } "
213
+ f"/{ layer_name .replace ('/' ,'.' )} _cluster_{ idx_cluster } " ,
214
+ clean_tex = True ,
158
215
)
159
- board_tensor .requires_grad = True
160
- attr = attribution (board_tensor , [{"y" : 0 }], composite )
161
- if only_config_rel :
162
- heatmap = board_tensor .grad [0 , :12 ].sum (dim = 0 ).view (64 )
163
- else :
164
- heatmap = board_tensor .grad [0 ].sum (dim = 0 ).view (64 )
165
- if board .turn == chess .BLACK :
166
- heatmap = heatmap .view (8 , 8 ).flip (0 ).view (64 )
167
- move = move_utils .decode_move (
168
- label , (board .turn , not board .turn ), board
169
- )
170
- uci_move = move .uci ()
171
- heatmap = heatmap / heatmap .abs ().max ()
172
- heatmap_str = create_heatmap_string (heatmap )
173
-
174
- doc = add_plot (
175
- doc ,
176
- board .fen (),
177
- heatmap_str ,
178
- current_piece_pos = uci_move [:2 ],
179
- next_move = uci_move [2 :4 ],
180
- )
181
-
182
- # Generate pdf
183
- doc .generate_pdf (
184
- f"scripts/results/{ model_name } -{ dataset_name } "
185
- f"/{ layer_name } _cluster_{ idx_cluster } " ,
186
- clean_tex = True ,
187
- )
0 commit comments