Skip to content

Commit

Permalink
Flipping Latent Pixels (#15)
Browse files Browse the repository at this point in the history
* minor changes & pixel flipping script

* no type ignore

* square
  • Loading branch information
Xmaster6y authored Apr 24, 2024
1 parent 83742dd commit d6c5011
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 34 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ repos:
hooks:
- id: mypy
additional_dependencies: ['types-requests', 'types-toml']
exclude: scripts
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
Expand Down
15 changes: 11 additions & 4 deletions demo/board_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
from demo import constants


def make_board_plot(
board_fen,
arrows,
):
def make_board_plot(board_fen, arrows, square):
try:
board = chess.Board(board_fen)
except ValueError:
Expand All @@ -35,10 +32,12 @@ def make_board_plot(
chess_arrows = []
gr.Warning("Invalid arrows, using none.")

color_dict = {chess.parse_square(square): "#FF0000"} if square else {}
svg_board = chess.svg.board(
board,
size=350,
arrows=chess_arrows,
fill=color_dict,
)
with open(f"{constants.FIGURE_DIRECTORY}/board.svg", "w") as f:
f.write(svg_board)
Expand All @@ -61,12 +60,20 @@ def make_board_plot(
value="",
placeholder="e2e4 e7e5",
)
square = gr.Textbox(
label="Square",
lines=1,
max_lines=1,
value="",
placeholder="e4",
)
with gr.Column():
image = gr.Image(label="Board", interactive=False)

inputs = [
board_fen,
arrows,
square,
]
board_fen.submit(make_board_plot, inputs=inputs, outputs=image)
arrows.submit(make_board_plot, inputs=inputs, outputs=image)
Expand Down
20 changes: 12 additions & 8 deletions scripts/cluster_latent_relevances.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@
#######################################


def legal_init_rel(board_list, board_tensor):
def legal_init_rel(board_list, out_tensor):
legal_move_mask = torch.zeros((len(board_list), 1858))
for idx, board in enumerate(board_list):
legal_moves = [
move_utils.encode_move(move, (board.turn, not board.turn))
for move in board.legal_moves
]
legal_move_mask[idx, legal_moves] = 1
return legal_move_mask * board_tensor
return legal_move_mask * out_tensor


model = PolicyFlow.from_path(f"./assets/{model_name}")
Expand All @@ -84,9 +84,10 @@ def legal_init_rel(board_list, board_tensor):
_, board_tensor, labels = batch
label_tensor = torch.tensor(labels)

def init_rel_fn(board_tensor):
rel = torch.zeros_like(board_tensor)
rel[:, label_tensor] = board_tensor[:, label_tensor]
def init_rel_fn(out_tensor):
rel = torch.zeros_like(out_tensor)
for i in range(rel.shape[0]):
rel[i, label_tensor[i]] = out_tensor[i, label_tensor[i]]
return rel

board_tensor.requires_grad = True
Expand Down Expand Up @@ -203,9 +204,12 @@ def init_rel_fn(board_tensor):
)
label_tensor = torch.tensor([label])

def init_rel_fn(board_tensor):
rel = torch.zeros_like(board_tensor)
rel[:, label_tensor] = board_tensor[:, label_tensor]
def init_rel_fn(out_tensor):
rel = torch.zeros_like(out_tensor)
for i in range(rel.shape[0]):
rel[i, label_tensor[i]] = out_tensor[
i, label_tensor[i]
]
return rel

move = move_utils.decode_move(
Expand Down
189 changes: 189 additions & 0 deletions scripts/pixel_flipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""Script to perform pixel flipping.
Run with:
```bash
poetry run python -m scripts.pixel_flipping
```
"""

import matplotlib.pyplot as plt
import numpy as np
import torch
from crp.attribution import CondAttribution

from lczerolens.game import PolicyFlow
from lczerolens.xai import ConceptDataset, LrpLens
from lczerolens.xai.hook import HookConfig, ModifyHook

#######################################
# HYPERPARAMETERS
#######################################
debug = False
model_name = "64x6-2018_0627_1913_08_161.onnx"
dataset_name = "TCEC_game_collection_random_boards_bestlegal.jsonl"
n_samples = 10
n_steps = 100
viz_name = "pixel_flipping_tcec_bestlegal"
#######################################


model = PolicyFlow.from_path(f"./assets/{model_name}")
concept_dataset = ConceptDataset(f"./assets/{dataset_name}", first_n=n_samples)
print(f"[INFO] Board dataset len: {len(concept_dataset)}")


layer_names = ["model.inputconv", "model.block0/conv2/relu"]
print(layer_names)
dataloader = torch.utils.data.DataLoader(
concept_dataset,
batch_size=n_samples,
shuffle=False,
collate_fn=ConceptDataset.collate_fn_tensor,
)
indices, board_tensor, labels = next(iter(dataloader))
rule_names = ["default", "no_onnx"]

morf_logit_dict = {
rule_name: {layer_name: [] for layer_name in layer_names}
for rule_name in rule_names
}
lerf_logit_dict = {
rule_name: {layer_name: [] for layer_name in layer_names}
for rule_name in rule_names
}


def mask_fn(output, modify_data):
if modify_data is None:
return output
else:
return output * modify_data


for logit_dict, morf in zip([morf_logit_dict, lerf_logit_dict], [True, False]):
for rule_name in rule_names:
if rule_name == "default":
composite = LrpLens.make_default_composite()
replace_onnx2torch = True
elif rule_name == "no_onnx":
composite = LrpLens.make_default_composite()
replace_onnx2torch = False
else:
raise ValueError(f"Unknown rule: {rule_name}")
for layer_name in layer_names:
hook_config = HookConfig(
module_exp=rf"^{layer_name}$",
data={layer_name: None},
data_fn=mask_fn,
)
hook = ModifyHook(hook_config)
hook.register(model)
for i in range(n_steps):
label_tensor = torch.tensor(labels)

def init_rel_fn(out_tensor):
rel = torch.zeros_like(out_tensor)
for i in range(rel.shape[0]):
rel[i, label_tensor[i]] = out_tensor[
i, label_tensor[i]
]
return rel

board_tensor.requires_grad = True
with LrpLens.context(
model,
composite=composite,
replace_onnx2torch=replace_onnx2torch,
) as modifed_model:
attribution = CondAttribution(modifed_model)
attr = attribution(
board_tensor,
[{"y": None}],
composite,
record_layer=layer_names,
init_rel=init_rel_fn,
)
latent_rel = attr.relevances[layer_name]
if morf:
to_flip = latent_rel.view(
board_tensor.shape[0], -1
).argmax(dim=1)
else:
to_flip = latent_rel.view(
board_tensor.shape[0], -1
).argmin(dim=1)
if hook.config.data[layer_name] is None:
mask_flat = torch.ones_like(latent_rel).view(
board_tensor.shape[0], -1
)
for i in range(mask_flat.shape[0]):
mask_flat[i, to_flip[i]] = 0
hook.config.data[layer_name] = mask_flat.view_as(
latent_rel
)
else:
old_mask_flat = hook.config.data[layer_name].view(
board_tensor.shape[0], -1
)
for i in range(old_mask_flat.shape[0]):
old_mask_flat[i, to_flip[i]] = 0
hook.config.data[layer_name] = old_mask_flat.view_as(
latent_rel
)
if debug:
print(f"[INFO] Most relevant pixels: {to_flip}")
logit_dict[rule_name][layer_name].append(
attr.prediction.gather(
1, label_tensor.view(-1, 1)
).detach()
)
print(f"[INFO] Layer: {layer_name} done")
if debug:
print(
"[INFO] Logits: "
f"{torch.cat(logit_dict[rule_name][layer_name], dim=1)}"
)
hook.remove()
hook.clear()
print(f"[INFO] Rule: {rule_name} done")

fig, ax = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True)
for rule_name in rule_names:
for layer_name in layer_names:
morf_logits = torch.cat(morf_logit_dict[rule_name][layer_name], dim=1)
lerf_logits = torch.cat(lerf_logit_dict[rule_name][layer_name], dim=1)
diff = lerf_logits - morf_logits
means = diff.mean(dim=0)
stds = diff.std(dim=0)
ax[0].errorbar(
np.arange(means.shape[0]),
means,
yerr=stds,
label=f"{rule_name} {layer_name}",
)
means = morf_logits.mean(dim=0)
stds = morf_logits.std(dim=0)
ax[1].errorbar(
np.arange(means.shape[0]),
means,
yerr=stds,
label=f"{rule_name} {layer_name}",
)
means = lerf_logits.mean(dim=0)
stds = lerf_logits.std(dim=0)
ax[2].errorbar(
np.arange(means.shape[0]),
means,
yerr=stds,
label=f"{rule_name} {layer_name}",
)
plt.sca(ax[0])
plt.ylabel(f"Mean logit (n={n_samples})")
plt.legend()
plt.sca(ax[1])
plt.xlabel("Pixels flipped", loc="center")
plt.legend()
plt.sca(ax[2])
plt.legend()

plt.savefig(f"./scripts/results/{viz_name}.png")
10 changes: 5 additions & 5 deletions scripts/register_wandb_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,23 @@
)

if ARGS.log_datasets:
wandb.login() # type: ignore
with wandb.init( # type: ignore
wandb.login()
with wandb.init(
project="lczerolens-saes", job_type="make-datasets"
) as run:
artifact = wandb.Artifact("tcec_train", type="dataset") # type: ignore
artifact = wandb.Artifact("tcec_train", type="dataset")
artifact.add_file(
f"{ARGS.output_root}/assets/"
"TCEC_game_collection_random_boards_train.jsonl"
)
run.log_artifact(artifact)
artifact = wandb.Artifact("tcec_val", type="dataset") # type: ignore
artifact = wandb.Artifact("tcec_val", type="dataset")
artifact.add_file(
f"{ARGS.output_root}/assets/"
"TCEC_game_collection_random_boards_val.jsonl"
)
run.log_artifact(artifact)
artifact = wandb.Artifact("tcec_test", type="dataset") # type: ignore
artifact = wandb.Artifact("tcec_test", type="dataset")
artifact.add_file(
f"{ARGS.output_root}/assets/"
"TCEC_game_collection_random_boards_test.jsonl"
Expand Down
8 changes: 3 additions & 5 deletions scripts/register_wandb_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
ARGS = parser.parse_args()

if ARGS.log_models:
wandb.login() # type: ignore
with wandb.init( # type: ignore
project="lczerolens-saes", job_type="make-models"
) as run:
wandb.login()
with wandb.init(project="lczerolens-saes", job_type="make-models") as run:
for model_name, model_path in models.items():
artifact = wandb.Artifact(model_name, type="model") # type: ignore
artifact = wandb.Artifact(model_name, type="model")
artifact.add_file(f"./assets/{model_path}")
run.log_artifact(artifact)
3 changes: 3 additions & 0 deletions src/lczerolens/xai/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
game_ids: Optional[List[str]] = None,
concept: Optional[Concept] = None,
labels: Optional[List[Any]] = None,
first_n: Optional[int] = None,
):
if file_name is None:
super().__init__(file_name, boards, game_ids)
Expand All @@ -193,6 +194,8 @@ def __init__(
self.boards.append(board)
self.game_ids.append(obj["gameid"])
self.labels.append(obj["label"])
if first_n is not None and len(self.boards) >= first_n:
break
self._concept = concept if concept is not None else NullConcept()
if labels is not None:
self.labels = labels
Expand Down
Loading

0 comments on commit d6c5011

Please sign in to comment.