Skip to content

Commit 4d5fa69

Browse files
authored
Extended clustering script and apptainer (#11)
* expanded clustering * running scripts with apptainer * gpu label
1 parent 0fc5b6b commit 4d5fa69

File tree

12 files changed

+308
-82
lines changed

12 files changed

+308
-82
lines changed

apptainer/.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
*
2+
!.gitignore
3+
!base.def
4+
!script.def
5+
!make-datasets.sh

apptainer/base.def

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Bootstrap: docker
2+
From: python:3.9.18
3+
4+
%files
5+
./assets /opt/assets
6+
./src /opt/src
7+
./pyproject.toml /opt/pyproject.toml
8+
./poetry.lock /opt/poetry.lock
9+
./README.md /opt/README.md
10+
11+
%environment
12+
export "PATH=/opt/.venv/bin:$PATH"
13+
14+
%post
15+
python -m pip install poetry
16+
17+
cd /opt
18+
python -m poetry config virtualenvs.in-project true
19+
python -m poetry install

apptainer/make-datasets.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
3+
#SBATCH --mail-type=ALL
4+
#SBATCH --mail-user=<[email protected]>
5+
#SBATCH --job-name=apptainer
6+
#SBATCH --output=%j_%x.out
7+
#SBATCH --nodes=1
8+
#SBATCH --ntasks=1
9+
#SBATCH --cpus-per-task=4
10+
#SBATCH --gpus=1
11+
#SBATCH --mem=32G
12+
#SBATCH --time=1:00:00
13+
14+
#####################################################################################
15+
16+
# This included file contains the definition for $LOCAL_JOB_DIR to be used locally on the node.
17+
source "/etc/slurm/local_job_dir.sh"
18+
19+
# Launch the apptainer image with --nv for nvidia support. Two bind mounts are used:
20+
# - One for the ImageNet dataset and
21+
# - One for the results (e.g. checkpoint data that you may store in $LOCAL_JOB_DIR on the node
22+
timeout 24h apptainer exec --nv --bind ${LOCAL_JOB_DIR}:/opt/output \
23+
./apptainer/script.sif python -m scripts.make_datasets \
24+
--output-root /opt/output
25+
26+
# This command copies all results generated in $LOCAL_JOB_DIR back to the submit folder regarding the job id.
27+
cp -r ${LOCAL_JOB_DIR} ${SLURM_SUBMIT_DIR}/${SLURM_JOB_ID}
28+
29+
echo "$PWD/${SLURM_JOB_ID}_stats.out" > $LOCAL_JOB_DIR/stats_file_loc_cfg

apptainer/script.def

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Bootstrap: localimage
2+
From: ./apptainer/base.sif
3+
4+
%files
5+
./scripts/*.py /opt/scripts/
6+
7+
%runscript
8+
cd /opt/
9+
echo "Running script"

scripts/cluster_latent_relevances.py

Lines changed: 108 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88

99
import os
1010

11-
import chess
1211
import matplotlib.pyplot as plt
1312
import numpy as np
1413
import torch
1514
from crp.attribution import CondAttribution
16-
from crp.concepts import ChannelConcept
1715
from pylatex import Document
1816
from pylatex.package import Package
1917
from safetensors import safe_open
@@ -30,17 +28,26 @@
3028
#######################################
3129
# HYPERPARAMETERS
3230
#######################################
33-
n_clusters = 10
31+
n_clusters = 15
3432
batch_size = 500
35-
save_files = True
36-
conv_sum_dims = (2, 3)
33+
save_files = False
3734
model_name = "64x6-2018_0627_1913_08_161.onnx"
3835
dataset_name = "TCEC_game_collection_random_boards_bestlegal_knight.jsonl"
3936
only_config_rel = True
4037
best_legal = True
4138
run_name = (
4239
f"bestres_tcec_bestlegal_knight_{'expbest' if best_legal else 'full'}"
4340
)
41+
n_samples = 1000
42+
conv_sum_dims = ()
43+
cosine_sim = False
44+
kmeans_on_tsne = True
45+
viz_latent = True
46+
viz_name = (
47+
f"{'latent' if viz_latent else 'input'}"
48+
f"_nosum_{'cosine' if cosine_sim else 'norm'}"
49+
f"_{'after' if kmeans_on_tsne else 'before'}-tsne"
50+
)
4451
#######################################
4552

4653

@@ -60,7 +67,6 @@ def legal_init_rel(board_list, board_tensor):
6067
print(f"[INFO] Board dataset len: {len(concept_dataset)}")
6168

6269
composite = LrpLens.make_default_composite()
63-
cc = ChannelConcept()
6470
layer_names = [f"model.block{b}/conv2/relu" for b in [0, 3, 5]]
6571
print(layer_names)
6672

@@ -96,9 +102,6 @@ def init_rel_fn(board_tensor):
96102

97103
for layer_name in layer_names:
98104
latent_rel = attr.relevances[layer_name]
99-
latent_rel = cc.attribute(latent_rel, abs_norm=True)
100-
if len(latent_rel.shape) == 4:
101-
latent_rel = latent_rel.sum(conv_sum_dims)
102105
if layer_name not in all_relevances:
103106
all_relevances[layer_name] = latent_rel.detach().cpu()
104107
else:
@@ -131,23 +134,34 @@ def init_rel_fn(board_tensor):
131134
#######################################
132135

133136
print("############ Clustering ...")
134-
os.makedirs(f"scripts/results/{run_name}", exist_ok=True)
137+
os.makedirs(f"scripts/results/{run_name}/{viz_name}", exist_ok=True)
135138

136139
for layer_name, relevances in all_relevances.items():
137-
kmeans = KMeans(n_clusters=n_clusters, init="k-means++")
138-
kmeans.fit(relevances)
140+
relevances = relevances[:n_samples]
141+
if conv_sum_dims:
142+
relevances = relevances.sum(dim=conv_sum_dims).view(
143+
relevances.shape[0], -1
144+
)
145+
else:
146+
relevances = relevances.view(relevances.shape[0], -1)
139147

140148
# Perform t-SNE dimensionality reduction
141149
tsne = TSNE(n_components=2)
142150
latent_rel_tsne = tsne.fit_transform(relevances)
143151

152+
if kmeans_on_tsne:
153+
relevances = latent_rel_tsne
154+
kmeans = KMeans(n_clusters=n_clusters, init="k-means++")
155+
kmeans.fit(relevances)
156+
144157
# Plot the clustered data
145158
plt.scatter(latent_rel_tsne[:, 0], latent_rel_tsne[:, 1], c=kmeans.labels_)
146159
plt.title("Clustered Latent Relevances")
147160
plt.xlabel("Dimension 1")
148161
plt.ylabel("Dimension 2")
149162
plt.savefig(
150-
f"scripts/results/{run_name}/{layer_name.replace('/','.')}_t-sne.png"
163+
f"scripts/results/{run_name}/{viz_name}/"
164+
f"{layer_name.replace('/','.')}_t-sne.png"
151165
)
152166
plt.close()
153167

@@ -160,10 +174,25 @@ def init_rel_fn(board_tensor):
160174
attribution = CondAttribution(modifed_model)
161175
for idx_cluster in tqdm(range(n_clusters)):
162176
cluster_center = kmeans.cluster_centers_[idx_cluster]
163-
distances = np.linalg.norm(relevances - cluster_center, axis=1)
164-
nearest_neighbors = np.argsort(distances)[:8]
177+
if cosine_sim:
178+
dot_prod = relevances @ cluster_center.T
179+
similarities = dot_prod / (
180+
np.linalg.norm(relevances, axis=1)
181+
* np.linalg.norm(cluster_center)
182+
)
183+
nearest_neighbors = np.argsort(similarities)[-8:]
184+
else:
185+
distances = np.linalg.norm(relevances - cluster_center, axis=1)
186+
nearest_neighbors = np.argsort(distances)[:8]
165187

166-
doc = Document() # create a new document
188+
doc = Document(
189+
geometry_options={
190+
"lmargin": "3cm",
191+
"tmargin": "0.5cm",
192+
"bmargin": "1.5cm",
193+
"rmargin": "3cm",
194+
}
195+
)
167196
doc.packages.append(Package("xskak"))
168197

169198
# compute heatmap for each nearest neighbor
@@ -179,37 +208,83 @@ def init_rel_fn(board_tensor):
179208
rel[:, label_tensor] = board_tensor[:, label_tensor]
180209
return rel
181210

182-
board_tensor.requires_grad = True
183-
attr = attribution(
184-
board_tensor,
185-
[{"y": None}],
186-
composite,
187-
init_rel=init_rel_fn if best_legal else None,
188-
)
189-
if only_config_rel:
190-
heatmap = board_tensor.grad[0, :12].sum(dim=0).view(64)
191-
else:
192-
heatmap = board_tensor.grad[0].sum(dim=0).view(64)
193-
if board.turn == chess.BLACK:
194-
heatmap = heatmap.view(8, 8).flip(0).view(64)
195211
move = move_utils.decode_move(
196212
label, (board.turn, not board.turn), board
197213
)
198214
uci_move = move.uci()
199-
heatmap = heatmap / heatmap.abs().max()
200-
heatmap_str = create_heatmap_string(heatmap)
215+
216+
if viz_latent:
217+
latent_rel = all_relevances[layer_name][idx_sample]
218+
if not board.turn:
219+
latent_rel = latent_rel.flip(1)
220+
latent_rel = latent_rel.view(-1, 64)
221+
channel_rels = latent_rel.abs().sum(dim=1)
222+
c1, c2 = torch.topk(channel_rels, 2).indices
223+
heatmap_str_list = [
224+
create_heatmap_string(latent_rel.sum(0), abs_max=True),
225+
create_heatmap_string(latent_rel[c1], abs_max=True),
226+
create_heatmap_string(latent_rel[c2], abs_max=True),
227+
]
228+
heatmap_caption_list = [
229+
"Total relevance",
230+
"Best channel",
231+
"Second best channel",
232+
]
233+
add_caption = "latent"
234+
else:
235+
board_tensor.requires_grad = True
236+
attr = attribution(
237+
board_tensor,
238+
[{"y": None}],
239+
composite,
240+
init_rel=init_rel_fn if best_legal else None,
241+
)
242+
input_relevances = board_tensor.grad
243+
if not board.turn:
244+
input_relevances = (
245+
input_relevances.view(112, 8, 8)
246+
.flip(1)
247+
.view(112, 64)
248+
)
249+
input_relevances = input_relevances.view(112, 64)
250+
heatmap_str_list = [
251+
create_heatmap_string(
252+
input_relevances.sum(dim=0), abs_max=True
253+
),
254+
create_heatmap_string(
255+
input_relevances[:13].sum(dim=0), abs_max=True
256+
),
257+
create_heatmap_string(
258+
input_relevances[104:].sum(dim=0), abs_max=True
259+
),
260+
]
261+
heatmap_caption_list = [
262+
"Total relevance",
263+
"Current config relevance",
264+
"Meta relevance",
265+
]
266+
h0 = input_relevances[:13].abs().sum()
267+
hist = input_relevances[13:104].abs().sum()
268+
meta = input_relevances[104:].abs().sum()
269+
total = (h0 + hist + meta) / 100
270+
add_caption = (
271+
f"{h0/total:.0f}%|{hist/total:.0f}%|{meta/total:.0f}%"
272+
)
201273

202274
doc = add_plot(
203275
doc,
204276
board.fen(),
205-
heatmap_str,
277+
heatmap_str_list,
206278
current_piece_pos=uci_move[:2],
207279
next_move=uci_move[2:4],
280+
caption=f"Sample {idx_sample} - {add_caption}",
281+
heatmap_caption_list=heatmap_caption_list,
208282
)
209283

210284
# Generate pdf
211285
doc.generate_pdf(
212286
f"scripts/results/{run_name}"
213-
f"/{layer_name.replace('/','.')}_cluster_{idx_cluster}",
287+
f"/{viz_name}/{layer_name.replace('/','.')}"
288+
f"_cluster_{idx_cluster}",
214289
clean_tex=True,
215290
)

scripts/create_figure.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,54 +3,70 @@
33

44

55
import chess
6-
from pylatex import Figure, MiniPage, NoEscape
6+
from pylatex import Figure, NoEscape, SubFigure
77

88

99
def add_plot(
1010
doc,
1111
label,
12-
heatmap_str,
12+
heatmap_str_list,
1313
current_piece_pos=None,
1414
next_move=None,
1515
caption=None,
16+
heatmap_caption_list=None,
1617
):
1718
# Put some data inside the Figure environment
1819
with doc.create(Figure()) as fig:
20+
doc.append(NoEscape(r"\centering"))
1921
if caption is not None:
2022
fig.add_caption(caption)
2123
verbatim = NoEscape(
2224
r"\storechessboardstyle{8x8}{maxfield=h8,showmover=true}"
2325
)
2426
doc.append(verbatim)
2527

26-
with doc.create(MiniPage(width=r"0.45\textwidth")):
28+
with doc.create(
29+
SubFigure(
30+
width=NoEscape(r"0.45\textwidth"),
31+
)
32+
) as subfig:
33+
subfig.add_caption("Board")
34+
doc.append(NoEscape(r"\chessboard[style=8x8,"))
2735
if current_piece_pos is not None:
2836
markmove = current_piece_pos + "-" + next_move
2937
markfields = (
3038
"{{" + current_piece_pos + "},{" + next_move + "}}"
3139
)
3240
chessboard_fen = NoEscape(
33-
rf"\chessboard[style=8x8,setfen={label},showmover=true,"
41+
rf"setfen={label},showmover=true,"
3442
rf"color=green,pgfstyle=straightmove,markmove={markmove},"
35-
rf"pgfstyle=border,color=red,markfields={markfields},] "
43+
rf"pgfstyle=border,color=red,markfields={markfields},]"
3644
)
3745
else:
3846
chessboard_fen = NoEscape(
3947
rf"\chessboard[style=8x8,setfen={label},"
40-
"showmover=true,pgfstyle=straightmove,color=green,] "
48+
"showmover=true,pgfstyle=straightmove,color=green,]"
4149
)
4250
doc.append(chessboard_fen)
43-
doc.append(NoEscape("\hfill")) # noqa
44-
with doc.create(MiniPage(width=r"0.45\textwidth")):
45-
heatmap_begin = NoEscape(r"\chessboard[style=8x8,showmover=false,")
46-
doc.append(heatmap_begin)
51+
for i, heatmap_str in enumerate(heatmap_str_list):
52+
doc.append(NoEscape(r"\hfill"))
53+
with doc.create(
54+
SubFigure(width=NoEscape(r"0.45\textwidth"))
55+
) as subfig:
56+
subfig.add_caption(heatmap_caption_list[i])
57+
heatmap_begin = NoEscape(
58+
r"\chessboard[style=8x8,showmover=false,"
59+
)
60+
doc.append(heatmap_begin)
4761

48-
heatmap_end = NoEscape(heatmap_str) + NoEscape(r"]")
49-
doc.append(heatmap_end)
62+
heatmap_end = NoEscape(heatmap_str) + NoEscape(r"]")
63+
doc.append(heatmap_end)
5064
return doc
5165

5266

53-
def create_heatmap_string(heatmap):
67+
def create_heatmap_string(heatmap, abs_max=True):
68+
if abs_max:
69+
heatmap = heatmap / heatmap.abs().max()
5470
heatmap_str = ""
5571
for idx, name in enumerate(chess.SQUARE_NAMES):
5672
colorcode = heatmap[idx]

0 commit comments

Comments
 (0)