diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6541ac0..9dd6da1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,7 @@ repos: hooks: - id: mypy additional_dependencies: ['types-requests', 'types-toml'] + exclude: scripts - repo: https://github.com/pycqa/flake8 rev: 7.0.0 hooks: diff --git a/demo/board_interface.py b/demo/board_interface.py index 85c19b8..0948694 100644 --- a/demo/board_interface.py +++ b/demo/board_interface.py @@ -8,10 +8,7 @@ from demo import constants -def make_board_plot( - board_fen, - arrows, -): +def make_board_plot(board_fen, arrows, square): try: board = chess.Board(board_fen) except ValueError: @@ -35,10 +32,12 @@ def make_board_plot( chess_arrows = [] gr.Warning("Invalid arrows, using none.") + color_dict = {chess.parse_square(square): "#FF0000"} if square else {} svg_board = chess.svg.board( board, size=350, arrows=chess_arrows, + fill=color_dict, ) with open(f"{constants.FIGURE_DIRECTORY}/board.svg", "w") as f: f.write(svg_board) @@ -61,12 +60,20 @@ def make_board_plot( value="", placeholder="e2e4 e7e5", ) + square = gr.Textbox( + label="Square", + lines=1, + max_lines=1, + value="", + placeholder="e4", + ) with gr.Column(): image = gr.Image(label="Board", interactive=False) inputs = [ board_fen, arrows, + square, ] board_fen.submit(make_board_plot, inputs=inputs, outputs=image) arrows.submit(make_board_plot, inputs=inputs, outputs=image) diff --git a/scripts/cluster_latent_relevances.py b/scripts/cluster_latent_relevances.py index 5e72f2c..bbeb524 100644 --- a/scripts/cluster_latent_relevances.py +++ b/scripts/cluster_latent_relevances.py @@ -51,7 +51,7 @@ ####################################### -def legal_init_rel(board_list, board_tensor): +def legal_init_rel(board_list, out_tensor): legal_move_mask = torch.zeros((len(board_list), 1858)) for idx, board in enumerate(board_list): legal_moves = [ @@ -59,7 +59,7 @@ def legal_init_rel(board_list, board_tensor): for move in board.legal_moves ] legal_move_mask[idx, legal_moves] = 1 - return legal_move_mask * board_tensor + return legal_move_mask * out_tensor model = PolicyFlow.from_path(f"./assets/{model_name}") @@ -84,9 +84,10 @@ def legal_init_rel(board_list, board_tensor): _, board_tensor, labels = batch label_tensor = torch.tensor(labels) - def init_rel_fn(board_tensor): - rel = torch.zeros_like(board_tensor) - rel[:, label_tensor] = board_tensor[:, label_tensor] + def init_rel_fn(out_tensor): + rel = torch.zeros_like(out_tensor) + for i in range(rel.shape[0]): + rel[i, label_tensor[i]] = out_tensor[i, label_tensor[i]] return rel board_tensor.requires_grad = True @@ -203,9 +204,12 @@ def init_rel_fn(board_tensor): ) label_tensor = torch.tensor([label]) - def init_rel_fn(board_tensor): - rel = torch.zeros_like(board_tensor) - rel[:, label_tensor] = board_tensor[:, label_tensor] + def init_rel_fn(out_tensor): + rel = torch.zeros_like(out_tensor) + for i in range(rel.shape[0]): + rel[i, label_tensor[i]] = out_tensor[ + i, label_tensor[i] + ] return rel move = move_utils.decode_move( diff --git a/scripts/pixel_flipping.py b/scripts/pixel_flipping.py new file mode 100644 index 0000000..fb76c88 --- /dev/null +++ b/scripts/pixel_flipping.py @@ -0,0 +1,189 @@ +"""Script to perform pixel flipping. + +Run with: +```bash +poetry run python -m scripts.pixel_flipping +``` +""" + +import matplotlib.pyplot as plt +import numpy as np +import torch +from crp.attribution import CondAttribution + +from lczerolens.game import PolicyFlow +from lczerolens.xai import ConceptDataset, LrpLens +from lczerolens.xai.hook import HookConfig, ModifyHook + +####################################### +# HYPERPARAMETERS +####################################### +debug = False +model_name = "64x6-2018_0627_1913_08_161.onnx" +dataset_name = "TCEC_game_collection_random_boards_bestlegal.jsonl" +n_samples = 10 +n_steps = 100 +viz_name = "pixel_flipping_tcec_bestlegal" +####################################### + + +model = PolicyFlow.from_path(f"./assets/{model_name}") +concept_dataset = ConceptDataset(f"./assets/{dataset_name}", first_n=n_samples) +print(f"[INFO] Board dataset len: {len(concept_dataset)}") + + +layer_names = ["model.inputconv", "model.block0/conv2/relu"] +print(layer_names) +dataloader = torch.utils.data.DataLoader( + concept_dataset, + batch_size=n_samples, + shuffle=False, + collate_fn=ConceptDataset.collate_fn_tensor, +) +indices, board_tensor, labels = next(iter(dataloader)) +rule_names = ["default", "no_onnx"] + +morf_logit_dict = { + rule_name: {layer_name: [] for layer_name in layer_names} + for rule_name in rule_names +} +lerf_logit_dict = { + rule_name: {layer_name: [] for layer_name in layer_names} + for rule_name in rule_names +} + + +def mask_fn(output, modify_data): + if modify_data is None: + return output + else: + return output * modify_data + + +for logit_dict, morf in zip([morf_logit_dict, lerf_logit_dict], [True, False]): + for rule_name in rule_names: + if rule_name == "default": + composite = LrpLens.make_default_composite() + replace_onnx2torch = True + elif rule_name == "no_onnx": + composite = LrpLens.make_default_composite() + replace_onnx2torch = False + else: + raise ValueError(f"Unknown rule: {rule_name}") + for layer_name in layer_names: + hook_config = HookConfig( + module_exp=rf"^{layer_name}$", + data={layer_name: None}, + data_fn=mask_fn, + ) + hook = ModifyHook(hook_config) + hook.register(model) + for i in range(n_steps): + label_tensor = torch.tensor(labels) + + def init_rel_fn(out_tensor): + rel = torch.zeros_like(out_tensor) + for i in range(rel.shape[0]): + rel[i, label_tensor[i]] = out_tensor[ + i, label_tensor[i] + ] + return rel + + board_tensor.requires_grad = True + with LrpLens.context( + model, + composite=composite, + replace_onnx2torch=replace_onnx2torch, + ) as modifed_model: + attribution = CondAttribution(modifed_model) + attr = attribution( + board_tensor, + [{"y": None}], + composite, + record_layer=layer_names, + init_rel=init_rel_fn, + ) + latent_rel = attr.relevances[layer_name] + if morf: + to_flip = latent_rel.view( + board_tensor.shape[0], -1 + ).argmax(dim=1) + else: + to_flip = latent_rel.view( + board_tensor.shape[0], -1 + ).argmin(dim=1) + if hook.config.data[layer_name] is None: + mask_flat = torch.ones_like(latent_rel).view( + board_tensor.shape[0], -1 + ) + for i in range(mask_flat.shape[0]): + mask_flat[i, to_flip[i]] = 0 + hook.config.data[layer_name] = mask_flat.view_as( + latent_rel + ) + else: + old_mask_flat = hook.config.data[layer_name].view( + board_tensor.shape[0], -1 + ) + for i in range(old_mask_flat.shape[0]): + old_mask_flat[i, to_flip[i]] = 0 + hook.config.data[layer_name] = old_mask_flat.view_as( + latent_rel + ) + if debug: + print(f"[INFO] Most relevant pixels: {to_flip}") + logit_dict[rule_name][layer_name].append( + attr.prediction.gather( + 1, label_tensor.view(-1, 1) + ).detach() + ) + print(f"[INFO] Layer: {layer_name} done") + if debug: + print( + "[INFO] Logits: " + f"{torch.cat(logit_dict[rule_name][layer_name], dim=1)}" + ) + hook.remove() + hook.clear() + print(f"[INFO] Rule: {rule_name} done") + +fig, ax = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True) +for rule_name in rule_names: + for layer_name in layer_names: + morf_logits = torch.cat(morf_logit_dict[rule_name][layer_name], dim=1) + lerf_logits = torch.cat(lerf_logit_dict[rule_name][layer_name], dim=1) + diff = lerf_logits - morf_logits + means = diff.mean(dim=0) + stds = diff.std(dim=0) + ax[0].errorbar( + np.arange(means.shape[0]), + means, + yerr=stds, + label=f"{rule_name} {layer_name}", + ) + means = morf_logits.mean(dim=0) + stds = morf_logits.std(dim=0) + ax[1].errorbar( + np.arange(means.shape[0]), + means, + yerr=stds, + label=f"{rule_name} {layer_name}", + ) + means = lerf_logits.mean(dim=0) + stds = lerf_logits.std(dim=0) + ax[2].errorbar( + np.arange(means.shape[0]), + means, + yerr=stds, + label=f"{rule_name} {layer_name}", + ) +plt.sca(ax[0]) +plt.ylabel(f"Mean logit (n={n_samples})") +plt.legend() +plt.sca(ax[1]) +plt.xlabel("Pixels flipped", loc="center") +plt.legend() +plt.sca(ax[2]) +plt.legend() + +plt.savefig(f"./scripts/results/{viz_name}.png") diff --git a/scripts/register_wandb_datasets.py b/scripts/register_wandb_datasets.py index 0945fe2..50bca88 100644 --- a/scripts/register_wandb_datasets.py +++ b/scripts/register_wandb_datasets.py @@ -63,23 +63,23 @@ ) if ARGS.log_datasets: - wandb.login() # type: ignore - with wandb.init( # type: ignore + wandb.login() + with wandb.init( project="lczerolens-saes", job_type="make-datasets" ) as run: - artifact = wandb.Artifact("tcec_train", type="dataset") # type: ignore + artifact = wandb.Artifact("tcec_train", type="dataset") artifact.add_file( f"{ARGS.output_root}/assets/" "TCEC_game_collection_random_boards_train.jsonl" ) run.log_artifact(artifact) - artifact = wandb.Artifact("tcec_val", type="dataset") # type: ignore + artifact = wandb.Artifact("tcec_val", type="dataset") artifact.add_file( f"{ARGS.output_root}/assets/" "TCEC_game_collection_random_boards_val.jsonl" ) run.log_artifact(artifact) - artifact = wandb.Artifact("tcec_test", type="dataset") # type: ignore + artifact = wandb.Artifact("tcec_test", type="dataset") artifact.add_file( f"{ARGS.output_root}/assets/" "TCEC_game_collection_random_boards_test.jsonl" diff --git a/scripts/register_wandb_models.py b/scripts/register_wandb_models.py index 1debb0f..0e3a9ea 100644 --- a/scripts/register_wandb_models.py +++ b/scripts/register_wandb_models.py @@ -25,11 +25,9 @@ ARGS = parser.parse_args() if ARGS.log_models: - wandb.login() # type: ignore - with wandb.init( # type: ignore - project="lczerolens-saes", job_type="make-models" - ) as run: + wandb.login() + with wandb.init(project="lczerolens-saes", job_type="make-models") as run: for model_name, model_path in models.items(): - artifact = wandb.Artifact(model_name, type="model") # type: ignore + artifact = wandb.Artifact(model_name, type="model") artifact.add_file(f"./assets/{model_path}") run.log_artifact(artifact) diff --git a/src/lczerolens/xai/concept.py b/src/lczerolens/xai/concept.py index 8c6cbdc..1add36b 100644 --- a/src/lczerolens/xai/concept.py +++ b/src/lczerolens/xai/concept.py @@ -177,6 +177,7 @@ def __init__( game_ids: Optional[List[str]] = None, concept: Optional[Concept] = None, labels: Optional[List[Any]] = None, + first_n: Optional[int] = None, ): if file_name is None: super().__init__(file_name, boards, game_ids) @@ -193,6 +194,8 @@ def __init__( self.boards.append(board) self.game_ids.append(obj["gameid"]) self.labels.append(obj["label"]) + if first_n is not None and len(self.boards) >= first_n: + break self._concept = concept if concept is not None else NullConcept() if labels is not None: self.labels = labels diff --git a/src/lczerolens/xai/hook.py b/src/lczerolens/xai/hook.py index 68b7aa2..b6a77bf 100644 --- a/src/lczerolens/xai/hook.py +++ b/src/lczerolens/xai/hook.py @@ -139,13 +139,13 @@ class MeasureHook(Hook): """ def forward_factory(self, name: str): - if self.config.data is not None: - measure_data = self.config.data[name] - else: - measure_data = None if self.config.hook_mode is HookMode.INPUT: def hook(module, input, output): + if self.config.data is not None: + measure_data = self.config.data[name] + else: + measure_data = None self.storage[name] = self.config.data_fn( input.detach(), measure_data=measure_data ) @@ -153,6 +153,10 @@ def hook(module, input, output): elif self.config.hook_mode is HookMode.OUTPUT: def hook(module, input, output): + if self.config.data is not None: + measure_data = self.config.data[name] + else: + measure_data = None self.storage[name] = self.config.data_fn( output.detach(), measure_data=measure_data ) @@ -174,19 +178,18 @@ class ModifyHook(Hook): """ def forward_factory(self, name: str): - if self.config.data is not None: - modify_data = self.config.data[name] - else: - modify_data = None if self.config.hook_mode is HookMode.INPUT: - - def hook(module, input, output): - input = self.config.data_fn(input, modify_data=modify_data) - return input + raise NotImplementedError( + "Input hook not implemented for ModifyHook" + ) elif self.config.hook_mode is HookMode.OUTPUT: def hook(module, input, output): + if self.config.data is not None: + modify_data = self.config.data[name] + else: + modify_data = None output = self.config.data_fn(output, modify_data=modify_data) return output