9
9
10
10
from demo import constants , utils , visualisation
11
11
12
- cache = None
13
- boards = None
14
- board_index = 0
15
-
16
12
17
13
def list_models ():
18
14
"""
@@ -38,89 +34,94 @@ def compute_cache(
38
34
attention_layer ,
39
35
attention_head ,
40
36
square ,
41
- quantity ,
42
- func ,
43
- trick ,
44
- aggregate ,
37
+ state_board_index ,
38
+ state_boards ,
39
+ state_cache ,
45
40
):
46
- global cache
47
- global boards
48
41
if model_name == "" :
49
42
gr .Warning ("No model selected." )
50
- return None , None , None
43
+ return None , None , None , state_boards , state_cache
51
44
52
45
try :
53
46
board = chess .Board (board_fen )
54
47
except ValueError :
55
48
board = chess .Board ()
56
49
gr .Warning ("Invalid FEN, using starting position." )
57
- boards = [board .copy ()]
50
+ state_boards = [board .copy ()]
58
51
if action_seq :
59
52
try :
60
53
if action_seq .startswith ("1." ):
61
54
for action in action_seq .split ():
62
55
if action .endswith ("." ):
63
56
continue
64
57
board .push_san (action )
65
- boards .append (board .copy ())
58
+ state_boards .append (board .copy ())
66
59
else :
67
60
for action in action_seq .split ():
68
61
board .push_uci (action )
69
- boards .append (board .copy ())
62
+ state_boards .append (board .copy ())
70
63
except ValueError :
71
64
gr .Warning (f"Invalid action { action } stopping before it." )
72
65
try :
73
66
wrapper , lens = utils .get_wrapper_lens_from_state (
74
- model_name , "attention"
67
+ model_name ,
68
+ "activation" ,
69
+ lens_name = "attention" ,
70
+ module_exp = r"encoder\d+/mha/QK/softmax" ,
75
71
)
76
72
except ValueError :
77
73
gr .Warning ("Could not load model." )
78
- return None , None , None
79
- cache = []
80
- for board in boards :
81
- attention_cache = copy .deepcopy (lens .compute_heatmap (board , wrapper ))
82
- cache .append (attention_cache )
83
- return make_plot (
84
- attention_layer ,
85
- attention_head ,
86
- square ,
87
- quantity ,
88
- func ,
89
- trick ,
90
- aggregate ,
74
+ return None , None , None , state_boards , state_cache
75
+ state_cache = []
76
+ for board in state_boards :
77
+ attention_cache = copy .deepcopy (lens .analyse_board (board , wrapper ))
78
+ state_cache .append (attention_cache )
79
+ return (
80
+ * make_plot (
81
+ attention_layer ,
82
+ attention_head ,
83
+ square ,
84
+ state_board_index ,
85
+ state_boards ,
86
+ state_cache ,
87
+ ),
88
+ state_boards ,
89
+ state_cache ,
91
90
)
92
91
93
92
94
93
def make_plot (
95
- attention_layer , attention_head , square , quantity , func , trick , aggregate
94
+ attention_layer ,
95
+ attention_head ,
96
+ square ,
97
+ state_board_index ,
98
+ state_boards ,
99
+ state_cache ,
96
100
):
97
- global cache
98
- global boards
99
- global board_index
100
101
101
- if cache is None :
102
- gr .Warning ("Cache not computed! " )
103
- return None , None
102
+ if state_cache == [] :
103
+ gr .Warning ("No cache available. " )
104
+ return None , None , None
104
105
105
- board = boards [ board_index ]
106
- num_attention_layers = len (cache [ board_index ])
106
+ board = state_boards [ state_board_index ]
107
+ num_attention_layers = len (state_cache [ state_board_index ])
107
108
if attention_layer > num_attention_layers :
108
109
gr .Warning (
109
110
f"Attention layer { attention_layer } does not exist, "
110
111
f"using layer { num_attention_layers } instead."
111
112
)
112
113
attention_layer = num_attention_layers
113
114
114
- key = f"{ attention_layer - 1 } - { quantity } - { func } "
115
+ key = f"encoder { attention_layer - 1 } /mha/QK/softmax "
115
116
try :
116
- attention_tensor = cache [ board_index ][key ]
117
+ attention_tensor = state_cache [ state_board_index ][key ]
117
118
except KeyError :
118
119
gr .Warning (f"Combination { key } does not exist." )
119
120
return None , None , None
120
121
if attention_head > attention_tensor .shape [1 ]:
121
122
gr .Warning (
122
123
f"Attention head { attention_head } does not exist, "
123
- f"using head { attention_tensor .shape [1 ]} instead."
124
+ f"using head { attention_tensor .shape [1 ]+ 1 } instead."
124
125
)
125
126
attention_head = attention_tensor .shape [1 ]
126
127
try :
@@ -132,15 +133,7 @@ def make_plot(
132
133
if board .turn == chess .BLACK :
133
134
square_index = chess .square_mirror (square_index )
134
135
135
- if trick == "revert" :
136
- square_index = 63 - square_index
137
-
138
- if aggregate == "Row" :
139
- heatmap = attention_tensor [0 , attention_head - 1 , square_index , :]
140
- elif aggregate == "Column" :
141
- heatmap = attention_tensor [0 , attention_head - 1 , :, square_index ]
142
- else :
143
- heatmap = attention_tensor [0 , attention_head - 1 ]
136
+ heatmap = attention_tensor [0 , attention_head - 1 , square_index ]
144
137
if board .turn == chess .BLACK :
145
138
heatmap = heatmap .view (8 , 8 ).flip (0 ).view (64 )
146
139
svg_board , fig = visualisation .render_heatmap (
@@ -155,37 +148,49 @@ def previous_board(
155
148
attention_layer ,
156
149
attention_head ,
157
150
square ,
158
- from_to ,
159
- color_flip ,
160
- trick ,
161
- aggregate ,
151
+ state_board_index ,
152
+ state_boards ,
153
+ state_cache ,
162
154
):
163
- global board_index
164
- board_index -= 1
165
- if board_index < 0 :
155
+ state_board_index -= 1
156
+ if state_board_index < 0 :
166
157
gr .Warning ("Already at first board." )
167
- board_index = 0
168
- return make_plot (
169
- attention_layer , attention_head , square , from_to , color_flip
158
+ state_board_index = 0
159
+ return (
160
+ * make_plot (
161
+ attention_layer ,
162
+ attention_head ,
163
+ square ,
164
+ state_board_index ,
165
+ state_boards ,
166
+ state_cache ,
167
+ ),
168
+ state_board_index ,
170
169
)
171
170
172
171
173
172
def next_board (
174
173
attention_layer ,
175
174
attention_head ,
176
175
square ,
177
- from_to ,
178
- color_flip ,
179
- trick ,
180
- aggregate ,
176
+ state_board_index ,
177
+ state_boards ,
178
+ state_cache ,
181
179
):
182
- global board_index
183
- board_index += 1
184
- if board_index >= len (boards ):
180
+ state_board_index += 1
181
+ if state_board_index >= len (state_boards ):
185
182
gr .Warning ("Already at last board." )
186
- board_index = len (boards ) - 1
187
- return make_plot (
188
- attention_layer , attention_head , square , from_to , color_flip
183
+ state_board_index = len (state_boards ) - 1
184
+ return (
185
+ * make_plot (
186
+ attention_layer ,
187
+ attention_head ,
188
+ square ,
189
+ state_board_index ,
190
+ state_boards ,
191
+ state_cache ,
192
+ ),
193
+ state_board_index ,
189
194
)
190
195
191
196
@@ -254,38 +259,6 @@ def next_board(
254
259
value = "a1" ,
255
260
scale = 1 ,
256
261
)
257
- quantity = gr .Dropdown (
258
- label = "Quantity" ,
259
- choices = ["QK" , "Q" , "K" , "out" , "QKV" ],
260
- value = "QK" ,
261
- scale = 2 ,
262
- )
263
- aggregate = gr .Dropdown (
264
- label = "Aggregate" ,
265
- choices = ["Row" , "Column" , "None" ],
266
- value = "Row" ,
267
- scale = 2 ,
268
- )
269
- func = gr .Dropdown (
270
- label = "Function" ,
271
- choices = [
272
- "softmax" ,
273
- "transpose" ,
274
- "matmul" ,
275
- "scale" ,
276
- ],
277
- value = "softmax" ,
278
- scale = 2 ,
279
- )
280
- trick = gr .Dropdown (
281
- label = "Trick" ,
282
- choices = [
283
- "none" ,
284
- "revert" ,
285
- ],
286
- value = "none" ,
287
- scale = 2 ,
288
- )
289
262
with gr .Row ():
290
263
previous_board_button = gr .Button ("Previous board" )
291
264
next_board_button = gr .Button ("Next board" )
@@ -298,32 +271,34 @@ def next_board(
298
271
with gr .Column ():
299
272
image = gr .Image (label = "Board" )
300
273
274
+ state_board_index = gr .State (0 )
275
+ state_boards = gr .State ([])
276
+ state_cache = gr .State ([])
301
277
base_inputs = [
302
278
attention_layer ,
303
279
attention_head ,
304
280
square ,
305
- quantity ,
306
- func ,
307
- trick ,
308
- aggregate ,
281
+ state_board_index ,
282
+ state_boards ,
283
+ state_cache ,
309
284
]
310
285
outputs = [image , current_board_fen , colorbar ]
311
286
312
287
compute_cache_button .click (
313
288
compute_cache ,
314
289
inputs = [board_fen , action_seq , model_name ] + base_inputs ,
315
- outputs = outputs ,
290
+ outputs = outputs + [ state_boards , state_cache ] ,
316
291
)
317
292
318
293
previous_board_button .click (
319
- previous_board , inputs = base_inputs , outputs = outputs
294
+ previous_board ,
295
+ inputs = base_inputs ,
296
+ outputs = outputs + [state_board_index ],
297
+ )
298
+ next_board_button .click (
299
+ next_board , inputs = base_inputs , outputs = outputs + [state_board_index ]
320
300
)
321
- next_board_button .click (next_board , inputs = base_inputs , outputs = outputs )
322
301
323
302
attention_layer .change (make_plot , inputs = base_inputs , outputs = outputs )
324
303
attention_head .change (make_plot , inputs = base_inputs , outputs = outputs )
325
304
square .submit (make_plot , inputs = base_inputs , outputs = outputs )
326
- quantity .change (make_plot , inputs = base_inputs , outputs = outputs )
327
- func .change (make_plot , inputs = base_inputs , outputs = outputs )
328
- trick .change (make_plot , inputs = base_inputs , outputs = outputs )
329
- aggregate .change (make_plot , inputs = base_inputs , outputs = outputs )
0 commit comments