diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9dd6da1..9b4cb97 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,4 @@ repos: -- repo: https://github.com/psf/black - rev: 24.2.0 - hooks: - - id: black - args: ["--config", "pyproject.toml"] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: @@ -19,20 +14,9 @@ repos: rev: 1.7.0 hooks: - id: poetry-check -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 - hooks: - - id: mypy - additional_dependencies: ['types-requests', 'types-toml'] - exclude: scripts -- repo: https://github.com/pycqa/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - args: ['--ignore=E203,W503', '--per-file-ignores=__init__.py:F401'] -- repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - args: ["--settings-path", "pyproject.toml"] - name: isort (python) +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.2 + hooks: + - id: ruff + args: [ --fix ] + - id: ruff-format diff --git a/LICENSE b/LICENSE index 88c5637..dba5260 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Yoann Poupart +Copyright (c) 2024 Yoann Poupart Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index 0368ae6..400efd8 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,3 @@ -# CI .PHONY: checks checks: poetry run pre-commit run --all-files @@ -19,28 +18,6 @@ tests: docs: cd docs && poetry run make html -# API .PHONY: demo demo: poetry run python -m demo.main - -# Docker -.PHONY: docker-build -docker-build: - docker compose -f docker/docker-compose.yml build - -.PHONY: docker-start -docker-start: - docker compose -f docker/docker-compose.yml up - -.PHONY: docker-start-bg -docker-start-bg: - docker compose -f docker/docker-compose.yml up -d --build - -.PHONY: docker-stop -docker-stop: - docker compose -f docker/docker-compose.yml down - -.PHONY: docker-tty -docker-tty: - docker compose -f docker/docker-compose.yml exec fastapi bash diff --git a/apptainer/.gitignore b/apptainer/.gitignore deleted file mode 100644 index 1b26f6e..0000000 --- a/apptainer/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -* -!.gitignore -!base.def -!script.def -!make-datasets.sh diff --git a/apptainer/base.def b/apptainer/base.def deleted file mode 100644 index 288ee28..0000000 --- a/apptainer/base.def +++ /dev/null @@ -1,19 +0,0 @@ -Bootstrap: docker -From: python:3.9.18 - -%files - ./assets /opt/assets - ./src /opt/src - ./pyproject.toml /opt/pyproject.toml - ./poetry.lock /opt/poetry.lock - ./README.md /opt/README.md - -%environment - export "PATH=/opt/.venv/bin:$PATH" - -%post - python -m pip install poetry - - cd /opt - python -m poetry config virtualenvs.in-project true - python -m poetry install diff --git a/apptainer/make-datasets.sh b/apptainer/make-datasets.sh deleted file mode 100644 index 64a2697..0000000 --- a/apptainer/make-datasets.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -#SBATCH --mail-type=ALL -#SBATCH --mail-user= -#SBATCH --job-name=apptainer -#SBATCH --output=%j_%x.out -#SBATCH --nodes=1 -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=4 -#SBATCH --gpus=1 -#SBATCH --mem=32G -#SBATCH --time=1:00:00 - -##################################################################################### - -# This included file contains the definition for $LOCAL_JOB_DIR to be used locally on the node. -source "/etc/slurm/local_job_dir.sh" - -# Launch the apptainer image with --nv for nvidia support. Two bind mounts are used: -# - One for the ImageNet dataset and -# - One for the results (e.g. checkpoint data that you may store in $LOCAL_JOB_DIR on the node -timeout 24h apptainer exec --nv --bind ${LOCAL_JOB_DIR}:/opt/output \ - ./apptainer/script.sif python -m scripts.make_datasets \ - --output-root /opt/output - -# This command copies all results generated in $LOCAL_JOB_DIR back to the submit folder regarding the job id. -cp -r ${LOCAL_JOB_DIR} ${SLURM_SUBMIT_DIR}/${SLURM_JOB_ID} - -echo "$PWD/${SLURM_JOB_ID}_stats.out" > $LOCAL_JOB_DIR/stats_file_loc_cfg diff --git a/apptainer/script.def b/apptainer/script.def deleted file mode 100644 index 5e5fa8e..0000000 --- a/apptainer/script.def +++ /dev/null @@ -1,9 +0,0 @@ -Bootstrap: localimage -From: ./apptainer/base.sif - -%files - ./scripts/*.py /opt/scripts/ - -%runscript - cd /opt/ - echo "Running script" diff --git a/demo/attention_interface.py b/demo/attention_interface.py index d324fc8..2107393 100644 --- a/demo/attention_interface.py +++ b/demo/attention_interface.py @@ -98,7 +98,6 @@ def make_plot( state_boards, state_cache, ): - if state_cache == []: gr.Warning("No cache available.") return None, None, None @@ -107,8 +106,7 @@ def make_plot( num_attention_layers = len(state_cache[state_board_index]) if attention_layer > num_attention_layers: gr.Warning( - f"Attention layer {attention_layer} does not exist, " - f"using layer {num_attention_layers} instead." + f"Attention layer {attention_layer} does not exist, " f"using layer {num_attention_layers} instead." ) attention_layer = num_attention_layers @@ -120,8 +118,7 @@ def make_plot( return None, None, None if attention_head > attention_tensor.shape[1]: gr.Warning( - f"Attention head {attention_head} does not exist, " - f"using head {attention_tensor.shape[1]+1} instead." + f"Attention head {attention_head} does not exist, " f"using head {attention_tensor.shape[1]+1} instead." ) attention_head = attention_tensor.shape[1] try: @@ -136,9 +133,7 @@ def make_plot( heatmap = attention_tensor[0, attention_head - 1, square_index] if board.turn == chess.BLACK: heatmap = heatmap.view(8, 8).flip(0).view(64) - svg_board, fig = visualisation.render_heatmap( - board, heatmap, square=square - ) + svg_board, fig = visualisation.render_heatmap(board, heatmap, square=square) with open(f"{constants.FIGURE_DIRECTORY}/attention.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/attention.svg", board.fen(), fig @@ -206,9 +201,7 @@ def next_board( ) with gr.Column(scale=1): with gr.Row(): - model_name = gr.Textbox( - label="Selected model", lines=1, interactive=False, scale=7 - ) + model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) model_df.select( on_select_model_df, @@ -228,10 +221,7 @@ def next_board( label="Action sequence", lines=1, max_lines=1, - value=( - "e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " - "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6" - ), + value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) compute_cache_button = gr.Button("Compute cache") @@ -295,9 +285,7 @@ def next_board( inputs=base_inputs, outputs=outputs + [state_board_index], ) - next_board_button.click( - next_board, inputs=base_inputs, outputs=outputs + [state_board_index] - ) + next_board_button.click(next_board, inputs=base_inputs, outputs=outputs + [state_board_index]) attention_layer.change(make_plot, inputs=base_inputs, outputs=outputs) attention_head.change(make_plot, inputs=base_inputs, outputs=outputs) diff --git a/demo/backend_interface.py b/demo/backend_interface.py index b1e1bde..5bb5893 100644 --- a/demo/backend_interface.py +++ b/demo/backend_interface.py @@ -9,8 +9,8 @@ from lczero.backends import Backend, GameState, Weights from demo import constants, utils, visualisation -from lczerolens import move_utils -from lczerolens.utils import lczero as lczero_utils +from lczerolens import move_encodings +from lczerolens.model import lczero as lczero_utils from lczerolens.xai import PolicyLens @@ -74,9 +74,7 @@ def make_policy_plot( only_legal=only_legal, illegal_value=0, ) - pickup_agg, dropoff_agg = PolicyLens.aggregate_policy( - policy, int(aggregate_topk) - ) + pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(policy, int(aggregate_topk)) if view == "from": if board.turn == chess.WHITE: @@ -90,9 +88,7 @@ def make_policy_plot( heatmap = dropoff_agg.view(8, 8).flip(0).view(64) us_them = (board.turn, not board.turn) if only_legal: - legal_moves = [ - move_utils.encode_move(move, us_them) for move in board.legal_moves - ] + legal_moves = [move_encodings.encode_move(move, us_them) for move in board.legal_moves] filtered_policy = torch.zeros(1858) filtered_policy[legal_moves] = policy[legal_moves] if (filtered_policy < 0).any(): @@ -102,11 +98,9 @@ def make_policy_plot( topk_moves = torch.topk(policy, render_bestk) arrows = [] for move_index in topk_moves.indices: - move = move_utils.decode_move(move_index, us_them) + move = move_encodings.decode_move(move_index, us_them) arrows.append((move.from_square, move.to_square)) - svg_board, fig = visualisation.render_heatmap( - board, heatmap, arrows=arrows - ) + svg_board, fig = visualisation.render_heatmap(board, heatmap, arrows=arrows) with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f: f.write(svg_board) raw_policy, _ = lczero_utils.prediction_from_backend( @@ -118,7 +112,7 @@ def make_policy_plot( ) fig_dist = visualisation.render_policy_distribution( raw_policy, - [move_utils.encode_move(move, us_them) for move in board.legal_moves], + [move_encodings.encode_move(move, us_them) for move in board.legal_moves], ) return ( f"{constants.FIGURE_DIRECTORY}/policy.svg", @@ -140,9 +134,7 @@ def make_policy_plot( ) with gr.Column(scale=1): with gr.Row(): - model_name = gr.Textbox( - label="Selected model", lines=1, interactive=False, scale=7 - ) + model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) model_df.select( on_select_model_df, @@ -161,10 +153,7 @@ def make_policy_plot( label="Action sequence", lines=1, max_lines=1, - value=( - "e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " - "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6" - ), + value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) with gr.Group(): with gr.Row(): @@ -194,15 +183,11 @@ def make_policy_plot( value=5, scale=3, ) - only_legal = gr.Checkbox( - label="Only legal", value=True, scale=1 - ) + only_legal = gr.Checkbox(label="Only legal", value=True, scale=1) policy_button = gr.Button("Plot policy") colorbar = gr.Plot(label="Colorbar") - game_info = gr.Textbox( - label="Game info", lines=1, max_lines=1, value="" - ) + game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="") with gr.Column(): image = gr.Image(label="Board") density_plot = gr.Plot(label="Density") @@ -219,6 +204,4 @@ def make_policy_plot( only_legal, ] policy_outputs = [image, colorbar, game_info, density_plot] - policy_button.click( - make_policy_plot, inputs=policy_inputs, outputs=policy_outputs - ) + policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs) diff --git a/demo/convert_interface.py b/demo/convert_interface.py index dc41f3a..fc073e1 100644 --- a/demo/convert_interface.py +++ b/demo/convert_interface.py @@ -8,7 +8,7 @@ import gradio as gr from demo import constants, utils -from lczerolens.utils import lczero as lczero_utils +from lczerolens.model import lczero as lczero_utils def list_models(): @@ -147,9 +147,7 @@ def get_model_path( ) with gr.Column(scale=1): with gr.Row(): - model_name = gr.Textbox( - label="Selected model", lines=1, interactive=False, scale=7 - ) + model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) conversion_status = gr.Textbox( label="Conversion status", lines=1, diff --git a/demo/crp_interface.py b/demo/crp_interface.py index f460c95..06266f3 100644 --- a/demo/crp_interface.py +++ b/demo/crp_interface.py @@ -99,9 +99,7 @@ def make_plot( heatmap = relevance_tensor[plane_index - 1].view(64) if board.turn == chess.BLACK: heatmap = heatmap.view(8, 8).flip(0).view(64) - svg_board, fig = visualisation.render_heatmap( - board, heatmap, vmin=vmin, vmax=vmax - ) + svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax) with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig @@ -125,20 +123,14 @@ def make_history_plot( relevance_tensor = relevance_tensor / a_max vmin = -1 vmax = 1 - heatmap = ( - relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1] - .sum(dim=0) - .view(64) - ) + heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64) if board.turn == chess.BLACK: heatmap = heatmap.view(8, 8).flip(0).view(64) if board_index - history_index + 1 < 0: history_board = chess.Board(fen=None) else: history_board = boards[board_index - history_index + 1] - svg_board, fig = visualisation.render_heatmap( - history_board, heatmap, vmin=vmin, vmax=vmax - ) + svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax) with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig @@ -194,9 +186,7 @@ def next_board( ) with gr.Column(scale=1): with gr.Row(): - model_name = gr.Textbox( - label="Selected model", lines=1, interactive=False, scale=7 - ) + model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) model_df.select( on_select_model_df, @@ -216,10 +206,7 @@ def next_board( label="Action sequence", lines=1, max_lines=1, - value=( - "e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " - "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6" - ), + value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) compute_cache_button = gr.Button("Compute heatmaps") @@ -277,9 +264,7 @@ def next_board( outputs=outputs, ) - previous_board_button.click( - previous_board, inputs=base_inputs, outputs=outputs - ) + previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs) next_board_button.click(next_board, inputs=base_inputs, outputs=outputs) plane_index.change( diff --git a/demo/encoding_interface.py b/demo/encoding_interface.py index a3c8691..07be8f5 100644 --- a/demo/encoding_interface.py +++ b/demo/encoding_interface.py @@ -6,7 +6,7 @@ import gradio as gr from demo import constants, visualisation -from lczerolens import board_utils +from lczerolens import board_encodings def make_encoding_plot( @@ -27,13 +27,11 @@ def make_encoding_plot( except ValueError: gr.Warning("Invalid action sequence, using starting position.") board = chess.Board() - board_tensor = board_utils.board_to_input_tensor(board) + board_tensor = board_encodings.board_to_input_tensor(board) heatmap = board_tensor[plane_index] if color_flip and board.turn == chess.BLACK: heatmap = heatmap.flip(0) - svg_board, fig = visualisation.render_heatmap( - board, heatmap.view(64), vmin=0.0, vmax=1.0 - ) + svg_board, fig = visualisation.render_heatmap(board, heatmap.view(64), vmin=0.0, vmax=1.0) with open(f"{constants.FIGURE_DIRECTORY}/encoding.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/encoding.svg", fig @@ -52,10 +50,7 @@ def make_encoding_plot( label="Action sequence", lines=1, max_lines=1, - value=( - "e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " - "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6" - ), + value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) with gr.Group(): with gr.Row(): @@ -67,9 +62,7 @@ def make_encoding_plot( value=0, scale=3, ) - color_flip = gr.Checkbox( - label="Color flip", value=True, scale=1 - ) + color_flip = gr.Checkbox(label="Color flip", value=True, scale=1) colorbar = gr.Plot(label="Colorbar") with gr.Column(): @@ -82,18 +75,8 @@ def make_encoding_plot( color_flip, ] policy_outputs = [image, colorbar] - board_fen.submit( - make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs - ) - action_seq.submit( - make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs - ) - plane_index.change( - make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs - ) - color_flip.change( - make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs - ) - interface.load( - make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs - ) + board_fen.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs) + action_seq.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs) + plane_index.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs) + color_flip.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs) + interface.load(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs) diff --git a/demo/lrp_interface.py b/demo/lrp_interface.py index e778df3..7c4acbd 100644 --- a/demo/lrp_interface.py +++ b/demo/lrp_interface.py @@ -99,9 +99,7 @@ def make_plot( heatmap = relevance_tensor[plane_index - 1].view(64) if board.turn == chess.BLACK: heatmap = heatmap.view(8, 8).flip(0).view(64) - svg_board, fig = visualisation.render_heatmap( - board, heatmap, vmin=vmin, vmax=vmax - ) + svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax) with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig @@ -125,20 +123,14 @@ def make_history_plot( relevance_tensor = relevance_tensor / a_max vmin = -1 vmax = 1 - heatmap = ( - relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1] - .sum(dim=0) - .view(64) - ) + heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64) if board.turn == chess.BLACK: heatmap = heatmap.view(8, 8).flip(0).view(64) if board_index - history_index + 1 < 0: history_board = chess.Board(fen=None) else: history_board = boards[board_index - history_index + 1] - svg_board, fig = visualisation.render_heatmap( - history_board, heatmap, vmin=vmin, vmax=vmax - ) + svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax) with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig @@ -194,9 +186,7 @@ def next_board( ) with gr.Column(scale=1): with gr.Row(): - model_name = gr.Textbox( - label="Selected model", lines=1, interactive=False, scale=7 - ) + model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) model_df.select( on_select_model_df, @@ -216,10 +206,7 @@ def next_board( label="Action sequence", lines=1, max_lines=1, - value=( - "e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " - "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6" - ), + value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) compute_cache_button = gr.Button("Compute heatmaps") @@ -277,9 +264,7 @@ def next_board( outputs=outputs, ) - previous_board_button.click( - previous_board, inputs=base_inputs, outputs=outputs - ) + previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs) next_board_button.click(next_board, inputs=base_inputs, outputs=outputs) plane_index.change( diff --git a/demo/policy_interface.py b/demo/policy_interface.py index 880c572..998d342 100644 --- a/demo/policy_interface.py +++ b/demo/policy_interface.py @@ -8,7 +8,7 @@ import torch from demo import constants, utils, visualisation -from lczerolens import move_utils +from lczerolens import move_encodings from lczerolens.xai import PolicyLens current_board = None @@ -72,10 +72,7 @@ def compute_policy( policy = torch.softmax(output["policy"][0], dim=-1) filtered_policy = torch.full((1858,), 0.0) - legal_moves = [ - move_utils.encode_move(move, (board.turn, not board.turn)) - for move in board.legal_moves - ] + legal_moves = [move_encodings.encode_move(move, (board.turn, not board.turn)) for move in board.legal_moves] filtered_policy[legal_moves] = policy[legal_moves] policy = filtered_policy @@ -100,9 +97,7 @@ def make_plot( gr.Warning("Please compute a policy first.") return (None, None, "", None) - pickup_agg, dropoff_agg = PolicyLens.aggregate_policy( - current_policy, int(aggregate_topk) - ) + pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(current_policy, int(aggregate_topk)) if view == "from": if current_board.turn == chess.WHITE: @@ -116,21 +111,14 @@ def make_plot( heatmap = dropoff_agg.view(8, 8).flip(0).view(64) us_them = (current_board.turn, not current_board.turn) topk_moves = torch.topk(current_policy, 50) - move = move_utils.decode_move( - topk_moves.indices[move_to_play - 1], us_them - ) + move = move_encodings.decode_move(topk_moves.indices[move_to_play - 1], us_them) arrows = [(move.from_square, move.to_square)] - svg_board, fig = visualisation.render_heatmap( - current_board, heatmap, arrows=arrows - ) + svg_board, fig = visualisation.render_heatmap(current_board, heatmap, arrows=arrows) with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f: f.write(svg_board) fig_dist = visualisation.render_policy_distribution( current_raw_policy, - [ - move_utils.encode_move(move, us_them) - for move in current_board.legal_moves - ], + [move_encodings.encode_move(move, us_them) for move in current_board.legal_moves], ) return ( f"{constants.FIGURE_DIRECTORY}/policy.svg", @@ -171,7 +159,7 @@ def play_move( global current_board global current_policy - move = move_utils.decode_move( + move = move_encodings.decode_move( current_policy.topk(50).indices[move_to_play - 1], (current_board.turn, not current_board.turn), ) @@ -205,9 +193,7 @@ def play_move( ) with gr.Column(scale=1): with gr.Row(): - model_name = gr.Textbox( - label="Selected model", lines=1, interactive=False, scale=7 - ) + model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) model_df.select( on_select_model_df, None, @@ -225,10 +211,7 @@ def play_move( action_seq = gr.Textbox( label="Action sequence", lines=1, - value=( - "e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " - "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6" - ), + value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) with gr.Group(): with gr.Row(): @@ -259,9 +242,7 @@ def play_move( policy_button = gr.Button("Compute policy") colorbar = gr.Plot(label="Colorbar") - game_info = gr.Textbox( - label="Game info", lines=1, max_lines=1, value="" - ) + game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="") with gr.Column(): image = gr.Image(label="Board") density_plot = gr.Plot(label="Density") @@ -275,24 +256,16 @@ def play_move( move_to_play, ] policy_outputs = [image, colorbar, game_info, density_plot] - policy_button.click( - make_policy_plot, inputs=policy_inputs, outputs=policy_outputs - ) - board_fen.submit( - make_policy_plot, inputs=policy_inputs, outputs=policy_outputs - ) - action_seq.submit( - make_policy_plot, inputs=policy_inputs, outputs=policy_outputs - ) + policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs) + board_fen.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs) + action_seq.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs) fast_inputs = [ view, aggregate_topk, move_to_play, ] - aggregate_topk.change( - make_plot, inputs=fast_inputs, outputs=policy_outputs - ) + aggregate_topk.change(make_plot, inputs=fast_inputs, outputs=policy_outputs) view.change(make_plot, inputs=fast_inputs, outputs=policy_outputs) move_to_play.change(make_plot, inputs=fast_inputs, outputs=policy_outputs) diff --git a/demo/statistics_interface.py b/demo/statistics_interface.py index 2d8cb95..1a40401 100644 --- a/demo/statistics_interface.py +++ b/demo/statistics_interface.py @@ -59,9 +59,7 @@ def make_policy_plot(): ) return None else: - return visualisation.render_policy_statistics( - current_policy_statistics - ) + return visualisation.render_policy_statistics(current_policy_statistics) def compute_lrp_statistics( @@ -89,9 +87,7 @@ def make_lrp_plot(): ) return None, None, None else: - return visualisation.render_relevance_proportion( - current_lrp_statistics - ) + return visualisation.render_relevance_proportion(current_lrp_statistics) def compute_probing_statistics( @@ -106,12 +102,8 @@ def compute_probing_statistics( "Please select a model.", ) return None - wrapper, lens = utils.get_wrapper_lens_from_state( - model_name, "probing", concept=check_concept - ) - current_probing_statistics = lens.compute_statistics( - unique_check_dataset, wrapper, 10 - ) + wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "probing", concept=check_concept) + current_probing_statistics = lens.compute_statistics(unique_check_dataset, wrapper, 10) return make_probing_plot() @@ -124,9 +116,7 @@ def make_probing_plot(): ) return None else: - return visualisation.render_probing_statistics( - current_probing_statistics - ) + return visualisation.render_probing_statistics(current_probing_statistics) with gr.Blocks() as interface: @@ -141,9 +131,7 @@ def make_probing_plot(): ) with gr.Column(scale=1): with gr.Row(): - model_name = gr.Textbox( - label="Selected model", lines=1, interactive=False, scale=7 - ) + model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) model_df.select( on_select_model_df, None, @@ -153,9 +141,7 @@ def make_probing_plot(): with gr.Row(): with gr.Column(): policy_plot = gr.Plot(label="Policy statistics") - policy_compute_button = gr.Button( - value="Compute policy statistics" - ) + policy_compute_button = gr.Button(value="Compute policy statistics") policy_plot_button = gr.Button(value="Plot policy statistics") policy_compute_button.click( diff --git a/demo/utils.py b/demo/utils.py index 931a93d..b4cb93e 100644 --- a/demo/utils.py +++ b/demo/utils.py @@ -8,7 +8,7 @@ from demo import constants, state from lczerolens import Lens, ModelWrapper -from lczerolens.utils import lczero as lczero_utils +from lczerolens.model import lczero as lczero_utils def get_models_info(onnx=True, leela=True): @@ -68,9 +68,7 @@ def save_model(tmp_file_path): popen.wait() if popen.returncode != 0: raise RuntimeError - file_desc = ( - popen.stdout.read().decode("utf-8").split(tmp_file_path)[1].strip() - ) + file_desc = popen.stdout.read().decode("utf-8").split(tmp_file_path)[1].strip() rename_match = re.search(r"was\s\"(?P.+)\"", file_desc) type_match = re.search(r"\:\s(?P[a-zA-Z]+)", file_desc) if rename_match is None or type_match is None: @@ -99,33 +97,25 @@ def get_wrapper_from_state(model_name): if model_name in state.wrappers: return state.wrappers[model_name] else: - wrapper = ModelWrapper.from_path( - f"{constants.MODEL_DIRECTORY}/{model_name}" - ) + wrapper = ModelWrapper.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}") state.wrappers[model_name] = wrapper return wrapper -def get_wrapper_lens_from_state( - model_name, lens_type, lens_name="lens", **kwargs -): +def get_wrapper_lens_from_state(model_name, lens_type, lens_name="lens", **kwargs): """ Get the model wrapper and lens from the state. """ if model_name in state.wrappers: wrapper = state.wrappers[model_name] else: - wrapper = ModelWrapper.from_path( - f"{constants.MODEL_DIRECTORY}/{model_name}" - ) + wrapper = ModelWrapper.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}") state.wrappers[model_name] = wrapper if lens_name in state.lenses[lens_type]: lens = state.lenses[lens_type][lens_name] else: lens = Lens.from_name(lens_type, **kwargs) if not lens.is_compatible(wrapper): - raise ValueError( - f"Lens of type {lens_type} not compatible with model." - ) + raise ValueError(f"Lens of type {lens_type} not compatible with model.") state.lenses[lens_type][lens_name] = lens return wrapper, lens diff --git a/demo/visualisation.py b/demo/visualisation.py index ae18fc8..a590a09 100644 --- a/demo/visualisation.py +++ b/demo/visualisation.py @@ -44,9 +44,7 @@ def render_heatmap( for square_index in range(64): color = COLOR_MAP(norm(heatmap[square_index])) color = (*color[:3], ALPHA) - color_dict[square_index] = matplotlib.colors.to_hex( - color, keep_alpha=True - ) + color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True) fig = plt.figure(figsize=(6, 0.6)) ax = plt.gca() ax.axis("off") @@ -88,15 +86,15 @@ def render_architecture(model, name: str = "model", directory: str = ""): value = torch.zeros(outcome_probs.shape[0], 1) else: policy, outcome_probs, value = out - torchviz.make_dot( - policy, params=dict(list(model.named_parameters())) - ).render(f"{directory}/{name}_policy", format="svg") - torchviz.make_dot( - outcome_probs, params=dict(list(model.named_parameters())) - ).render(f"{directory}/{name}_outcome_probs", format="svg") - torchviz.make_dot( - value, params=dict(list(model.named_parameters())) - ).render(f"{directory}/{name}_value", format="svg") + torchviz.make_dot(policy, params=dict(list(model.named_parameters()))).render( + f"{directory}/{name}_policy", format="svg" + ) + torchviz.make_dot(outcome_probs, params=dict(list(model.named_parameters()))).render( + f"{directory}/{name}_outcome_probs", format="svg" + ) + torchviz.make_dot(value, params=dict(list(model.named_parameters()))).render( + f"{directory}/{name}_value", format="svg" + ) def render_policy_distribution( @@ -107,9 +105,7 @@ def render_policy_distribution( """ Render the policy distribution histogram. """ - legal_mask = torch.Tensor( - [move in legal_moves for move in range(1858)] - ).bool() + legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool() fig = plt.figure(figsize=(6, 6)) ax = plt.gca() _, bins = np.histogram(policy, bins=n_bins) @@ -143,22 +139,10 @@ def render_policy_statistics( fig = plt.figure(figsize=(6, 6)) ax = plt.gca() move_indices = list(statistics["mean_legal_logits"].keys()) - legal_means_avg = [ - np.mean(statistics["mean_legal_logits"][move_idx]) - for move_idx in move_indices - ] - illegal_means_avg = [ - np.mean(statistics["mean_illegal_logits"][move_idx]) - for move_idx in move_indices - ] - legal_means_std = [ - np.std(statistics["mean_legal_logits"][move_idx]) - for move_idx in move_indices - ] - illegal_means_std = [ - np.std(statistics["mean_illegal_logits"][move_idx]) - for move_idx in move_indices - ] + legal_means_avg = [np.mean(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices] + illegal_means_avg = [np.mean(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices] + legal_means_std = [np.std(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices] + illegal_means_std = [np.std(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices] ax.errorbar( move_indices, legal_means_avg, @@ -187,25 +171,11 @@ def render_relevance_proportion(statistics, scaled=True): move_indices = list(statistics["planes_relevance_proportion"].keys()) for h in range(8): relevance_proportion_avg = [ - np.mean( - [ - rel[13 * h : 13 * (h + 1)].sum() - for rel in statistics["planes_relevance_proportion"][ - move_idx - ] - ] - ) + np.mean([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] relevance_proportion_std = [ - np.std( - [ - rel[13 * h : 13 * (h + 1)].sum() - for rel in statistics["planes_relevance_proportion"][ - move_idx - ] - ] - ) + np.std([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] ax.errorbar( @@ -217,21 +187,11 @@ def render_relevance_proportion(statistics, scaled=True): ) relevance_proportion_avg = [ - np.mean( - [ - rel[104:108].sum() - for rel in statistics["planes_relevance_proportion"][move_idx] - ] - ) + np.mean([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] relevance_proportion_std = [ - np.std( - [ - rel[104:108].sum() - for rel in statistics["planes_relevance_proportion"][move_idx] - ] - ) + np.std([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] ax.errorbar( @@ -242,21 +202,11 @@ def render_relevance_proportion(statistics, scaled=True): c=COLOR_MAP(norm(8 / 9)), ) relevance_proportion_avg = [ - np.mean( - [ - rel[108:].sum() - for rel in statistics["planes_relevance_proportion"][move_idx] - ] - ) + np.mean([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] relevance_proportion_std = [ - np.std( - [ - rel[108:].sum() - for rel in statistics["planes_relevance_proportion"][move_idx] - ] - ) + np.std([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]]) for move_idx in move_indices ] ax.errorbar( @@ -280,12 +230,10 @@ def render_relevance_proportion(statistics, scaled=True): move_indices = list(statistics[stat_key].keys()) for p in range(13): relevance_proportion_avg = [ - np.mean([rel[p].item() for rel in statistics[stat_key][move_idx]]) - for move_idx in move_indices + np.mean([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices ] relevance_proportion_std = [ - np.std([rel[p].item() for rel in statistics[stat_key][move_idx]]) - for move_idx in move_indices + np.std([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices ] ax.errorbar( move_indices, @@ -306,16 +254,9 @@ def render_relevance_proportion(statistics, scaled=True): stat_key = f"configuration_relevance_proportion_threatened_piece{p}" n_attackers = list(statistics[stat_key].keys()) relevance_proportion_avg = [ - np.mean( - statistics[ - f"configuration_relevance_proportion_threatened_piece{p}" - ][n] - ) - for n in n_attackers - ] - relevance_proportion_std = [ - np.std(statistics[stat_key][n]) for n in n_attackers + np.mean(statistics[f"configuration_relevance_proportion_threatened_piece{p}"][n]) for n in n_attackers ] + relevance_proportion_std = [np.std(statistics[stat_key][n]) for n in n_attackers] ax.errorbar( n_attackers, relevance_proportion_avg, diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index b237c3c..0000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,22 +0,0 @@ -# syntax=docker/dockerfile:1 -FROM python:3.9.18 - -WORKDIR /code - -RUN apt-get update && apt-get install -y \ - ocl-icd-opencl-dev \ - libopenblas-dev \ - zip - -COPY poetry.lock pyproject.toml /code/ -RUN pip install --upgrade pip -RUN pip install poetry==1.6.1 -RUN poetry config virtualenvs.create false -RUN poetry install --no-interaction --no-ansi --with demo - -RUN mkdir -p /service/demo -COPY demo /service/demo -EXPOSE 8000 -COPY docker/start.sh ./ -RUN chmod +x start.sh -ENTRYPOINT ["/code/start.sh"] diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml deleted file mode 100644 index 276d7ed..0000000 --- a/docker/docker-compose.yml +++ /dev/null @@ -1,10 +0,0 @@ -version: '3' -services: - demo: - build: - context: ../ - dockerfile: docker/Dockerfile - ports: - - "8002:8000" - stdin_open: true - tty: true diff --git a/docker/start.sh b/docker/start.sh deleted file mode 100644 index 9dbf543..0000000 --- a/docker/start.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -echo "----------- Launching Demo --------- " -cd /service/ -python -m demo.main diff --git a/pyproject.toml b/pyproject.toml index fc60688..d3a70fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,10 @@ -[tool.black] -line-length = 79 - -[tool.isort] -profile = "black" -line_length = 79 -src_paths = ["src", "tests", "scripts", "docs", "demo"] +[tool.ruff] +line-length = 119 +target-version = "py39" [tool.poetry] name = "lczerolens" -version = "0.1.3" +version = "0.2.0" description = "Interpretability for LeelaChessZero networks." readme = "README.md" license = "MIT" diff --git a/scripts/cluster_latent_relevances.py b/scripts/cluster_latent_relevances.py index bbeb524..5dc7e6c 100644 --- a/scripts/cluster_latent_relevances.py +++ b/scripts/cluster_latent_relevances.py @@ -35,9 +35,7 @@ dataset_name = "TCEC_game_collection_random_boards_bestlegal_knight.jsonl" only_config_rel = True best_legal = True -run_name = ( - f"bestres_tcec_bestlegal_knight_{'expbest' if best_legal else 'full'}" -) +run_name = f"bestres_tcec_bestlegal_knight_{'expbest' if best_legal else 'full'}" n_samples = 1000 conv_sum_dims = () cosine_sim = False @@ -54,10 +52,7 @@ def legal_init_rel(board_list, out_tensor): legal_move_mask = torch.zeros((len(board_list), 1858)) for idx, board in enumerate(board_list): - legal_moves = [ - move_utils.encode_move(move, (board.turn, not board.turn)) - for move in board.legal_moves - ] + legal_moves = [move_utils.encode_move(move, (board.turn, not board.turn)) for move in board.legal_moves] legal_move_mask[idx, legal_moves] = 1 return legal_move_mask * out_tensor @@ -140,9 +135,7 @@ def init_rel_fn(out_tensor): for layer_name, relevances in all_relevances.items(): relevances = relevances[:n_samples] if conv_sum_dims: - relevances = relevances.sum(dim=conv_sum_dims).view( - relevances.shape[0], -1 - ) + relevances = relevances.sum(dim=conv_sum_dims).view(relevances.shape[0], -1) else: relevances = relevances.view(relevances.shape[0], -1) @@ -160,10 +153,7 @@ def init_rel_fn(out_tensor): plt.title("Clustered Latent Relevances") plt.xlabel("Dimension 1") plt.ylabel("Dimension 2") - plt.savefig( - f"scripts/results/{run_name}/{viz_name}/" - f"{layer_name.replace('/','.')}_t-sne.png" - ) + plt.savefig(f"scripts/results/{run_name}/{viz_name}/" f"{layer_name.replace('/','.')}_t-sne.png") plt.close() ####################################### @@ -177,10 +167,7 @@ def init_rel_fn(out_tensor): cluster_center = kmeans.cluster_centers_[idx_cluster] if cosine_sim: dot_prod = relevances @ cluster_center.T - similarities = dot_prod / ( - np.linalg.norm(relevances, axis=1) - * np.linalg.norm(cluster_center) - ) + similarities = dot_prod / (np.linalg.norm(relevances, axis=1) * np.linalg.norm(cluster_center)) nearest_neighbors = np.argsort(similarities)[-8:] else: distances = np.linalg.norm(relevances - cluster_center, axis=1) @@ -199,22 +186,16 @@ def init_rel_fn(out_tensor): # compute heatmap for each nearest neighbor for idx_sample in nearest_neighbors: _, board, label = concept_dataset[idx_sample] - _, board_tensor, _ = ConceptDataset.collate_fn_tensor( - [concept_dataset[idx_sample]] - ) + _, board_tensor, _ = ConceptDataset.collate_fn_tensor([concept_dataset[idx_sample]]) label_tensor = torch.tensor([label]) def init_rel_fn(out_tensor): rel = torch.zeros_like(out_tensor) for i in range(rel.shape[0]): - rel[i, label_tensor[i]] = out_tensor[ - i, label_tensor[i] - ] + rel[i, label_tensor[i]] = out_tensor[i, label_tensor[i]] return rel - move = move_utils.decode_move( - label, (board.turn, not board.turn), board - ) + move = move_utils.decode_move(label, (board.turn, not board.turn), board) uci_move = move.uci() if viz_latent: @@ -245,22 +226,12 @@ def init_rel_fn(out_tensor): ) input_relevances = board_tensor.grad if not board.turn: - input_relevances = ( - input_relevances.view(112, 8, 8) - .flip(1) - .view(112, 64) - ) + input_relevances = input_relevances.view(112, 8, 8).flip(1).view(112, 64) input_relevances = input_relevances.view(112, 64) heatmap_str_list = [ - create_heatmap_string( - input_relevances.sum(dim=0), abs_max=True - ), - create_heatmap_string( - input_relevances[:13].sum(dim=0), abs_max=True - ), - create_heatmap_string( - input_relevances[104:].sum(dim=0), abs_max=True - ), + create_heatmap_string(input_relevances.sum(dim=0), abs_max=True), + create_heatmap_string(input_relevances[:13].sum(dim=0), abs_max=True), + create_heatmap_string(input_relevances[104:].sum(dim=0), abs_max=True), ] heatmap_caption_list = [ "Total relevance", @@ -271,9 +242,7 @@ def init_rel_fn(out_tensor): hist = input_relevances[13:104].abs().sum() meta = input_relevances[104:].abs().sum() total = (h0 + hist + meta) / 100 - add_caption = ( - f"{h0/total:.0f}%|{hist/total:.0f}%|{meta/total:.0f}%" - ) + add_caption = f"{h0/total:.0f}%|{hist/total:.0f}%|{meta/total:.0f}%" doc = add_plot( doc, @@ -287,8 +256,6 @@ def init_rel_fn(out_tensor): # Generate pdf doc.generate_pdf( - f"scripts/results/{run_name}" - f"/{viz_name}/{layer_name.replace('/','.')}" - f"_cluster_{idx_cluster}", + f"scripts/results/{run_name}" f"/{viz_name}/{layer_name.replace('/','.')}" f"_cluster_{idx_cluster}", clean_tex=True, ) diff --git a/scripts/create_figure.py b/scripts/create_figure.py index 00f6ab9..d546be0 100644 --- a/scripts/create_figure.py +++ b/scripts/create_figure.py @@ -1,5 +1,4 @@ -"""Nice plotting of chessboard and heatmap with arrows. -""" +"""Nice plotting of chessboard and heatmap with arrows.""" import chess from pylatex import Figure, NoEscape, SubFigure @@ -19,9 +18,7 @@ def add_plot( doc.append(NoEscape(r"\centering")) if caption is not None: fig.add_caption(caption) - verbatim = NoEscape( - r"\storechessboardstyle{8x8}{maxfield=h8,showmover=true}" - ) + verbatim = NoEscape(r"\storechessboardstyle{8x8}{maxfield=h8,showmover=true}") doc.append(verbatim) with doc.create( @@ -33,9 +30,7 @@ def add_plot( doc.append(NoEscape(r"\chessboard[style=8x8,")) if current_piece_pos is not None: markmove = current_piece_pos + "-" + next_move - markfields = ( - "{{" + current_piece_pos + "},{" + next_move + "}}" - ) + markfields = "{{" + current_piece_pos + "},{" + next_move + "}}" chessboard_fen = NoEscape( rf"setfen={label},showmover=true," rf"color=green,pgfstyle=straightmove,markmove={markmove}," @@ -43,19 +38,14 @@ def add_plot( ) else: chessboard_fen = NoEscape( - rf"\chessboard[style=8x8,setfen={label}," - "showmover=true,pgfstyle=straightmove,color=green,]" + rf"\chessboard[style=8x8,setfen={label}," "showmover=true,pgfstyle=straightmove,color=green,]" ) doc.append(chessboard_fen) for i, heatmap_str in enumerate(heatmap_str_list): doc.append(NoEscape(r"\hfill")) - with doc.create( - SubFigure(width=NoEscape(r"0.45\textwidth")) - ) as subfig: + with doc.create(SubFigure(width=NoEscape(r"0.45\textwidth"))) as subfig: subfig.add_caption(heatmap_caption_list[i]) - heatmap_begin = NoEscape( - r"\chessboard[style=8x8,showmover=false," - ) + heatmap_begin = NoEscape(r"\chessboard[style=8x8,showmover=false,") doc.append(heatmap_begin) heatmap_end = NoEscape(heatmap_str) + NoEscape(r"]") @@ -70,15 +60,9 @@ def create_heatmap_string(heatmap, abs_max=True): for idx, name in enumerate(chess.SQUARE_NAMES): colorcode = heatmap[idx] if colorcode >= 0: - heatmap_str += ( - "pgfstyle=color, color=red!" - f"{colorcode*100:.0f}!white, markfield={name},\n" - ) + heatmap_str += "pgfstyle=color, color=red!" f"{colorcode*100:.0f}!white, markfield={name},\n" elif colorcode < 0: - heatmap_str += ( - "pgfstyle=color, color=blue!" - f"{-colorcode*100:.0f}!white, markfield={name},\n" - ) + heatmap_str += "pgfstyle=color, color=blue!" f"{-colorcode*100:.0f}!white, markfield={name},\n" else: raise TypeError return heatmap_str diff --git a/scripts/find_concepts.py b/scripts/find_concepts.py index 2424a43..4ecaaa2 100644 --- a/scripts/find_concepts.py +++ b/scripts/find_concepts.py @@ -59,9 +59,7 @@ def get_n_concepts(l_name, model): fv_path = f"scripts/im_viz/{model_name}-{dataset_name}" -fv = crp_helpers.ModifiedFeatureVisualization( - attribution, unique_dataset, layer_map, preprocess_fn=None, path=fv_path -) +fv = crp_helpers.ModifiedFeatureVisualization(attribution, unique_dataset, layer_map, preprocess_fn=None, path=fv_path) def collate_fn_tensor(batch): @@ -76,9 +74,7 @@ def collate_fn_tuple(batch): if save_files: - saved_files = fv.run( - composite, batch_size, 100, custom_collate_fn=collate_fn_tensor - ) + saved_files = fv.run(composite, batch_size, 100, custom_collate_fn=collate_fn_tensor) print("[INFO] Files saved!") concepts = { @@ -90,9 +86,7 @@ def collate_fn_tuple(batch): for case, concept in concepts.items(): unique_dataset.concept = concept - concept_fen_strings = set( - [b.fen() for _, b, label in unique_dataset if label == 1] - ) + concept_fen_strings = set([b.fen() for _, b, label in unique_dataset if label == 1]) print(f"[INFO] Concept '{case}' positives: {len(concept_fen_strings)}") for l_name in layer_names: diff --git a/scripts/make_datasets.py b/scripts/make_datasets.py index b775c58..dfae80d 100644 --- a/scripts/make_datasets.py +++ b/scripts/make_datasets.py @@ -58,8 +58,7 @@ written_boards = 0 print(f"[INFO] Converting games to boards: {dataset_name}") with jsonlines.open( - f"{ARGS.output_root}/assets/" - f"{dataset_name.replace('.jsonl', '_boards.jsonl')}", + f"{ARGS.output_root}/assets/" f"{dataset_name.replace('.jsonl', '_boards.jsonl')}", "w", ) as writer: for game in tqdm.tqdm(dataset.games): @@ -80,8 +79,7 @@ written_boards = 0 random.seed(tcec_random_seed) with jsonlines.open( - f"{ARGS.output_root}/assets/" - f"{dataset_name.replace('.jsonl', '_random_boards.jsonl')}", + f"{ARGS.output_root}/assets/" f"{dataset_name.replace('.jsonl', '_random_boards.jsonl')}", "w", ) as writer: for game in tqdm.tqdm(dataset.games): @@ -112,8 +110,7 @@ concept_dataset = ConceptDataset.from_board_dataset(dataset, concept) concept_dataset.save( - f"{ARGS.output_root}/assets/" - f"{dataset_name.replace('.jsonl', '_bestlegal.jsonl')}", + f"{ARGS.output_root}/assets/" f"{dataset_name.replace('.jsonl', '_bestlegal.jsonl')}", n_history=n_history, ) print(f"[INFO] Concept dataset written: {len(concept_dataset)}") @@ -130,18 +127,13 @@ concept_dataset = ConceptDataset(f"./assets/{dataset_name}") def filter_fn(board, label, gameid): - move = move_utils.decode_move( - label, (board.turn, not board.turn), board - ) + move = move_utils.decode_move(label, (board.turn, not board.turn), board) from_piece = board.piece_at(move.from_square) - return (from_piece == chess.Piece.from_symbol("N")) or ( - from_piece == chess.Piece.from_symbol("n") - ) + return (from_piece == chess.Piece.from_symbol("N")) or (from_piece == chess.Piece.from_symbol("n")) concept_dataset.filter_(filter_fn) concept_dataset.save( - f"{ARGS.output_root}/assets/" - f"{dataset_name.replace('.jsonl', '_knight.jsonl')}", + f"{ARGS.output_root}/assets/" f"{dataset_name.replace('.jsonl', '_knight.jsonl')}", n_history=n_history, ) print(f"[INFO] Concept dataset written: {len(concept_dataset)}") @@ -161,8 +153,7 @@ def filter_fn(board, label, gameid): concept_dataset.filter_(filter_fn) concept_dataset.save( - f"{ARGS.output_root}/assets/" - f"{dataset_name.replace('.jsonl', '_10.jsonl')}", + f"{ARGS.output_root}/assets/" f"{dataset_name.replace('.jsonl', '_10.jsonl')}", n_history=n_history, ) print(f"[INFO] Concept dataset written: {len(concept_dataset)}") @@ -175,13 +166,9 @@ def filter_fn(board, label, gameid): model.to(DEVICE) concept = BestLegalMoveConcept(model) - concept_dataset = ConceptDataset.from_game_dataset( - dataset, n_history=n_history - ) + concept_dataset = ConceptDataset.from_game_dataset(dataset, n_history=n_history) concept_dataset.set_concept(concept, mininterval=10) - new_dataset_name = dataset_name.replace( - ".jsonl", "_boards_bestlegal.jsonl" - ) + new_dataset_name = dataset_name.replace(".jsonl", "_boards_bestlegal.jsonl") concept_dataset.save( f"{ARGS.output_root}/assets/{new_dataset_name}", n_history=n_history, diff --git a/scripts/pixel_flipping.py b/scripts/pixel_flipping.py index fb76c88..6186853 100644 --- a/scripts/pixel_flipping.py +++ b/scripts/pixel_flipping.py @@ -43,14 +43,8 @@ indices, board_tensor, labels = next(iter(dataloader)) rule_names = ["default", "no_onnx"] -morf_logit_dict = { - rule_name: {layer_name: [] for layer_name in layer_names} - for rule_name in rule_names -} -lerf_logit_dict = { - rule_name: {layer_name: [] for layer_name in layer_names} - for rule_name in rule_names -} +morf_logit_dict = {rule_name: {layer_name: [] for layer_name in layer_names} for rule_name in rule_names} +lerf_logit_dict = {rule_name: {layer_name: [] for layer_name in layer_names} for rule_name in rule_names} def mask_fn(output, modify_data): @@ -84,9 +78,7 @@ def mask_fn(output, modify_data): def init_rel_fn(out_tensor): rel = torch.zeros_like(out_tensor) for i in range(rel.shape[0]): - rel[i, label_tensor[i]] = out_tensor[ - i, label_tensor[i] - ] + rel[i, label_tensor[i]] = out_tensor[i, label_tensor[i]] return rel board_tensor.requires_grad = True @@ -105,44 +97,27 @@ def init_rel_fn(out_tensor): ) latent_rel = attr.relevances[layer_name] if morf: - to_flip = latent_rel.view( - board_tensor.shape[0], -1 - ).argmax(dim=1) + to_flip = latent_rel.view(board_tensor.shape[0], -1).argmax(dim=1) else: - to_flip = latent_rel.view( - board_tensor.shape[0], -1 - ).argmin(dim=1) + to_flip = latent_rel.view(board_tensor.shape[0], -1).argmin(dim=1) if hook.config.data[layer_name] is None: - mask_flat = torch.ones_like(latent_rel).view( - board_tensor.shape[0], -1 - ) + mask_flat = torch.ones_like(latent_rel).view(board_tensor.shape[0], -1) for i in range(mask_flat.shape[0]): mask_flat[i, to_flip[i]] = 0 - hook.config.data[layer_name] = mask_flat.view_as( - latent_rel - ) + hook.config.data[layer_name] = mask_flat.view_as(latent_rel) else: - old_mask_flat = hook.config.data[layer_name].view( - board_tensor.shape[0], -1 - ) + old_mask_flat = hook.config.data[layer_name].view(board_tensor.shape[0], -1) for i in range(old_mask_flat.shape[0]): old_mask_flat[i, to_flip[i]] = 0 - hook.config.data[layer_name] = old_mask_flat.view_as( - latent_rel - ) + hook.config.data[layer_name] = old_mask_flat.view_as(latent_rel) if debug: print(f"[INFO] Most relevant pixels: {to_flip}") logit_dict[rule_name][layer_name].append( - attr.prediction.gather( - 1, label_tensor.view(-1, 1) - ).detach() + attr.prediction.gather(1, label_tensor.view(-1, 1)).detach() ) print(f"[INFO] Layer: {layer_name} done") if debug: - print( - "[INFO] Logits: " - f"{torch.cat(logit_dict[rule_name][layer_name], dim=1)}" - ) + print("[INFO] Logits: " f"{torch.cat(logit_dict[rule_name][layer_name], dim=1)}") hook.remove() hook.clear() print(f"[INFO] Rule: {rule_name} done") diff --git a/scripts/register_wandb_datasets.py b/scripts/register_wandb_datasets.py index 50bca88..51697c2 100644 --- a/scripts/register_wandb_datasets.py +++ b/scripts/register_wandb_datasets.py @@ -20,16 +20,12 @@ ####################################### parser = argparse.ArgumentParser("register-wandb-datasets") parser.add_argument("--output_root", type=str, default=".") -parser.add_argument( - "--make_datasets", action=argparse.BooleanOptionalAction, default=False -) +parser.add_argument("--make_datasets", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--train_samples", type=int, default=100_000) parser.add_argument("--val_samples", type=int, default=5_000) parser.add_argument("--test_samples", type=int, default=5_000) -parser.add_argument( - "--log_datasets", action=argparse.BooleanOptionalAction, default=False -) +parser.add_argument("--log_datasets", action=argparse.BooleanOptionalAction, default=False) ####################################### ARGS = parser.parse_args() @@ -47,41 +43,27 @@ test_indices = all_indices[val_slice:test_slice] dataset.save( - f"{ARGS.output_root}/assets/" - "TCEC_game_collection_random_boards_train.jsonl", + f"{ARGS.output_root}/assets/" "TCEC_game_collection_random_boards_train.jsonl", indices=train_indices, ) dataset.save( - f"{ARGS.output_root}/assets/" - "TCEC_game_collection_random_boards_val.jsonl", + f"{ARGS.output_root}/assets/" "TCEC_game_collection_random_boards_val.jsonl", indices=val_indices, ) dataset.save( - f"{ARGS.output_root}/assets/" - "TCEC_game_collection_random_boards_test.jsonl", + f"{ARGS.output_root}/assets/" "TCEC_game_collection_random_boards_test.jsonl", indices=test_indices, ) if ARGS.log_datasets: wandb.login() - with wandb.init( - project="lczerolens-saes", job_type="make-datasets" - ) as run: + with wandb.init(project="lczerolens-saes", job_type="make-datasets") as run: artifact = wandb.Artifact("tcec_train", type="dataset") - artifact.add_file( - f"{ARGS.output_root}/assets/" - "TCEC_game_collection_random_boards_train.jsonl" - ) + artifact.add_file(f"{ARGS.output_root}/assets/" "TCEC_game_collection_random_boards_train.jsonl") run.log_artifact(artifact) artifact = wandb.Artifact("tcec_val", type="dataset") - artifact.add_file( - f"{ARGS.output_root}/assets/" - "TCEC_game_collection_random_boards_val.jsonl" - ) + artifact.add_file(f"{ARGS.output_root}/assets/" "TCEC_game_collection_random_boards_val.jsonl") run.log_artifact(artifact) artifact = wandb.Artifact("tcec_test", type="dataset") - artifact.add_file( - f"{ARGS.output_root}/assets/" - "TCEC_game_collection_random_boards_test.jsonl" - ) + artifact.add_file(f"{ARGS.output_root}/assets/" "TCEC_game_collection_random_boards_test.jsonl") run.log_artifact(artifact) diff --git a/scripts/register_wandb_models.py b/scripts/register_wandb_models.py index 0e3a9ea..61ae532 100644 --- a/scripts/register_wandb_models.py +++ b/scripts/register_wandb_models.py @@ -17,9 +17,7 @@ models = { "maia-1100": "maia-1100.onnx", } -parser.add_argument( - "--log_models", action=argparse.BooleanOptionalAction, default=False -) +parser.add_argument("--log_models", action=argparse.BooleanOptionalAction, default=False) ####################################### ARGS = parser.parse_args() diff --git a/scripts/sae_training.py b/scripts/sae_training.py index 6a443f5..cc847af 100644 --- a/scripts/sae_training.py +++ b/scripts/sae_training.py @@ -48,9 +48,7 @@ def sae_loss( ghost_grads = False if ghost_threshold is not None: if num_samples_since_activated is None: - raise ValueError( - "num_samples_since_activated must be provided for ghost grads" - ) + raise ValueError("num_samples_since_activated must be provided for ghost grads") ghost_mask = num_samples_since_activated > ghost_threshold if ghost_mask.sum() > 0: # if there are dead neurons ghost_grads = True @@ -79,13 +77,9 @@ def sae_loss( ) ghost_loss = t.nn.MSELoss()(residual.detach(), x_ghost).sqrt() - if ( - num_samples_since_activated is not None - ): # update the number of samples since each neuron was last activated + if num_samples_since_activated is not None: # update the number of samples since each neuron was last activated deads = (f == 0).all(dim=0) - num_samples_since_activated.copy_( - t.where(deads, num_samples_since_activated + 1, 0) - ) + num_samples_since_activated.copy_(t.where(deads, num_samples_since_activated + 1, 0)) if use_entropy: sparsity_loss = entropy(f) @@ -98,17 +92,11 @@ def sae_loss( if ghost_loss is None: out_losses["total_loss"] = classical_loss else: - out_losses["total_loss"] = classical_loss + ghost_loss * ( - mse_loss.detach() / (ghost_loss.detach() + EPS) - ) + out_losses["total_loss"] = classical_loss + ghost_loss * (mse_loss.detach() / (ghost_loss.detach() + EPS)) if explained_variance: - out_losses["explained_variance"] = explained_variance_score( - out_acts.detach().cpu(), x_hat.detach().cpu() - ) + out_losses["explained_variance"] = explained_variance_score(out_acts.detach().cpu(), x_hat.detach().cpu()) if r2: - out_losses["r2_score"] = r2_score( - out_acts.detach().cpu(), x_hat.detach().cpu() - ) + out_losses["r2_score"] = r2_score(out_acts.detach().cpu(), x_hat.detach().cpu()) return out_losses @@ -255,20 +243,14 @@ def lr_fn(step): ae, sparsity_penalty, entropy, - num_samples_since_activated=( - num_samples_since_activated - ), + num_samples_since_activated=(num_samples_since_activated), ghost_threshold=ghost_threshold, ) if wandb is not None: - wandb.log({f"train/{k}": l for k, l in losses.items()}) + wandb.log({f"train/{k}": v for k, v in losses.items()}) if do_print: print(f"[INFO] Train step {step}: {losses}") - if ( - save_steps is not None - and save_dir is not None - and step % save_steps == 0 - ): + if save_steps is not None and save_dir is not None and step % save_steps == 0: if not os.path.exists(os.path.join(save_dir, "checkpoints")): os.mkdir(os.path.join(save_dir, "checkpoints")) t.save( @@ -291,22 +273,18 @@ def lr_fn(step): ae, sparsity_penalty, use_entropy=entropy, - num_samples_since_activated=( - num_samples_since_activated - ), + num_samples_since_activated=(num_samples_since_activated), ghost_threshold=ghost_threshold, explained_variance=True, r2=True, ) - for k, _ in val_losses.items(): + for k in val_losses.keys(): val_losses[k] += losses[k] - for k, v in val_losses.items(): + for k in val_losses.keys(): val_losses[k] /= len(val_dataloader) if wandb is not None: - wandb.log( - {f"val/{k}": l for k, l in val_losses.items()} - ) + wandb.log({f"val/{k}": v for k, v in val_losses.items()}) if do_print: print(f"[INFO] Val step {step}: {val_losses}") diff --git a/scripts/sample_exploration.py b/scripts/sample_exploration.py index d10b0eb..a64c767 100644 --- a/scripts/sample_exploration.py +++ b/scripts/sample_exploration.py @@ -31,9 +31,7 @@ ####################################### -concept_dataset = ConceptDataset( - "./assets/TCEC_game_collection_random_boards_bestlegal_knight_10.jsonl" -) +concept_dataset = ConceptDataset("./assets/TCEC_game_collection_random_boards_bestlegal_knight_10.jsonl") lens = LrpLens() all_relevances = {} @@ -78,9 +76,7 @@ def collate_fn(batch): ) doc.packages.append(Package("xskak")) for elo, relevances in all_relevances.items(): - move = move_utils.decode_move( - label, (board.turn, not board.turn), board - ) + move = move_utils.decode_move(label, (board.turn, not board.turn), board) uci_move = move.uci() input_relevances = relevances[i] # type: ignore if not board.turn: @@ -88,12 +84,8 @@ def collate_fn(batch): input_relevances = input_relevances.view(112, 64) heatmap_str_list = [ create_heatmap_string(input_relevances.sum(dim=0), abs_max=True), - create_heatmap_string( - input_relevances[:12].sum(dim=0), abs_max=True - ), - create_heatmap_string( - input_relevances[104:].sum(dim=0), abs_max=True - ), + create_heatmap_string(input_relevances[:12].sum(dim=0), abs_max=True), + create_heatmap_string(input_relevances[104:].sum(dim=0), abs_max=True), ] heatmap_caption_list = [ "Total relevance", @@ -112,14 +104,11 @@ def collate_fn(batch): heatmap_str_list, current_piece_pos=uci_move[:2], next_move=uci_move[2:4], - caption=f"Sample {i} - Model ELO {elo} " - f"- {h0/total:.0f}%|{hist/total:.0f}%|{meta/total:.0f}%", + caption=f"Sample {i} - Model ELO {elo} " f"- {h0/total:.0f}%|{hist/total:.0f}%|{meta/total:.0f}%", heatmap_caption_list=heatmap_caption_list, ) doc.generate_pdf( - "scripts/results/exploration/" - f"{'best' if best_legal else 'full'}" - f"_{target}_{i}", + "scripts/results/exploration/" f"{'best' if best_legal else 'full'}" f"_{target}_{i}", clean_tex=True, ) diff --git a/scripts/simple_sae.py b/scripts/simple_sae.py index ceb2cd1..3321608 100644 --- a/scripts/simple_sae.py +++ b/scripts/simple_sae.py @@ -125,9 +125,7 @@ parser.add_argument("--act_batch_size", type=int, default=100) parser.add_argument("--model_name", type=str, default="maia-1100.onnx") # SAE training -parser.add_argument( - "--train_sae", action=argparse.BooleanOptionalAction, default=True -) +parser.add_argument("--train_sae", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--from_checkpoint", type=str, default=None) parser.add_argument("--freeze_dict", type=bool, default=False) parser.add_argument("--sae_module_name", type=str, default="block1/conv2/relu") @@ -154,9 +152,7 @@ parser.add_argument("--log_steps", type=int, default=50) parser.add_argument("--val_steps", type=int, default=200) # Test -parser.add_argument( - "--compute_evals", action=argparse.BooleanOptionalAction, default=True -) +parser.add_argument("--compute_evals", action=argparse.BooleanOptionalAction, default=True) ####################################### ARGS = parser.parse_args() @@ -221,13 +217,10 @@ collate_fn=BoardDataset.collate_fn_tuple, ) - os.makedirs( - f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}", exist_ok=True - ) + os.makedirs(f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}", exist_ok=True) save_file( activations, - f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" - f"{dataset_type}_activations.safetensors", + f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" f"{dataset_type}_activations.safetensors", ) @@ -247,9 +240,7 @@ def rearrange_activations(activations): return einops.rearrange(patches, "b c h w -> (b h w) c") def invert_rearrange_activations(activations): - patches = einops.rearrange( - activations, "(b h w) c -> b c h w", h=4, w=4 - ) + patches = einops.rearrange(activations, "(b h w) c -> b c h w", h=4, w=4) p1 = patches[:, :64] p2 = patches[:, 64:128].flip(dims=(3,)) p3 = patches[:, 128:192].flip(dims=(2,)) @@ -271,9 +262,7 @@ def rearrange_activations(activations): ph=ARGS.h_patch_size, pw=ARGS.w_patch_size, ) - return einops.rearrange( - split_batch, "b c h ph w pw -> (b h w) (c ph pw)" - ) + return einops.rearrange(split_batch, "b c h ph w pw -> (b h w) (c ph pw)") def invert_rearrange_activations(activations): split_batch = einops.rearrange( @@ -284,21 +273,17 @@ def invert_rearrange_activations(activations): ph=ARGS.h_patch_size, pw=ARGS.w_patch_size, ) - return einops.rearrange( - split_batch, "b c h ph w pw -> b c (h ph) (w pw)" - ) + return einops.rearrange(split_batch, "b c h ph w pw -> b c (h ph) (w pw)") if ARGS.train_sae: with safe_open( - f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" - "train_activations.safetensors", + f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" "train_activations.safetensors", framework="pt", ) as f: train_activations = f.get_tensor(ARGS.sae_module_name) with safe_open( - f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" - "val_activations.safetensors", + f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" "val_activations.safetensors", framework="pt", ) as f: val_activations = f.get_tensor(ARGS.sae_module_name) @@ -344,10 +329,7 @@ def invert_rearrange_activations(activations): freeze_dict=ARGS.freeze_dict, wandb=wandb, ) - model_path = ( - f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}" - f"/{ARGS.sae_module_name.replace('/', '_')}.pt" - ) + model_path = f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}" f"/{ARGS.sae_module_name.replace('/', '_')}.pt" torch.save( ae, @@ -367,13 +349,11 @@ def invert_rearrange_activations(activations): if ARGS.compute_evals: if not ARGS.train_sae: ae = torch.load( - f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" - f"{ARGS.sae_module_name.replace('/', '_')}.pt", + f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" f"{ARGS.sae_module_name.replace('/', '_')}.pt", map_location=torch.device(DEVICE), ) with safe_open( - f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" - "test_activations.safetensors", + f"{ARGS.output_root}/scripts/saes/{ARGS.model_name}/" "test_activations.safetensors", framework="pt", ) as f: test_activations = f.get_tensor(ARGS.sae_module_name) @@ -397,9 +377,7 @@ def invert_rearrange_activations(activations): out = ae(acts.to(DEVICE), output_features=True) f = out["features"] x_hat = out["x_hat"] - test_losses["explained_variance"] += explained_variance_score( - acts.cpu(), x_hat.cpu() - ) + test_losses["explained_variance"] += explained_variance_score(acts.cpu(), x_hat.cpu()) test_losses["r2_score"] += r2_score(acts.cpu(), x_hat.cpu()) feature_act_count += (f > 0).sum(dim=0).cpu() activated_features += (f > 0).sum().cpu() @@ -417,11 +395,8 @@ def invert_rearrange_activations(activations): ) wandb.log( # type: ignore { - "test/ativated_features": activated_features - / len(test_dataset), - "test/frac_activated_features": activated_features - / ae.dict_size - / len(test_dataset), + "test/ativated_features": activated_features / len(test_dataset), + "test/frac_activated_features": activated_features / ae.dict_size / len(test_dataset), } ) for k in test_losses.keys(): diff --git a/src/lczerolens/__init__.py b/src/lczerolens/__init__.py index 3701da4..0aeda17 100644 --- a/src/lczerolens/__init__.py +++ b/src/lczerolens/__init__.py @@ -1,10 +1,12 @@ -"""Main module for the lczerolens package. -""" +"""Main module for the lczerolens package.""" -__version__ = "0.1.3" +__version__ = "0.2.0" -from .game import BoardDataset, GameDataset, ModelWrapper -from .utils import board as board_utils -from .utils import move as move_utils +from .encodings import board as board_encodings +from .encodings import move as move_encodings +from .game import BoardDataset, GameDataset +from .model import ModelWrapper from .xai import Lens + +__all__ = ["BoardDataset", "GameDataset", "ModelWrapper", "Lens", "board_encodings", "move_encodings"] diff --git a/src/lczerolens/_native_builder/__init__.py b/src/lczerolens/_native_builder/__init__.py index ee1f581..46b469b 100644 --- a/src/lczerolens/_native_builder/__init__.py +++ b/src/lczerolens/_native_builder/__init__.py @@ -1,4 +1,5 @@ -"""Native builder for the lc0 models. -""" +"""Native builder for the lc0 models.""" from .builder import NativeBuilder + +__all__ = ["NativeBuilder"] diff --git a/src/lczerolens/_native_builder/builder.py b/src/lczerolens/_native_builder/builder.py index a8837be..d889ce5 100644 --- a/src/lczerolens/_native_builder/builder.py +++ b/src/lczerolens/_native_builder/builder.py @@ -1,5 +1,4 @@ -"""LCZero model builder. -""" +"""LCZero model builder.""" import os import re @@ -20,10 +19,7 @@ class BuilderError(Exception): class NativeBuilder: """Class for automatically building a model.""" - _module_exp = re.compile( - r"\/?(?P[a-z_\-]+)(?P[0-9]*)" - r"(\/(?P.*))?" - ) + _module_exp = re.compile(r"\/?(?P[a-z_\-]+)(?P[0-9]*)" r"(\/(?P.*))?") @staticmethod def _translate(name): @@ -71,9 +67,7 @@ def build_from_path(cls, model_path: str): elif model_path.endswith(".pt"): return cls.build_from_torch_path(model_path) else: - raise NotImplementedError( - f"Model path {model_path} is not supported." - ) + raise NotImplementedError(f"Model path {model_path} is not supported.") @classmethod def build_from_onnx_path(cls, onnx_model_path: str): @@ -81,9 +75,7 @@ def build_from_onnx_path(cls, onnx_model_path: str): Builds a model from a given path. """ if not os.path.exists(onnx_model_path): - raise FileExistsError( - f"Model path {onnx_model_path} does not exist." - ) + raise FileExistsError(f"Model path {onnx_model_path} does not exist.") try: onnx_model = safe_shape_inference(onnx_model_path) onnx_graph = OnnxGraph(onnx_model.graph) @@ -101,9 +93,7 @@ def build_from_onnx_path(cls, onnx_model_path: str): def make_onnx_td_forward(onnx_model): old_forward = onnx_model.forward output_node = list(onnx_model.graph.nodes)[-1] - output_names = [ - n.name.replace("output_", "") for n in output_node.all_input_nodes - ] + output_names = [n.name.replace("output_", "") for n in output_node.all_input_nodes] def td_forward(x): old_out = old_forward(x) @@ -120,9 +110,7 @@ def build_from_torch_path(cls, torch_model_path: str): Builds a model from a given path. """ if not os.path.exists(torch_model_path): - raise FileExistsError( - f"Model path {torch_model_path} does not exist." - ) + raise FileExistsError(f"Model path {torch_model_path} does not exist.") try: torch_model = torch.load(torch_model_path) except Exception: @@ -146,9 +134,7 @@ def _build_senet_from_onnx(cls, onnx_graph): for name, onnx_tensor in onnx_graph.initializers.items(): parsed_name = name.replace("/w", "") try: - module_name, module_index, remaining = cls._parse_remaining( - parsed_name - ) + module_name, module_index, remaining = cls._parse_remaining(parsed_name) except BuilderError: continue @@ -166,20 +152,14 @@ def _build_senet_from_onnx(cls, onnx_graph): submodule_index, subremaining, ) = cls._parse_remaining(remaining) - state_dict_name = ( - f"block{module_index}.se_layer." - f"linear{submodule_index}.{submodule_name}" - ) + state_dict_name = f"block{module_index}.se_layer." f"linear{submodule_index}.{submodule_name}" else: ( submodule_name, submodule_index, subremaining, ) = cls._parse_remaining(remaining) - state_dict_name = ( - f"block{module_index}." - f"{submodule_name}{submodule_index}.{subremaining}" - ) + state_dict_name = f"block{module_index}." f"{submodule_name}{submodule_index}.{subremaining}" elif module_name in ["mlh", "wdl", "policy", "value"]: if heads is None: heads = [module_name] @@ -190,10 +170,7 @@ def _build_senet_from_onnx(cls, onnx_graph): submodule_index, subremaining, ) = cls._parse_remaining(remaining) - state_dict_name = ( - f"{module_name}." - f"{submodule_name}{submodule_index}.{subremaining}" - ) + state_dict_name = f"{module_name}." f"{submodule_name}{submodule_index}.{subremaining}" elif module_name == "const": continue else: @@ -209,27 +186,17 @@ def _build_senet_from_onnx(cls, onnx_graph): if n_hidden is None: n_hidden = tmp_n_hidden elif n_hidden != tmp_n_hidden: - raise BuilderError( - "n_hidden mismatch: " f"{n_hidden} != {tmp_n_hidden}" - ) + raise BuilderError("n_hidden mismatch: " f"{n_hidden} != {tmp_n_hidden}") if n_hidden_red is None: n_hidden_red = tmp_n_hidden_red elif n_hidden_red != tmp_n_hidden_red: - raise BuilderError( - "n_hidden_red mismatch: " - f"{n_hidden_red} != {tmp_n_hidden_red}" - ) + raise BuilderError("n_hidden_red mismatch: " f"{n_hidden_red} != {tmp_n_hidden_red}") if state_dict_name == "value.linear2.bias": if torch_tensor.shape[0] != 1: convert_value_to_wdl = True state_dict[state_dict_name] = torch_tensor - if ( - n_hidden is None - or n_hidden_red is None - or heads is None - or n_blocks == 0 - ): + if n_hidden is None or n_hidden_red is None or heads is None or n_blocks == 0: raise BuilderError("Could not build SeNet from onnx graph.") if convert_value_to_wdl: diff --git a/src/lczerolens/_native_builder/senet.py b/src/lczerolens/_native_builder/senet.py index 364e8d2..8cebfa6 100644 --- a/src/lczerolens/_native_builder/senet.py +++ b/src/lczerolens/_native_builder/senet.py @@ -1,5 +1,4 @@ -"""Custom SE ResNet. -""" +"""Custom SE ResNet.""" import torch from tensordict import TensorDict @@ -52,9 +51,7 @@ def forward(self, x): out1, out2 = out.split(self.n_hidden, dim=1) non_lin = self.sigmoid(out1) out1 = MulUniformFunction.apply(x, non_lin) - return self.sum_layer( - torch.stack([out1, out2.repeat(1, 1, 8, 8)], dim=-1) - ) + return self.sum_layer(torch.stack([out1, out2.repeat(1, 1, 8, 8)], dim=-1)) class SeBlock(nn.Module): @@ -101,10 +98,7 @@ def forward(self, x): out = out.view(-1, 80 * 64) out = out.gather( 1, - torch.tensor(constants.GATHER_INDICES) - .unsqueeze(0) - .repeat(out.shape[0], 1) - .to(out.device), + torch.tensor(constants.GATHER_INDICES).unsqueeze(0).repeat(out.shape[0], 1).to(out.device), ) return out @@ -186,9 +180,7 @@ def forward(self, x): class SeNet(nn.Module): """ResNet model.""" - def __init__( - self, n_blocks, n_hidden, n_hidden_red=32, heads=None - ) -> None: + def __init__(self, n_blocks, n_hidden, n_hidden_red=32, heads=None) -> None: super().__init__() self.n_blocks = n_blocks self.n_hidden = n_hidden diff --git a/src/lczerolens/utils/__init__.py b/src/lczerolens/encodings/__init__.py similarity index 100% rename from src/lczerolens/utils/__init__.py rename to src/lczerolens/encodings/__init__.py diff --git a/src/lczerolens/utils/board.py b/src/lczerolens/encodings/board.py similarity index 84% rename from src/lczerolens/utils/board.py rename to src/lczerolens/encodings/board.py index 53624cf..8957329 100644 --- a/src/lczerolens/utils/board.py +++ b/src/lczerolens/encodings/board.py @@ -1,5 +1,4 @@ -"""Board utilities. -""" +"""Board utilities.""" import re from copy import deepcopy @@ -35,9 +34,7 @@ def get_plane_order(us_them: Tuple[bool, bool]): return plane_order -def get_piece_index( - piece: str, us_them: Tuple[bool, bool], plane_order: Optional[str] = None -): +def get_piece_index(piece: str, us_them: Tuple[bool, bool], plane_order: Optional[str] = None): """Converts a piece to its index in the plane order. Parameters @@ -81,9 +78,7 @@ def board_to_config_tensor( The 13x8x8 tensor. """ if input_encoding != InputEncoding.INPUT_CLASSICAL_112_PLANE: - raise NotImplementedError( - f"Input encoding {input_encoding} not implemented." - ) + raise NotImplementedError(f"Input encoding {input_encoding} not implemented.") if us_them is None: us = board.turn them = not us @@ -101,18 +96,12 @@ def piece_to_index(piece: str): ordered_fen = "".join(rev_rows) config_tensor = torch.zeros((13, 8, 8), dtype=torch.float) - ordinal_board = torch.tensor( - tuple(map(piece_to_index, ordered_fen)), dtype=torch.float - ) + ordinal_board = torch.tensor(tuple(map(piece_to_index, ordered_fen)), dtype=torch.float) ordinal_board = ordinal_board.reshape((8, 8)).unsqueeze(0) - piece_tensor = torch.tensor( - tuple(map(piece_to_index, plane_order)), dtype=torch.float - ) + piece_tensor = torch.tensor(tuple(map(piece_to_index, plane_order)), dtype=torch.float) piece_tensor = piece_tensor.reshape((12, 1, 1)) config_tensor[:12] = (ordinal_board == piece_tensor).float() - if board.is_repetition( - 2 - ): # Might be wrong if the full history is not available + if board.is_repetition(2): # Might be wrong if the full history is not available config_tensor[12] = torch.ones((8, 8), dtype=torch.float) return config_tensor if us == chess.WHITE else config_tensor.flip(1) @@ -139,9 +128,7 @@ def board_to_input_tensor( The 112x8x8 tensor. """ if input_encoding != InputEncoding.INPUT_CLASSICAL_112_PLANE: - raise NotImplementedError( - f"Input encoding {input_encoding} not implemented." - ) + raise NotImplementedError(f"Input encoding {input_encoding} not implemented.") board = deepcopy(last_board) input_tensor = torch.zeros((112, 8, 8), dtype=torch.float) us = last_board.turn @@ -164,8 +151,6 @@ def board_to_input_tensor( input_tensor[107] = torch.ones((8, 8), dtype=torch.float) if us == chess.BLACK: input_tensor[108] = torch.ones((8, 8), dtype=torch.float) - input_tensor[109] = ( - torch.ones((8, 8), dtype=torch.float) * last_board.halfmove_clock - ) + input_tensor[109] = torch.ones((8, 8), dtype=torch.float) * last_board.halfmove_clock input_tensor[111] = torch.ones((8, 8), dtype=torch.float) return input_tensor diff --git a/src/lczerolens/utils/constants.py b/src/lczerolens/encodings/constants.py similarity index 99% rename from src/lczerolens/utils/constants.py rename to src/lczerolens/encodings/constants.py index 28cefdc..ee2d2a9 100644 --- a/src/lczerolens/utils/constants.py +++ b/src/lczerolens/encodings/constants.py @@ -1889,14 +1889,10 @@ INVERTED_TO_INDEX[to_square] = [i] HISTORY_PLANE_NAMES = ( - [f"{piece} (us)" for piece in "PNBRQK"] - + [f"{piece} (them)" for piece in "pnbrqk"] - + ["repetition"] + [f"{piece} (us)" for piece in "PNBRQK"] + [f"{piece} (them)" for piece in "pnbrqk"] + ["repetition"] ) -PLANE_NAMES = [ - f"H{i}: {h_name}" for i in range(8) for h_name in HISTORY_PLANE_NAMES -] + [ +PLANE_NAMES = [f"H{i}: {h_name}" for i in range(8) for h_name in HISTORY_PLANE_NAMES] + [ "QCR (us)", "KCR (us)", "QCR (them)", diff --git a/src/lczerolens/utils/move.py b/src/lczerolens/encodings/move.py similarity index 88% rename from src/lczerolens/utils/move.py rename to src/lczerolens/encodings/move.py index f3a6c76..6ba64a7 100644 --- a/src/lczerolens/utils/move.py +++ b/src/lczerolens/encodings/move.py @@ -1,5 +1,4 @@ -"""Utils for the move module. -""" +"""Utils for the move module.""" from typing import Tuple @@ -26,9 +25,7 @@ def encode_move( to_square_row = to_square // 8 to_square_col = to_square % 8 to_square = 8 * (7 - to_square_row) + to_square_col - us_uci_move = ( - chess.SQUARE_NAMES[from_square] + chess.SQUARE_NAMES[to_square] - ) + us_uci_move = chess.SQUARE_NAMES[from_square] + chess.SQUARE_NAMES[to_square] if move.promotion is not None: if move.promotion == chess.BISHOP: us_uci_move += "b" @@ -62,8 +59,6 @@ def decode_move( uci_move = chess.SQUARE_NAMES[from_square] + chess.SQUARE_NAMES[to_square] from_piece = board.piece_at(from_square) - if ( - from_piece == chess.PAWN and to_square >= 56 - ): # Knight promotion is the default + if from_piece == chess.PAWN and to_square >= 56: # Knight promotion is the default uci_move += "n" return chess.Move.from_uci(uci_move) diff --git a/src/lczerolens/game/__init__.py b/src/lczerolens/game/__init__.py index 5abe8fa..c89ec4b 100644 --- a/src/lczerolens/game/__init__.py +++ b/src/lczerolens/game/__init__.py @@ -3,4 +3,6 @@ """ from .dataset import BoardDataset, GameDataset -from .wrapper import MlhFlow, ModelWrapper, PolicyFlow, ValueFlow, WdlFlow +from .play import WrapperSampler, SelfPlay, PolicySampler, BatchedPolicySampler + +__all__ = ["BoardDataset", "GameDataset", "WrapperSampler", "SelfPlay", "PolicySampler", "BatchedPolicySampler"] diff --git a/src/lczerolens/game/dataset.py b/src/lczerolens/game/dataset.py index 85cd050..d93f0bb 100644 --- a/src/lczerolens/game/dataset.py +++ b/src/lczerolens/game/dataset.py @@ -10,7 +10,7 @@ A class for representing an iterable dataset of boards. """ -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import chess import jsonlines @@ -18,9 +18,9 @@ import tqdm from torch.utils.data import Dataset -from lczerolens.utils import board as board_utils +from lczerolens.encodings import board as board_encodings -from .generate import Game +from .preprocess import Game, dict_to_game, game_to_boards class GameDataset(Dataset): @@ -45,29 +45,7 @@ def __init__( self.games = [] with jsonlines.open(file_name) as reader: for obj in reader: - *pre, post = obj["moves"].split("{ Book exit }") - if pre: - if len(pre) > 1: - raise ValueError("More than one book exit") - (pre,) = pre - parsed_pre_moves = [ - m for m in pre.split() if not m.endswith(".") - ] - book_exit = len(parsed_pre_moves) - else: - parsed_pre_moves = [] - book_exit = None - parsed_moves = parsed_pre_moves + [ - m for m in post.split() if not m.endswith(".") - ] - - self.games.append( - Game( - gameid=obj["gameid"], - moves=parsed_moves, - book_exit=book_exit, - ) - ) + self.games.append(dict_to_game(obj)) def __len__(self): return len(self.games) @@ -141,9 +119,7 @@ def save(self, file_name: str, n_history: int = 0, indices=None): writer.write( { "fen": working_board.root().fen(), - "moves": [ - move.uci() for move in working_board.move_stack - ], + "moves": [move.uci() for move in working_board.move_stack], "gameid": gameid, } ) @@ -160,81 +136,19 @@ def from_game_dataset( game_ids: List[str] = [] print("[INFO] Converting games to boards") for game in tqdm.tqdm(game_dataset.games, bar_format="{l_bar}{bar}"): - new_boards, new_ids = cls.game_to_board_list( - game, n_history, skip_book_exit, skip_first_n + new_boards = game_to_boards( + game, + n_history, + skip_book_exit, + skip_first_n, + output_dict=False, ) + new_ids = [game.gameid] * len(new_boards) boards.extend(new_boards) game_ids.extend(new_ids) return cls(boards=boards, game_ids=game_ids) - @staticmethod - def preprocess_game( - game: Game, - n_history: int = 0, - skip_book_exit: bool = False, - skip_first_n: int = 0, - ) -> List[Dict[str, Any]]: - working_board = chess.Board() - if skip_first_n > 0 or ( - skip_book_exit and (game.book_exit is not None) - ): - boards = [] - else: - boards = [ - { - "fen": working_board.fen(), - "moves": [], - "gameid": game.gameid, - } - ] - for i, move in enumerate( - game.moves[:-1] - ): # skip the last move as it can be over - working_board.push_san(move) - if (i < skip_first_n) or ( - skip_book_exit - and (game.book_exit is not None) - and (i < game.book_exit) - ): - continue - save_board = working_board.copy(stack=n_history) - boards.append( - { - "fen": save_board.root().fen(), - "moves": [move.uci() for move in save_board.move_stack], - "gameid": game.gameid, - } - ) - return boards - - @staticmethod - def game_to_board_list( - game: Game, - n_history: int = 0, - skip_book_exit: bool = False, - skip_first_n: int = 0, - ) -> Tuple[List[chess.Board], List[str]]: - working_board = chess.Board() - if skip_first_n > 0 or ( - skip_book_exit and (game.book_exit is not None) - ): - boards = [] - else: - boards = [working_board.copy(stack=n_history)] - for i, move in enumerate( - game.moves[:-1] - ): # skip the last move as it can be over - working_board.push_san(move) - if (i < skip_first_n) or ( - skip_book_exit - and (game.book_exit is not None) - and (i < game.book_exit) - ): - continue - boards.append(working_board.copy(stack=n_history)) - return boards, [game.gameid] * len(boards) - @staticmethod def collate_fn_tuple(batch): indices, boards = zip(*batch) @@ -242,9 +156,6 @@ def collate_fn_tuple(batch): @staticmethod def collate_fn_tensor(batch): - tensor_list = [ - board_utils.board_to_input_tensor(board).unsqueeze(0) - for board in batch - ] + tensor_list = [board_encodings.board_to_input_tensor(board).unsqueeze(0) for board in batch] batched_tensor = torch.cat(tensor_list, dim=0) return batched_tensor diff --git a/src/lczerolens/game/generate.py b/src/lczerolens/game/generate.py deleted file mode 100644 index 838d682..0000000 --- a/src/lczerolens/game/generate.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Classes for generating games. - -Classes -------- -Game - A class for representing a game. -GameGenerator - A class for generating games. -""" - -from dataclasses import dataclass -from typing import List, Optional - -from .search import SearchAlgorithm - - -@dataclass -class Game: - gameid: str - moves: List[str] - book_exit: Optional[int] = None - - -class GameGenerator: - """A class for generating games.""" - - def __init__(self, white: SearchAlgorithm, black: SearchAlgorithm): - """ - Initializes the game generator. - """ - self.white = white - self.black = black - - def play(self): - """ - Plays a game. - """ - raise NotImplementedError diff --git a/src/lczerolens/game/play.py b/src/lczerolens/game/play.py new file mode 100644 index 0000000..484ac0c --- /dev/null +++ b/src/lczerolens/game/play.py @@ -0,0 +1,215 @@ +"""Classes for playing.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Callable, List + +import chess +import torch +from torch.distributions import Categorical + +from lczerolens.encodings import move as move_encodings +from lczerolens.model.wrapper import ModelWrapper + + +def get_next_legal_boards(board: chess.Board): + working_board = board.copy(stack=7) + legal_moves = working_board.legal_moves + next_legal_boards = [] + for move in legal_moves: + working_board.push(move) + next_legal_boards.append(working_board.copy(stack=7)) + working_board.pop() + return legal_moves, next_legal_boards + + +class Sampler(ABC): + @abstractmethod + def get_next_move(self, board: chess.Board): + pass + + +@dataclass +class WrapperSampler(Sampler): + wrapper: ModelWrapper + use_argmax: bool = True + alpha: float = 1.0 + beta: float = 1.0 + gamma: float = 1.0 + draw_score: float = 0.0 + m_max: float = 0.0345 + m_slope: float = 0.0027 + k_0: float = 0.0 + k_1: float = 1.6521 + k_2: float = -0.6521 + q_threshold: float = 0.8 + + @torch.no_grad + def get_utility( + self, + board: chess.Board, + ): + to_log = {} + legal_moves, next_legal_boards = get_next_legal_boards(board) + all_stats = self.wrapper.predict([board, *next_legal_boards])[0] + utility = 0 + q_values = self._get_q_values(all_stats, to_log) + utility += self.alpha * q_values + utility += self.beta * self._get_m_values(all_stats, q_values, to_log) + us = board.turn + utility += self.gamma * self._get_p_values(all_stats, legal_moves, us, to_log) + to_log["max_utility"] = utility.max().item() + return utility, legal_moves, to_log + + def _get_q_values(self, all_stats, to_log): + if "value" in all_stats.keys(): + to_log["value"] = all_stats["value"][0].item() + return all_stats["value"][1:, 0] + elif "wdl" in all_stats.keys(): + to_log["wdl_w"] = all_stats["wdl"][0][0].item() + to_log["wdl_d"] = all_stats["wdl"][0][1].item() + to_log["wdl_l"] = all_stats["wdl"][0][2].item() + scores = torch.tensor([1, self.draw_score, -1]) + return all_stats["wdl"][1:] @ scores + else: + return torch.zeros(all_stats.batch_size[0] - 1) + + def _get_m_values(self, all_stats, q_values, to_log): + if "mlh" in all_stats.keys(): + to_log["mlh"] = all_stats["mlh"][0].item() + delta_m_values = self.m_slope * (all_stats["mlh"][1:, 0] - all_stats["mlh"][0, 0]) + delta_m_values.clamp_(-self.m_max, self.m_max) + scaled_q_values = torch.relu(q_values.abs() - self.q_threshold) / (1 - self.q_threshold) + poly_q_values = self.k_0 + self.k_1 * scaled_q_values + self.k_2 * scaled_q_values**2 + return -q_values.sign() * delta_m_values * poly_q_values + else: + return torch.zeros(all_stats.batch_size[0] - 1) + + def _get_p_values( + self, + all_stats, + legal_moves, + us, + to_log, + ): + if "policy" in all_stats.keys(): + indices = torch.tensor([move_encodings.encode_move(move, (us, not us)) for move in legal_moves]) + legal_policy = all_stats["policy"][0].gather(0, indices) + to_log["max_legal_policy"] = legal_policy.max().item() + return legal_policy + else: + return torch.zeros(all_stats.batch_size[0] - 1) + + def get_next_move(self, board: chess.Board): + utility, legal_moves, to_log = self.get_utility(board) + if self.use_argmax: + idx = utility.argmax() + else: + m = Categorical(logits=utility) + idx = m.sample() + return list(legal_moves)[idx], to_log + + +class PolicySampler(WrapperSampler): + @torch.no_grad + def get_utility( + self, + board: chess.Board, + ): + to_log = {} + legal_moves = board.legal_moves + all_stats = self.wrapper.predict([board])[0] + us = board.turn + utility = self._get_p_values(all_stats, legal_moves, us, to_log) + to_log["max_utility"] = utility.max().item() + if "value" in all_stats.keys(): + to_log["value"] = all_stats["value"][0].item() + elif "wdl" in all_stats.keys(): + to_log["wdl_w"] = all_stats["wdl"][0][0].item() + to_log["wdl_d"] = all_stats["wdl"][0][1].item() + to_log["wdl_l"] = all_stats["wdl"][0][2].item() + return utility, legal_moves, to_log + + def _get_p_values( + self, + all_stats, + legal_moves, + us, + to_log, + ): + if "policy" in all_stats.keys(): + indices = torch.tensor([move_encodings.encode_move(move, (us, not us)) for move in legal_moves]) + legal_policy = all_stats["policy"][0].gather(0, indices) + to_log["max_legal_policy"] = legal_policy.max().item() + return legal_policy + else: + return torch.zeros(len(legal_moves)) + + +@dataclass +class SelfPlay: + """A class for generating games.""" + + white: Sampler + black: Sampler + + def play( + self, + board: Optional[chess.Board] = None, + max_moves: int = 100, + to_play: chess.Color = chess.WHITE, + report_fn: Optional[Callable[[dict, chess.Color], None]] = None, + ): + """ + Plays a game. + """ + if board is None: + board = chess.Board() + game = [] + if to_play == chess.BLACK: + move, _ = self.black.get_next_move(board) + board.push(move) + game.append(move) + for _ in range(max_moves): + if board.is_game_over() or len(game) >= max_moves: + break + move, to_log = self.white.get_next_move(board) + if report_fn is not None: + report_fn(to_log, board.turn) + board.push(move) + game.append(move) + + if board.is_game_over() or len(game) >= max_moves: + break + move, to_log = self.black.get_next_move(board) + if report_fn is not None: + report_fn(to_log, board.turn) + board.push(move) + game.append(move) + if board.is_game_over() or len(game) >= max_moves: + break + return game, board + + +@dataclass +class BatchedPolicySampler: + wrapper: ModelWrapper + use_argmax: bool = True + + @torch.no_grad + def get_next_moves( + self, + boards: List[chess.Board], + ): + all_stats = self.wrapper.predict(boards)[0] + for board, policy in zip(boards, all_stats["policy"]): + us = board.turn + indices = torch.tensor([move_encodings.encode_move(move, (us, not us)) for move in board.legal_moves]) + legal_policy = all_stats["policy"][0].gather(0, indices) + if self.use_argmax: + idx = legal_policy.argmax() + else: + m = Categorical(logits=legal_policy) + print(m.probs) + idx = m.sample() + yield list(board.legal_moves)[idx] diff --git a/src/lczerolens/game/preprocess.py b/src/lczerolens/game/preprocess.py new file mode 100644 index 0000000..59465b3 --- /dev/null +++ b/src/lczerolens/game/preprocess.py @@ -0,0 +1,82 @@ +"""Preproces functions for chess games. + +Classes +------- +Game + A class for representing a game. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import chess + + +@dataclass +class Game: + gameid: str + moves: List[str] + book_exit: Optional[int] = None + + +def dict_to_game(obj: Dict[str, str]) -> Game: + if "moves" not in obj: + ValueError("The dict should contain `moves`.") + if "gameid" not in obj: + ValueError("The dict should contain `gameid`.") + *pre, post = obj["moves"].split("{ Book exit }") + if pre: + if len(pre) > 1: + raise ValueError("More than one book exit") + (pre,) = pre + parsed_pre_moves = [m for m in pre.split() if not m.endswith(".")] + book_exit = len(parsed_pre_moves) + else: + parsed_pre_moves = [] + book_exit = None + parsed_moves = parsed_pre_moves + [m for m in post.split() if not m.endswith(".")] + return Game( + gameid=obj["gameid"], + moves=parsed_moves, + book_exit=book_exit, + ) + + +def game_to_boards( + game: Game, + n_history: int = 0, + skip_book_exit: bool = False, + skip_first_n: int = 0, + output_dict=True, +) -> List[Union[Dict[str, Any], chess.Board]]: + working_board = chess.Board() + if skip_first_n > 0 or (skip_book_exit and (game.book_exit is not None)): + boards = [] + else: + if output_dict: + boards = [ + { + "fen": working_board.fen(), + "moves": [], + "gameid": game.gameid, + } + ] + else: + boards = [working_board.copy(stack=n_history)] + + for i, move in enumerate(game.moves[:-1]): # skip the last move as it can be over + working_board.push_san(move) + if (i < skip_first_n) or (skip_book_exit and (game.book_exit is not None) and (i < game.book_exit)): + continue + if output_dict: + save_board = working_board.copy(stack=n_history) + boards.append( + { + "fen": save_board.root().fen(), + "moves": [move.uci() for move in save_board.move_stack], + "gameid": game.gameid, + } + ) + else: + boards.append(working_board.copy(stack=n_history)) + return boards diff --git a/src/lczerolens/game/search.py b/src/lczerolens/game/search.py deleted file mode 100644 index 57de855..0000000 --- a/src/lczerolens/game/search.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Classes for search algorithms. -""" - - -class Node: - pass - - -class SearchAlgorithm: - pass - - -class MCTS(SearchAlgorithm): - pass - - -class Heuristic(SearchAlgorithm): - pass diff --git a/src/lczerolens/model/__init__.py b/src/lczerolens/model/__init__.py new file mode 100644 index 0000000..6ef44ad --- /dev/null +++ b/src/lczerolens/model/__init__.py @@ -0,0 +1,7 @@ +""" +Import the model module. +""" + +from .wrapper import MlhFlow, ModelWrapper, PolicyFlow, ValueFlow, WdlFlow + +__all__ = ["ModelWrapper", "MlhFlow", "PolicyFlow", "ValueFlow", "WdlFlow"] diff --git a/src/lczerolens/utils/lczero.py b/src/lczerolens/model/lczero.py similarity index 81% rename from src/lczerolens/utils/lczero.py rename to src/lczerolens/model/lczero.py index 5ff53ac..3b98213 100644 --- a/src/lczerolens/utils/lczero.py +++ b/src/lczerolens/model/lczero.py @@ -11,15 +11,12 @@ import chess import torch -from lczerolens import move_utils +from lczerolens.encodings import move as move_encodings try: from lczero.backends import Backend, GameState except ImportError as e: - raise ImportError( - "LCZero bindings are not installed." - "See https://github.com/LeelaChessZero/lc0.git." - ) from e + raise ImportError("LCZero bindings are not installed." "See https://github.com/LeelaChessZero/lc0.git.") from e def generic_command(args, verbose=False): @@ -34,9 +31,7 @@ def generic_command(args, verbose=False): popen.wait() if popen.returncode != 0: if verbose: - stderr = ( - f'\n[DEBUG] stderr:\n{popen.stderr.read().decode("utf-8")}' - ) + stderr = f'\n[DEBUG] stderr:\n{popen.stderr.read().decode("utf-8")}' else: stderr = "" raise RuntimeError(f"Could not run `lc0 {' '.join(args)}`." + stderr) @@ -70,9 +65,7 @@ def convert_to_leela(in_path, out_path, verbose=False): ) -def board_from_backend( - lczero_backend: Backend, lczero_game: GameState, planes: int = 112 -): +def board_from_backend(lczero_backend: Backend, lczero_game: GameState, planes: int = 112): """ Create a board from the lczero backend. """ @@ -104,13 +97,9 @@ def prediction_from_backend( else: indices = torch.tensor(range(1858)) if softmax: - policy = torch.tensor( - lczero_output.p_softmax(*range(1858)), dtype=torch.float - ) + policy = torch.tensor(lczero_output.p_softmax(*range(1858)), dtype=torch.float) else: - policy = torch.tensor( - lczero_output.p_raw(*range(1858)), dtype=torch.float - ) + policy = torch.tensor(lczero_output.p_raw(*range(1858)), dtype=torch.float) value = torch.tensor(lczero_output.q()) filtered_policy[indices] = policy[indices] return filtered_policy, value @@ -132,12 +121,10 @@ def moves_with_castling_swap(lczero_game: GameState, board: chess.Board): lczero_legal_moves.remove(leela_uci_move) lczero_legal_moves.append(uci_move) lczero_policy_indices.remove( - move_utils.encode_move( + move_encodings.encode_move( chess.Move.from_uci(leela_uci_move), (board.turn, not board.turn), ) ) - lczero_policy_indices.append( - move_utils.encode_move(move, (board.turn, not board.turn)) - ) + lczero_policy_indices.append(move_encodings.encode_move(move, (board.turn, not board.turn))) return lczero_legal_moves, lczero_policy_indices diff --git a/src/lczerolens/game/wrapper.py b/src/lczerolens/model/wrapper.py similarity index 83% rename from src/lczerolens/game/wrapper.py rename to src/lczerolens/model/wrapper.py index 9bf0c09..11d2453 100644 --- a/src/lczerolens/game/wrapper.py +++ b/src/lczerolens/model/wrapper.py @@ -1,5 +1,4 @@ -"""Class for wrapping the LCZero models. -""" +"""Class for wrapping the LCZero models.""" import os from typing import Dict, Iterable, Type, Union @@ -11,7 +10,7 @@ from tensordict import TensorDict from torch import nn -from lczerolens.utils import board as board_utils +from lczerolens.encodings import board as board_encodings class ModelWrapper(nn.Module): @@ -48,9 +47,7 @@ def from_path(cls, model_path: str): elif model_path.endswith(".pt"): return cls.from_torch_path(model_path) else: - raise NotImplementedError( - f"Model path {model_path} is not supported." - ) + raise NotImplementedError(f"Model path {model_path} is not supported.") @classmethod def from_onnx_path(cls, onnx_model_path: str, check: bool = True): @@ -58,16 +55,12 @@ def from_onnx_path(cls, onnx_model_path: str, check: bool = True): Builds a model from a given path. """ if not os.path.exists(onnx_model_path): - raise FileExistsError( - f"Model path {onnx_model_path} does not exist." - ) + raise FileExistsError(f"Model path {onnx_model_path} does not exist.") try: if check: onnx_model = safe_shape_inference(onnx_model_path) onnx_torch_model = convert(onnx_model) - onnx_torch_model.forward = cls.make_onnx_td_forward( - onnx_torch_model - ) + onnx_torch_model.forward = cls.make_onnx_td_forward(onnx_torch_model) return cls(model=onnx_torch_model) except Exception: raise ValueError(f"Could not load model at {onnx_model_path}.") @@ -76,9 +69,7 @@ def from_onnx_path(cls, onnx_model_path: str, check: bool = True): def make_onnx_td_forward(onnx_model): old_forward = onnx_model.forward output_node = list(onnx_model.graph.nodes)[-1] - output_names = [ - n.name.replace("output_", "") for n in output_node.all_input_nodes - ] + output_names = [n.name.replace("output_", "") for n in output_node.all_input_nodes] def td_forward(x): old_out = old_forward(x) @@ -95,9 +86,7 @@ def from_torch_path(cls, torch_model_path: str): Builds a model from a given path. """ if not os.path.exists(torch_model_path): - raise FileExistsError( - f"Model path {torch_model_path} does not exist." - ) + raise FileExistsError(f"Model path {torch_model_path} does not exist.") try: torch_model = torch.load(torch_model_path) except Exception: @@ -124,10 +113,7 @@ def predict( else: raise ValueError("Invalid input type.") - tensor_list = [ - board_utils.board_to_input_tensor(board).unsqueeze(0) - for board in board_list - ] + tensor_list = [board_encodings.board_to_input_tensor(board).unsqueeze(0) for board in board_list] batched_tensor = torch.cat(tensor_list, dim=0) if input_requires_grad: batched_tensor.requires_grad = True @@ -151,9 +137,7 @@ def __init__( model: nn.Module, ): if not self.is_compatible(model): - raise ValueError( - f"The model does not have a {self._flow_type} head." - ) + raise ValueError(f"The model does not have a {self._flow_type} head.") super().__init__(model=model) @classmethod @@ -178,9 +162,7 @@ def get_subclass(cls, name: str) -> Type["Flow"]: @classmethod def is_compatible(cls, model: nn.Module): - return hasattr(model, cls._flow_type) or hasattr( - model, f"output/{cls._flow_type}" - ) + return hasattr(model, cls._flow_type) or hasattr(model, f"output/{cls._flow_type}") def forward(self, x): """Forward pass.""" diff --git a/src/lczerolens/xai/__init__.py b/src/lczerolens/xai/__init__.py index 830350d..5221e86 100644 --- a/src/lczerolens/xai/__init__.py +++ b/src/lczerolens/xai/__init__.py @@ -1,5 +1,4 @@ -"""XAI module. -""" +"""XAI module.""" from .concept import ( AndBinaryConcept, @@ -24,3 +23,23 @@ PolicyLens, ProbingLens, ) + +__all__ = [ + "Lens", + "ConceptDataset", + "BinaryConcept", + "AndBinaryConcept", + "OrBinaryConcept", + "HasPieceConcept", + "HasThreatConcept", + "HasMateThreatConcept", + "HasMaterialAdvantageConcept", + "BestLegalMoveConcept", + "PieceBestLegalMoveConcept", + "ActivationLens", + "CrpLens", + "LrpLens", + "PatchingLens", + "PolicyLens", + "ProbingLens", +] diff --git a/src/lczerolens/xai/concept.py b/src/lczerolens/xai/concept.py index 1add36b..6b3c3bf 100644 --- a/src/lczerolens/xai/concept.py +++ b/src/lczerolens/xai/concept.py @@ -1,5 +1,4 @@ -"""Class for concept-based XAI methods. -""" +"""Class for concept-based XAI methods.""" import random from abc import ABC, abstractmethod @@ -11,8 +10,8 @@ import tqdm from sklearn import metrics +from lczerolens.encodings import board as board_encodings from lczerolens.game.dataset import BoardDataset -from lczerolens.utils import board as board_utils class Concept(ABC): @@ -135,12 +134,8 @@ def compute_metrics( """ return { "accuracy": metrics.accuracy_score(labels, predictions), - "precision": metrics.precision_score( - labels, predictions, average="weighted" - ), - "recall": metrics.recall_score( - labels, predictions, average="weighted" - ), + "precision": metrics.precision_score(labels, predictions, average="weighted"), + "recall": metrics.recall_score(labels, predictions, average="weighted"), "f1": metrics.f1_score(labels, predictions, average="weighted"), } @@ -202,8 +197,7 @@ def __init__( elif not hasattr(self, "labels"): print("[INFO] Computing labels") self.labels = [ - self._concept.compute_label(board) - for board in tqdm.tqdm(self.boards, bar_format="{l_bar}{bar}") + self._concept.compute_label(board) for board in tqdm.tqdm(self.boards, bar_format="{l_bar}{bar}") ] def __getitem__(self, idx) -> Tuple[int, chess.Board, Any]: # type: ignore @@ -229,9 +223,7 @@ def save(self, file_name: str, n_history: int = 0, indices=None): writer.write( { "fen": working_board.root().fen(), - "moves": [ - move.uci() for move in working_board.move_stack - ], + "moves": [move.uci() for move in working_board.move_stack], "gameid": gameid, "label": label, } @@ -246,21 +238,15 @@ def set_concept(self, concept: Concept, **pbar_kwargs): print("[INFO] Computing labels") self.labels = [ self._concept.compute_label(board) - for board in tqdm.tqdm( - self.boards, bar_format="{l_bar}{bar}", **pbar_kwargs - ) + for board in tqdm.tqdm(self.boards, bar_format="{l_bar}{bar}", **pbar_kwargs) ] @classmethod - def from_board_dataset( - cls, board_dataset: BoardDataset, concept: Concept, **pbar_kwargs - ): + def from_board_dataset(cls, board_dataset: BoardDataset, concept: Concept, **pbar_kwargs): print("[INFO] Computing labels") labels = [ concept.compute_label(board) - for board in tqdm.tqdm( - board_dataset.boards, bar_format="{l_bar}{bar}", **pbar_kwargs - ) + for board in tqdm.tqdm(board_dataset.boards, bar_format="{l_bar}{bar}", **pbar_kwargs) ] return cls( boards=board_dataset.boards, @@ -277,9 +263,7 @@ def collate_fn_tuple(batch): @staticmethod def collate_fn_tensor(batch): indices, boards, labels = zip(*batch) - tensor_list = [ - board_utils.board_to_input_tensor(board) for board in boards - ] + tensor_list = [board_encodings.board_to_input_tensor(board) for board in boards] batched_tensor = torch.stack(tensor_list, dim=0) return tuple(indices), batched_tensor, tuple(labels) @@ -287,9 +271,7 @@ def filter_(self, filter_fn: Callable): tuple_boards, tuple_labels, tuple_game_ids = zip( *[ (board, label, game_id) - for board, label, game_id in zip( - self.boards, self.labels, self.game_ids - ) + for board, label, game_id in zip(self.boards, self.labels, self.game_ids) if filter_fn(board, label, game_id) ] ) diff --git a/src/lczerolens/xai/concepts/__init__.py b/src/lczerolens/xai/concepts/__init__.py index 4f43a36..e78013e 100644 --- a/src/lczerolens/xai/concepts/__init__.py +++ b/src/lczerolens/xai/concepts/__init__.py @@ -1,6 +1,14 @@ -"""Concepts module. -""" +"""Concepts module.""" from .material import HasMaterialAdvantageConcept, HasPieceConcept from .move import BestLegalMoveConcept, PieceBestLegalMoveConcept from .threat import HasMateThreatConcept, HasThreatConcept + +__all__ = [ + "HasPieceConcept", + "HasThreatConcept", + "HasMateThreatConcept", + "HasMaterialAdvantageConcept", + "BestLegalMoveConcept", + "PieceBestLegalMoveConcept", +] diff --git a/src/lczerolens/xai/concepts/material.py b/src/lczerolens/xai/concepts/material.py index f5fb903..04d5fc7 100644 --- a/src/lczerolens/xai/concepts/material.py +++ b/src/lczerolens/xai/concepts/material.py @@ -1,5 +1,4 @@ -"""All concepts related to material. -""" +"""All concepts related to material.""" from typing import Dict, Optional diff --git a/src/lczerolens/xai/concepts/move.py b/src/lczerolens/xai/concepts/move.py index f76aa50..5df7bbd 100644 --- a/src/lczerolens/xai/concepts/move.py +++ b/src/lczerolens/xai/concepts/move.py @@ -1,11 +1,10 @@ -"""All concepts related to move. -""" +"""All concepts related to move.""" import chess import torch -from lczerolens.game.wrapper import ModelWrapper, PolicyFlow -from lczerolens.utils import move as move_utils +from lczerolens.encodings import move as move_encodings +from lczerolens.model.wrapper import ModelWrapper, PolicyFlow from lczerolens.xai.concept import BinaryConcept, MulticlassConcept @@ -28,8 +27,7 @@ def compute_label( policy = torch.softmax(policy.squeeze(0), dim=-1) legal_move_indices = [ - move_utils.encode_move(move, (board.turn, not board.turn)) - for move in board.legal_moves + move_encodings.encode_move(move, (board.turn, not board.turn)) for move in board.legal_moves ] sub_index = policy[legal_move_indices].argmax().item() return legal_move_indices[sub_index] @@ -56,10 +54,7 @@ def compute_label( policy = torch.softmax(policy.squeeze(0), dim=-1) legal_moves = list(board.legal_moves) - legal_move_indices = [ - move_utils.encode_move(move, (board.turn, not board.turn)) - for move in legal_moves - ] + legal_move_indices = [move_encodings.encode_move(move, (board.turn, not board.turn)) for move in legal_moves] sub_index = policy[legal_move_indices].argmax().item() best_legal_move = legal_moves[sub_index] if board.piece_at(best_legal_move.from_square) == self.piece: diff --git a/src/lczerolens/xai/concepts/threat.py b/src/lczerolens/xai/concepts/threat.py index ac67ed2..ba64b9a 100644 --- a/src/lczerolens/xai/concepts/threat.py +++ b/src/lczerolens/xai/concepts/threat.py @@ -1,5 +1,4 @@ -"""All concepts related to threats. -""" +"""All concepts related to threats.""" import chess diff --git a/src/lczerolens/xai/helpers/crp.py b/src/lczerolens/xai/helpers/crp.py index 7264d41..53137a3 100644 --- a/src/lczerolens/xai/helpers/crp.py +++ b/src/lczerolens/xai/helpers/crp.py @@ -1,5 +1,4 @@ -"""Helpers to modify the default classes. -""" +"""Helpers to modify the default classes.""" import warnings from collections.abc import Iterable @@ -125,9 +124,7 @@ def collate_fn(batch): samples_batch = samples[b * batch_size : (b + 1) * batch_size] data_batch, targets_samples = batch - targets_samples = np.array( - targets_samples - ) # numpy operation needed + targets_samples = np.array(targets_samples) # numpy operation needed # convert multi target to single target if user defined the method data_broadcast, targets, sample_indices = [], [], [] @@ -152,9 +149,7 @@ def collate_fn(batch): samples_batch, ) - conditions = [ - {self.attribution.MODEL_OUTPUT_NAME: [t]} for t in targets - ] + conditions = [{self.attribution.MODEL_OUTPUT_NAME: [t]} for t in targets] # dict_inputs is linked to FeatHooks dict_inputs["sample_indices"] = sample_indices dict_inputs["targets"] = targets @@ -162,9 +157,7 @@ def collate_fn(batch): # composites are already registered before if on_device: data_broadcast = data_broadcast.to(on_device) # type: ignore - self.attribution( - data_broadcast, conditions, None, exclude_parallel=False - ) + self.attribution(data_broadcast, conditions, None, exclude_parallel=False) if b % checkpoint == checkpoint - 1: self._save_results((last_checkpoint, sample_indices[-1] + 1)) @@ -196,13 +189,9 @@ def get_max_reference( if not isinstance(concept_ids, Iterable): concept_ids = [concept_ids] if mode == "relevance": - d_c_sorted, _, rf_c_sorted = load_maximization( - self.RelMax.PATH, layer_name - ) + d_c_sorted, _, rf_c_sorted = load_maximization(self.RelMax.PATH, layer_name) elif mode == "activation": - d_c_sorted, _, rf_c_sorted = load_maximization( - self.ActMax.PATH, layer_name - ) + d_c_sorted, _, rf_c_sorted = load_maximization(self.ActMax.PATH, layer_name) else: raise ValueError("`mode` must be `relevance` or `activation`") @@ -308,9 +297,7 @@ def collate_fn(batch): if composite: data_batch = torch.cat(data_batch_list, dim=0) data_p = self.preprocess_data(data_batch) - heatmaps = self._attribution_on_reference( - data_p, c_id, layer_name, composite, rf, n_indices, batch_size - ) + heatmaps = self._attribution_on_reference(data_p, c_id, layer_name, composite, rf, n_indices, batch_size) if callable(plot_fn): return plot_fn(data_batch.detach(), heatmaps.detach(), rf) diff --git a/src/lczerolens/xai/helpers/lrp.py b/src/lczerolens/xai/helpers/lrp.py index 30f5711..47a6246 100644 --- a/src/lczerolens/xai/helpers/lrp.py +++ b/src/lczerolens/xai/helpers/lrp.py @@ -84,14 +84,8 @@ def backward(ctx, *grad_outputs): out_relevance = out_relevance / stabilize(2 * outputs, epsilon) - relevance_a = ( - torch.matmul(out_relevance, input_b.permute(0, 1, -1, -2)) - * input_a - ) - relevance_b = ( - torch.matmul(input_a.permute(0, 1, -1, -2), out_relevance) - * input_b - ) + relevance_a = torch.matmul(out_relevance, input_b.permute(0, 1, -1, -2)) * input_a + relevance_b = torch.matmul(input_a.permute(0, 1, -1, -2), out_relevance) * input_b return relevance_a, relevance_b, None @@ -134,8 +128,6 @@ def forward(ctx, inputs, dim): def backward(ctx, *grad_outputs): inputs, output = ctx.saved_tensors - relevance = ( - grad_outputs[0] - (output * grad_outputs[0].sum(-1, keepdim=True)) - ) * inputs + relevance = (grad_outputs[0] - (output * grad_outputs[0].sum(-1, keepdim=True))) * inputs return (relevance, None) diff --git a/src/lczerolens/xai/helpers/sae.py b/src/lczerolens/xai/helpers/sae.py index c60d32f..4580848 100644 --- a/src/lczerolens/xai/helpers/sae.py +++ b/src/lczerolens/xai/helpers/sae.py @@ -12,9 +12,7 @@ class AutoEncoder(nn.Module): A 3-layers autoencoder. """ - def __init__( - self, activation_dim, dict_size, pre_bias=False, less_than_1=False - ): + def __init__(self, activation_dim, dict_size, pre_bias=False, less_than_1=False): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size @@ -58,9 +56,7 @@ def normalize_dict_(self, less_than_1): D_norm = self.D.norm(dim=1) if less_than_1: greater_than_1_mask = D_norm > 1 - self.D[greater_than_1_mask] /= D_norm[ - greater_than_1_mask - ].unsqueeze(1) + self.D[greater_than_1_mask] /= D_norm[greater_than_1_mask].unsqueeze(1) else: self.D /= D_norm.unsqueeze(1) diff --git a/src/lczerolens/xai/hook.py b/src/lczerolens/xai/hook.py index b6a77bf..e2a32b7 100644 --- a/src/lczerolens/xai/hook.py +++ b/src/lczerolens/xai/hook.py @@ -1,5 +1,4 @@ -"""Generic hook classes. -""" +"""Generic hook classes.""" import re from abc import ABC, abstractmethod @@ -34,7 +33,7 @@ class HookMode(str, Enum): @dataclass -class HookConfig(ABC): +class HookConfig: """ Configuration for hooks. """ @@ -43,7 +42,7 @@ class HookConfig(ABC): hook_mode: HookMode = HookMode.OUTPUT module_exp: Optional[str] = None data: Optional[Dict[str, Any]] = None - data_fn: Optional[Callable] = None + data_fn: Optional[Callable[[torch.Tensor, Any], torch.Tensor]] = None class Hook(ABC): @@ -72,13 +71,9 @@ def register(self, module: torch.nn.Module): if not compiled_exp.match(name): continue if self.config.hook_type is HookType.FORWARD: - self.removable_handles.append( - module.register_forward_hook(self.forward_factory(name)) - ) + self.removable_handles.append(module.register_forward_hook(self.forward_factory(name))) elif self.config.hook_type is HookType.BACKWARD: - self.removable_handles.append( - module.register_backward_hook(self.backward_factory(name)) - ) + self.removable_handles.append(module.register_backward_hook(self.backward_factory(name))) else: raise ValueError(f"Unknown hook type: {self.config.hook_type}") return self.removable_handles @@ -90,7 +85,6 @@ def remove(self): def clear(self): """Clears the storage and removes the hook.""" self.storage.clear() - self.removable_handles.clear() @abstractmethod def forward_factory(self, name: str): @@ -106,6 +100,9 @@ def backward_factory(self, name: str): """ pass + def _get_data(self, name): + return self.config.data.get(name) if self.config.data is not None else None + class CacheHook(Hook): """ @@ -116,21 +113,19 @@ def forward_factory(self, name: str): if self.config.hook_mode is HookMode.INPUT: def hook(module, input, output): - self.storage[name] = input.detach().cpu() + self.storage[name] = input elif self.config.hook_mode is HookMode.OUTPUT: def hook(module, input, output): - self.storage[name] = output.detach().cpu() + self.storage[name] = output else: - raise ValueError(f"Unknown hook mode: {self.config.hook_mode}") + raise ValueError(f"Unknown cache mode: {self.config.hook_mode}") return hook def backward_factory(self, name: str): - raise NotImplementedError( - "Backward hook not implemented for CacheHook" - ) + raise NotImplementedError("Backward hook not implemented for CacheHook") class MeasureHook(Hook): @@ -142,34 +137,21 @@ def forward_factory(self, name: str): if self.config.hook_mode is HookMode.INPUT: def hook(module, input, output): - if self.config.data is not None: - measure_data = self.config.data[name] - else: - measure_data = None - self.storage[name] = self.config.data_fn( - input.detach(), measure_data=measure_data - ) + measure_data = self._get_data(name) + self.storage[name] = self.config.data_fn(input, measure_data=measure_data, name=name) elif self.config.hook_mode is HookMode.OUTPUT: def hook(module, input, output): - if self.config.data is not None: - measure_data = self.config.data[name] - else: - measure_data = None - self.storage[name] = self.config.data_fn( - output.detach(), measure_data=measure_data - ) + measure_data = self._get_data(name) + self.storage[name] = self.config.data_fn(output, measure_data=measure_data, name=name) else: - raise ValueError(f"Unknown hook mode: {self.config.hook_mode}") - + raise ValueError(f"Unknown measure mode: {self.config.hook_mode}") return hook def backward_factory(self, name: str): - raise NotImplementedError( - "Backward hook not implemented for MeasureHook" - ) + raise NotImplementedError("Backward hook not implemented for MeasureHook") class ModifyHook(Hook): @@ -179,26 +161,23 @@ class ModifyHook(Hook): def forward_factory(self, name: str): if self.config.hook_mode is HookMode.INPUT: - raise NotImplementedError( - "Input hook not implemented for ModifyHook" - ) + + def hook(module, input, output): + modify_data = self._get_data(name) + self.storage[name] = self.config.data_fn(input, modify_data=modify_data, name=name) + return input elif self.config.hook_mode is HookMode.OUTPUT: def hook(module, input, output): - if self.config.data is not None: - modify_data = self.config.data[name] - else: - modify_data = None - output = self.config.data_fn(output, modify_data=modify_data) + modify_data = self._get_data(name) + self.storage[name] = self.config.data_fn(output, modify_data=modify_data, name=name) return output else: - raise ValueError(f"Unknown hook mode: {self.config.hook_mode}") + raise ValueError(f"Unknown modify mode: {self.config.hook_mode}") return hook def backward_factory(self, name: str): - raise NotImplementedError( - "Backward hook not implemented for ModifyHook" - ) + raise NotImplementedError("Backward hook not implemented for ModifyHook") diff --git a/src/lczerolens/xai/lens.py b/src/lczerolens/xai/lens.py index 1b3334c..9889cfe 100644 --- a/src/lczerolens/xai/lens.py +++ b/src/lczerolens/xai/lens.py @@ -1,5 +1,4 @@ -"""Generic lens class. -""" +"""Generic lens class.""" from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Type @@ -7,7 +6,7 @@ import chess from torch.utils.data import Dataset -from lczerolens.game.wrapper import ModelWrapper +from lczerolens.model.wrapper import ModelWrapper class Lens(ABC): diff --git a/src/lczerolens/xai/lenses/__init__.py b/src/lczerolens/xai/lenses/__init__.py index eefce87..4929695 100644 --- a/src/lczerolens/xai/lenses/__init__.py +++ b/src/lczerolens/xai/lenses/__init__.py @@ -8,3 +8,12 @@ from .patching import PatchingLens from .policy import PolicyLens from .probing import ProbingLens + +__all__ = [ + "ActivationLens", + "CrpLens", + "LrpLens", + "PatchingLens", + "PolicyLens", + "ProbingLens", +] diff --git a/src/lczerolens/xai/lenses/activation.py b/src/lczerolens/xai/lenses/activation.py index dbec89d..4a48341 100644 --- a/src/lczerolens/xai/lenses/activation.py +++ b/src/lczerolens/xai/lenses/activation.py @@ -1,5 +1,4 @@ -"""Activation lens for XAI. -""" +"""Activation lens for XAI.""" from typing import Any, Callable, Dict, Optional @@ -7,7 +6,7 @@ import torch from torch.utils.data import DataLoader, Dataset -from lczerolens.game.wrapper import ModelWrapper +from lczerolens.model.wrapper import ModelWrapper from lczerolens.xai.hook import CacheHook, HookConfig from lczerolens.xai.lens import Lens @@ -49,9 +48,7 @@ def analyse_dataset( """Cache the activations for a given model and dataset.""" if save_to is not None: raise NotImplementedError("Saving to file is not implemented.") - dataloader = DataLoader( - dataset, batch_size=batch_size, collate_fn=collate_fn - ) + dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) self.cache_hook.clear() self.cache_hook.register(wrapper.model) batched_activations: Dict[str, Any] = {} @@ -62,8 +59,6 @@ def analyse_dataset( if key not in batched_activations: batched_activations[key] = value else: - batched_activations[key] = torch.cat( - (batched_activations[key], value), dim=0 - ) + batched_activations[key] = torch.cat((batched_activations[key], value), dim=0) self.cache_hook.clear() return batched_activations diff --git a/src/lczerolens/xai/lenses/crp.py b/src/lczerolens/xai/lenses/crp.py index bde9858..2dd3b90 100644 --- a/src/lczerolens/xai/lenses/crp.py +++ b/src/lczerolens/xai/lenses/crp.py @@ -1,5 +1,4 @@ -"""Compute CRP heatmap for a given model and input. -""" +"""Compute CRP heatmap for a given model and input.""" from typing import Any, Callable, List, Optional @@ -9,7 +8,7 @@ from crp.helper import get_layer_names from torch.utils.data import Dataset -from lczerolens.game.wrapper import ModelWrapper +from lczerolens.model.wrapper import ModelWrapper from lczerolens.xai.lens import Lens from .lrp import LrpLens @@ -40,9 +39,7 @@ def analyse_board( composite = kwargs.get("composite", None) if mode == "latent_relevances": - return self._compute_latent_relevances( - [board], wrapper, layer_names=layer_names, composite=composite - ) + return self._compute_latent_relevances([board], wrapper, layer_names=layer_names, composite=composite) elif mode == "max_ref": raise NotImplementedError else: @@ -67,9 +64,7 @@ def _compute_latent_relevances( composite: Optional[Any] = None, ) -> torch.Tensor: if layer_names is None: - layer_names = layer_names = get_layer_names( - wrapper, [torch.nn.Identity] - ) + layer_names = layer_names = get_layer_names(wrapper, [torch.nn.Identity]) if composite is None: composite = LrpLens.make_default_composite() diff --git a/src/lczerolens/xai/lenses/lrp.py b/src/lczerolens/xai/lenses/lrp.py index 961f1eb..553aafb 100644 --- a/src/lczerolens/xai/lenses/lrp.py +++ b/src/lczerolens/xai/lenses/lrp.py @@ -1,5 +1,4 @@ -"""Compute LRP heatmap for a given model and input. -""" +"""Compute LRP heatmap for a given model and input.""" from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional @@ -13,7 +12,7 @@ from zennit.rules import Epsilon, Pass, ZPlus from zennit.types import Activation -from lczerolens.game.wrapper import ModelWrapper +from lczerolens.model.wrapper import ModelWrapper from lczerolens.xai.helpers import lrp as lrp_helpers from lczerolens.xai.lens import Lens @@ -92,9 +91,7 @@ def analyse_dataset( replace_onnx2torch = kwargs.get("replace_onnx2torch", True) linearise_softmax = kwargs.get("linearise_softmax", False) init_rel_fn = kwargs.get("init_rel_fn", None) - dataloader = DataLoader( - dataset, batch_size=batch_size, collate_fn=collate_fn - ) + dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) relevances = {} for batch in dataloader: inidices, boards = batch @@ -125,9 +122,7 @@ def _compute_lrp_relevance( Compute LRP heatmap for a given model and input. """ - with self.context( - wrapper, composite, replace_onnx2torch, linearise_softmax - ) as modified_model: + with self.context(wrapper, composite, replace_onnx2torch, linearise_softmax) as modified_model: output, input_tensor = modified_model.predict( boards, with_grad=True, @@ -135,18 +130,10 @@ def _compute_lrp_relevance( return_input=True, ) if target is None: - output.backward( - gradient=( - output if init_rel_fn is None else init_rel_fn(output) - ) - ) + output.backward(gradient=(output if init_rel_fn is None else init_rel_fn(output))) else: output[target].backward( - gradient=( - output[target] - if init_rel_fn is None - else init_rel_fn(output[target]) - ) + gradient=(output[target] if init_rel_fn is None else init_rel_fn(output[target])) ) return input_tensor.grad @@ -183,9 +170,7 @@ def context( new_module_mapping[name] = torch.nn.Identity() old_module_mapping[name] = module if replace_onnx2torch: - if isinstance( - module, onnx2torch.node_converters.OnnxBinaryMathOperation - ): + if isinstance(module, onnx2torch.node_converters.OnnxBinaryMathOperation): if module.math_op_function is torch.add: new_module_mapping[name] = lrp_helpers.AddEpsilon() old_module_mapping[name] = module @@ -195,9 +180,7 @@ def context( elif isinstance(module, onnx2torch.node_converters.OnnxMatMul): new_module_mapping[name] = lrp_helpers.MatMulEpsilon() old_module_mapping[name] = module - elif isinstance( - module, onnx2torch.node_converters.OnnxFunction - ): + elif isinstance(module, onnx2torch.node_converters.OnnxFunction): if module.function is torch.tanh: new_module_mapping[name] = torch.nn.Tanh() old_module_mapping[name] = module diff --git a/src/lczerolens/xai/lenses/patching.py b/src/lczerolens/xai/lenses/patching.py index eeda0b1..298b4fc 100644 --- a/src/lczerolens/xai/lenses/patching.py +++ b/src/lczerolens/xai/lenses/patching.py @@ -1,5 +1,4 @@ -"""Patching lens for XAI. -""" +"""Patching lens for XAI.""" from typing import Callable, Dict, Optional @@ -7,7 +6,7 @@ import torch from torch.utils.data import Dataset -from lczerolens.game.wrapper import ModelWrapper +from lczerolens.model.wrapper import ModelWrapper from lczerolens.xai.hook import HookConfig, ModifyHook from lczerolens.xai.lens import Lens @@ -22,9 +21,7 @@ def __init__(self, patching_dict: Dict[str, Callable]): self.patching_dict = patching_dict self.modify_hooks = {} for module_name, patch in self.patching_dict.items(): - self.modify_hooks[module_name] = ModifyHook( - HookConfig(module_exp=module_name, data_fn=patch) - ) + self.modify_hooks[module_name] = ModifyHook(HookConfig(module_exp=module_name, data_fn=patch)) def is_compatible(self, wrapper: ModelWrapper) -> bool: """ @@ -63,9 +60,7 @@ def analyse_dataset( """ if save_to is not None: raise NotImplementedError("Saving to file is not implemented.") - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, collate_fn=collate_fn - ) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) batched_outs = None for modify_hook in self.modify_hooks.values(): modify_hook.clear() diff --git a/src/lczerolens/xai/lenses/policy.py b/src/lczerolens/xai/lenses/policy.py index 36de8ec..886db8c 100644 --- a/src/lczerolens/xai/lenses/policy.py +++ b/src/lczerolens/xai/lenses/policy.py @@ -1,5 +1,4 @@ -"""PolicyLens class for wrapping the LCZero models. -""" +"""PolicyLens class for wrapping the LCZero models.""" from typing import Any, Callable, Dict, Optional @@ -7,8 +6,11 @@ import torch from torch.utils.data import DataLoader, Dataset -from lczerolens.game.wrapper import ModelWrapper, PolicyFlow -from lczerolens.utils.constants import INVERTED_FROM_INDEX, INVERTED_TO_INDEX +from lczerolens.encodings.constants import ( + INVERTED_FROM_INDEX, + INVERTED_TO_INDEX, +) +from lczerolens.model.wrapper import ModelWrapper, PolicyFlow from lczerolens.xai.lens import Lens DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -75,10 +77,6 @@ def aggregate_policy( filtered_policy = policy for square_index in range(64): square = chess.SQUARE_NAMES[square_index] - pickup_agg[square_index] = filtered_policy[ - INVERTED_FROM_INDEX[square] - ].sum() - dropoff_agg[square_index] = filtered_policy[ - INVERTED_TO_INDEX[square] - ].sum() + pickup_agg[square_index] = filtered_policy[INVERTED_FROM_INDEX[square]].sum() + dropoff_agg[square_index] = filtered_policy[INVERTED_TO_INDEX[square]].sum() return pickup_agg, dropoff_agg diff --git a/src/lczerolens/xai/lenses/probing.py b/src/lczerolens/xai/lenses/probing.py index 80e3297..cda337b 100644 --- a/src/lczerolens/xai/lenses/probing.py +++ b/src/lczerolens/xai/lenses/probing.py @@ -1,5 +1,4 @@ -"""Probing lens for XAI. -""" +"""Probing lens for XAI.""" from typing import Any, Callable, Dict, Optional @@ -7,7 +6,7 @@ import torch from torch.utils.data import Dataset -from lczerolens.game.wrapper import ModelWrapper +from lczerolens.model.wrapper import ModelWrapper from lczerolens.xai.hook import HookConfig, MeasureHook from lczerolens.xai.lens import Lens from lczerolens.xai.probe import Probe @@ -23,9 +22,7 @@ def __init__(self, probe_dict: Dict[str, Probe]): self.probe_dict = probe_dict self.measure_hooks = {} for module_name, probe in self.probe_dict.items(): - self.measure_hooks[module_name] = MeasureHook( - HookConfig(module_exp=module_name, data_fn=probe.predict) - ) + self.measure_hooks[module_name] = MeasureHook(HookConfig(module_exp=module_name, data_fn=probe.predict)) def is_compatible(self, wrapper: ModelWrapper) -> bool: """ @@ -66,9 +63,7 @@ def analyse_dataset( """ if save_to is not None: raise NotImplementedError("Saving to file is not implemented.") - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, collate_fn=collate_fn - ) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) batched_measures: Dict[str, Any] = {} for measure_hook in self.measure_hooks.values(): measure_hook.clear() @@ -78,9 +73,7 @@ def analyse_dataset( wrapper.predict(boards) for module_name, measure_hook in self.measure_hooks.items(): if module_name not in batched_measures: - batched_measures[module_name] = measure_hook.storage[ - module_name - ] + batched_measures[module_name] = measure_hook.storage[module_name] else: batched_measures[module_name] = torch.cat( ( diff --git a/src/lczerolens/xai/probe.py b/src/lczerolens/xai/probe.py index ed3bd89..b8ce625 100644 --- a/src/lczerolens/xai/probe.py +++ b/src/lczerolens/xai/probe.py @@ -1,5 +1,4 @@ -"""Module to implement generic probing. -""" +"""Module to implement generic probing.""" from abc import ABC, abstractmethod from typing import Any @@ -54,9 +53,7 @@ def train( mean_label = labels.mean(dim=1, keepdim=True) scaled_activations = activations - mean_activation scaled_labels = labels - mean_label - cav = einops.einsum( - scaled_activations, scaled_labels, "b a, b d -> a d" - ) + cav = einops.einsum(scaled_activations, scaled_labels, "b a, b d -> a d") self._h = cav / (cav.norm(dim=0, keepdim=True) + EPS) def predict(self, activations: torch.Tensor, **kwargs): diff --git a/tests/conftest.py b/tests/conftest.py index 811b556..40218b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ from lczerolens import GameDataset, ModelWrapper from lczerolens._native_builder import NativeBuilder -from lczerolens.utils import lczero as lczero_utils +from lczerolens.model import lczero as lczero_utils @pytest.fixture(scope="session") @@ -19,9 +19,7 @@ def tiny_lczero_backend(): @pytest.fixture(scope="session") def tiny_ensure_network(): - lczero_utils.convert_to_onnx( - "assets/tinygyal-8.pb.gz", "assets/tinygyal-8.onnx" - ) + lczero_utils.convert_to_onnx("assets/tinygyal-8.pb.gz", "assets/tinygyal-8.onnx") yield @@ -45,9 +43,7 @@ def tiny_senet_ort(tiny_ensure_network): @pytest.fixture(scope="class") def maia_ensure_network(): - lczero_utils.convert_to_onnx( - "assets/maia-1100.pb.gz", "assets/maia-1100.onnx" - ) + lczero_utils.convert_to_onnx("assets/maia-1100.pb.gz", "assets/maia-1100.onnx") yield @@ -80,26 +76,17 @@ def winner_ensure_network(): @pytest.fixture(scope="class") def winner_wrapper(winner_ensure_network): - wrapper = ModelWrapper.from_path( - "assets/384x30-2022_0108_1903_17_608.onnx" - ) - yield wrapper + yield ModelWrapper.from_path("assets/384x30-2022_0108_1903_17_608.onnx") @pytest.fixture(scope="class") def winner_senet(winner_ensure_network): - senet = NativeBuilder.build_from_path( - "assets/384x30-2022_0108_1903_17_608.onnx" - ) - yield senet + yield NativeBuilder.build_from_path("assets/384x30-2022_0108_1903_17_608.onnx") @pytest.fixture(scope="class") def winner_senet_ort(winner_ensure_network): - senet_ort = ort.InferenceSession( - "assets/384x30-2022_0108_1903_17_608.onnx" - ) - yield senet_ort + yield ort.InferenceSession("assets/384x30-2022_0108_1903_17_608.onnx") @pytest.fixture(scope="session") diff --git a/tests/unit/_native_builder/test_senet.py b/tests/unit/_native_builder/test_senet.py index da3b17e..bcc871f 100644 --- a/tests/unit/_native_builder/test_senet.py +++ b/tests/unit/_native_builder/test_senet.py @@ -5,7 +5,7 @@ import chess import torch -from lczerolens import board_utils +from lczerolens import board_encodings class TestTinySeNet: @@ -15,97 +15,69 @@ def test_senet_prediction(self, tiny_senet_ort, tiny_senet): """ board = chess.Board() - out = tiny_senet(board_utils.board_to_input_tensor(board).unsqueeze(0)) + out = tiny_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] value = out["value"] onnx_policy, onnx_value = tiny_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_value = torch.tensor(onnx_value) assert torch.allclose(policy, onnx_policy, atol=1e-4) assert torch.allclose(value, onnx_value, atol=1e-4) - def test_senet_prediction_random( - self, tiny_senet_ort, tiny_senet, random_move_board_list - ): + def test_senet_prediction_random(self, tiny_senet_ort, tiny_senet, random_move_board_list): """ Test that the wrapper prediction works. """ move_list, board_list = random_move_board_list for i, board in enumerate(board_list): - out = tiny_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = tiny_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] value = out["value"] onnx_policy, onnx_value = tiny_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_value = torch.tensor(onnx_value) assert torch.allclose(policy, onnx_policy, atol=1e-4) assert torch.allclose(value, onnx_value, atol=1e-4) - def test_senet_prediction_repetition( - self, tiny_senet_ort, tiny_senet, repetition_move_board_list - ): + def test_senet_prediction_repetition(self, tiny_senet_ort, tiny_senet, repetition_move_board_list): """ Test that the wrapper prediction works. """ move_list, board_list = repetition_move_board_list for i, board in enumerate(board_list): - out = tiny_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = tiny_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] value = out["value"] onnx_policy, onnx_value = tiny_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_value = torch.tensor(onnx_value) assert torch.allclose(policy, onnx_policy, atol=1e-4) assert torch.allclose(value, onnx_value, atol=1e-4) - def test_senet_prediction_long( - self, tiny_senet_ort, tiny_senet, long_move_board_list - ): + def test_senet_prediction_long(self, tiny_senet_ort, tiny_senet, long_move_board_list): """ Test that the wrapper prediction works. """ move_list, board_list = long_move_board_list for i, board in enumerate(board_list): - out = tiny_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = tiny_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] value = out["value"] onnx_policy, onnx_value = tiny_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_value = torch.tensor(onnx_value) @@ -120,97 +92,69 @@ def test_senet_prediction(self, maia_senet_ort, maia_senet): """ board = chess.Board() - out = maia_senet(board_utils.board_to_input_tensor(board).unsqueeze(0)) + out = maia_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] wdl = out["wdl"] onnx_policy, onnx_wdl = maia_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_wdl = torch.tensor(onnx_wdl) assert torch.allclose(policy, onnx_policy, atol=1e-4) assert torch.allclose(wdl, onnx_wdl, atol=1e-4) - def test_senet_prediction_random( - self, maia_senet_ort, maia_senet, random_move_board_list - ): + def test_senet_prediction_random(self, maia_senet_ort, maia_senet, random_move_board_list): """ Test that the wrapper prediction works. """ move_list, board_list = random_move_board_list for i, board in enumerate(board_list): - out = maia_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = maia_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] wdl = out["wdl"] onnx_policy, onnx_wdl = maia_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_wdl = torch.tensor(onnx_wdl) assert torch.allclose(policy, onnx_policy, atol=1e-4) assert torch.allclose(wdl, onnx_wdl, atol=1e-4) - def test_senet_prediction_repetition( - self, maia_senet_ort, maia_senet, repetition_move_board_list - ): + def test_senet_prediction_repetition(self, maia_senet_ort, maia_senet, repetition_move_board_list): """ Test that the wrapper prediction works. """ move_list, board_list = repetition_move_board_list for i, board in enumerate(board_list): - out = maia_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = maia_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] wdl = out["wdl"] onnx_policy, onnx_wdl = maia_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_wdl = torch.tensor(onnx_wdl) assert torch.allclose(policy, onnx_policy, atol=1e-4) assert torch.allclose(wdl, onnx_wdl, atol=1e-4) - def test_senet_prediction_long( - self, maia_senet_ort, maia_senet, long_move_board_list - ): + def test_senet_prediction_long(self, maia_senet_ort, maia_senet, long_move_board_list): """ Test that the wrapper prediction works. """ move_list, board_list = long_move_board_list for i, board in enumerate(board_list): - out = maia_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = maia_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] wdl = out["wdl"] onnx_policy, onnx_wdl = maia_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_wdl = torch.tensor(onnx_wdl) @@ -225,19 +169,13 @@ def test_senet_prediction(self, winner_senet_ort, winner_senet): """ board = chess.Board() - out = winner_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = winner_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] wdl = out["wdl"] mlh = out["mlh"] onnx_policy, onnx_wdl, onnx_mlh = winner_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_wdl = torch.tensor(onnx_wdl) @@ -246,28 +184,20 @@ def test_senet_prediction(self, winner_senet_ort, winner_senet): assert torch.allclose(wdl, onnx_wdl, atol=1e-4) assert torch.allclose(mlh, onnx_mlh, atol=1e-4) - def test_senet_prediction_random( - self, winner_senet_ort, winner_senet, random_move_board_list - ): + def test_senet_prediction_random(self, winner_senet_ort, winner_senet, random_move_board_list): """ Test that the wrapper prediction works. """ move_list, board_list = random_move_board_list for i, board in enumerate(board_list): - out = winner_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = winner_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] wdl = out["wdl"] mlh = out["mlh"] onnx_policy, onnx_wdl, onnx_mlh = winner_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_wdl = torch.tensor(onnx_wdl) @@ -276,28 +206,20 @@ def test_senet_prediction_random( assert torch.allclose(wdl, onnx_wdl, atol=1e-4) assert torch.allclose(mlh, onnx_mlh, atol=1e-4) - def test_senet_prediction_repetition( - self, winner_senet_ort, winner_senet, repetition_move_board_list - ): + def test_senet_prediction_repetition(self, winner_senet_ort, winner_senet, repetition_move_board_list): """ Test that the wrapper prediction works. """ move_list, board_list = repetition_move_board_list for i, board in enumerate(board_list): - out = winner_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = winner_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] wdl = out["wdl"] mlh = out["mlh"] onnx_policy, onnx_wdl, onnx_mlh = winner_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_wdl = torch.tensor(onnx_wdl) @@ -306,28 +228,20 @@ def test_senet_prediction_repetition( assert torch.allclose(wdl, onnx_wdl, atol=1e-4) assert torch.allclose(mlh, onnx_mlh, atol=1e-4) - def test_senet_prediction_long( - self, winner_senet_ort, winner_senet, long_move_board_list - ): + def test_senet_prediction_long(self, winner_senet_ort, winner_senet, long_move_board_list): """ Test that the wrapper prediction works. """ move_list, board_list = long_move_board_list for i, board in enumerate(board_list): - out = winner_senet( - board_utils.board_to_input_tensor(board).unsqueeze(0) - ) + out = winner_senet(board_encodings.board_to_input_tensor(board).unsqueeze(0)) policy = out["policy"] wdl = out["wdl"] mlh = out["mlh"] onnx_policy, onnx_wdl, onnx_mlh = winner_senet_ort.run( None, - { - "/input/planes": board_utils.board_to_input_tensor(board) - .unsqueeze(0) - .numpy() - }, + {"/input/planes": board_encodings.board_to_input_tensor(board).unsqueeze(0).numpy()}, ) onnx_policy = torch.tensor(onnx_policy) onnx_wdl = torch.tensor(onnx_wdl) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5178ced..540e225 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -32,9 +32,7 @@ def repetition_move_board_list(): move = chess.Move.from_uci(uci_move) move_list.append(move) board.push(move) - board_list.append( - board.copy(stack=True) - ) # Full stack is needed for repetition detection + board_list.append(board.copy(stack=True)) # Full stack is needed for repetition detection return move_list, board_list diff --git a/tests/unit/game/test_play.py b/tests/unit/game/test_play.py new file mode 100644 index 0000000..e094104 --- /dev/null +++ b/tests/unit/game/test_play.py @@ -0,0 +1,64 @@ +"""Wrapper tests.""" + +import chess + +from lczerolens.game import WrapperSampler, SelfPlay, PolicySampler, BatchedPolicySampler + + +class TestWrapperSampler: + def test_get_utility_tiny(self, tiny_wrapper): + """Test get_utility method.""" + board = chess.Board() + sampler = WrapperSampler(wrapper=tiny_wrapper) + utility, _, _ = sampler.get_utility(board) + assert utility.shape[0] == 20 + + def test_get_utility_winner(self, winner_wrapper): + """Test get_utility method.""" + board = chess.Board() + sampler = WrapperSampler(wrapper=winner_wrapper) + utility, _, _ = sampler.get_utility(board) + assert utility.shape[0] == 20 + + def test_policy_sampler_tiny(self, tiny_wrapper): + """Test policy_sampler method.""" + board = chess.Board() + sampler = PolicySampler(wrapper=tiny_wrapper) + utility, _, _ = sampler.get_utility(board) + assert utility.shape[0] == 20 + + +class TestSelfPlay: + def test_play(self, tiny_wrapper, winner_wrapper): + """Test play method.""" + board = chess.Board() + white = WrapperSampler(wrapper=tiny_wrapper) + black = WrapperSampler(wrapper=winner_wrapper) + self_play = SelfPlay(white=white, black=black) + logs = [] + + def report_fn(log, to_play): + logs.append((log, to_play)) + + game, board = self_play.play(board=board, max_moves=10, report_fn=report_fn) + + assert len(game) == len(logs) == 10 + + +class TestBatchedPolicySampler: + def test_batched_policy_sampler_ag(self, tiny_wrapper): + """Test batched_policy_sampler method.""" + boards = [chess.Board() for _ in range(10)] + + sampler_ag = BatchedPolicySampler(wrapper=tiny_wrapper, use_argmax=True) + moves = sampler_ag.get_next_moves(boards) + assert len(list(moves)) == 10 + assert all([move == moves[0] for move in moves]) + + def test_batched_policy_sampler_no_ag(self, tiny_wrapper): + """Test batched_policy_sampler method.""" + boards = [chess.Board() for _ in range(10)] + + sampler_no_ag = BatchedPolicySampler(wrapper=tiny_wrapper, use_argmax=False) + moves = sampler_no_ag.get_next_moves(boards) + assert len(list(moves)) == 10 diff --git a/tests/unit/game/test_wrapper.py b/tests/unit/game/test_wrapper.py index 467e032..e495cca 100644 --- a/tests/unit/game/test_wrapper.py +++ b/tests/unit/game/test_wrapper.py @@ -1,13 +1,12 @@ -"""Wrapper tests. -""" +"""Wrapper tests.""" import chess import pytest import torch from lczero.backends import GameState -from lczerolens.game import MlhFlow, PolicyFlow, ValueFlow, WdlFlow -from lczerolens.utils import lczero as lczero_utils +from lczerolens.model import MlhFlow, PolicyFlow, ValueFlow, WdlFlow +from lczerolens.model import lczero as lczero_utils class TestWrapper: @@ -22,63 +21,43 @@ def test_wrapper_prediction(self, tiny_lczero_backend, tiny_wrapper): policy = out["policy"] value = out["value"] lczero_game = GameState() - lczero_policy, lczero_value = lczero_utils.prediction_from_backend( - tiny_lczero_backend, lczero_game - ) + lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) assert torch.allclose(policy, lczero_policy, atol=1e-4) assert torch.allclose(value, lczero_value, atol=1e-4) - def test_wrapper_prediction_random( - self, tiny_lczero_backend, tiny_wrapper, random_move_board_list - ): + def test_wrapper_prediction_random(self, tiny_lczero_backend, tiny_wrapper, random_move_board_list): """Test that the wrapper prediction works.""" move_list, board_list = random_move_board_list for i, board in enumerate(board_list): (out,) = tiny_wrapper.predict(board) policy = out["policy"] value = out["value"] - lczero_game = GameState( - moves=[move.uci() for move in move_list[:i]] - ) - lczero_policy, lczero_value = lczero_utils.prediction_from_backend( - tiny_lczero_backend, lczero_game - ) + lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) + lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) assert torch.allclose(policy, lczero_policy, atol=1e-4) assert torch.allclose(value, lczero_value, atol=1e-4) - def test_wrapper_prediction_repetition( - self, tiny_lczero_backend, tiny_wrapper, repetition_move_board_list - ): + def test_wrapper_prediction_repetition(self, tiny_lczero_backend, tiny_wrapper, repetition_move_board_list): """Test that the wrapper prediction works.""" move_list, board_list = repetition_move_board_list for i, board in enumerate(board_list): (out,) = tiny_wrapper.predict(board) policy = out["policy"] value = out["value"] - lczero_game = GameState( - moves=[move.uci() for move in move_list[:i]] - ) - lczero_policy, lczero_value = lczero_utils.prediction_from_backend( - tiny_lczero_backend, lczero_game - ) + lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) + lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) assert torch.allclose(policy, lczero_policy, atol=1e-4) assert torch.allclose(value, lczero_value, atol=1e-4) - def test_wrapper_prediction_long( - self, tiny_lczero_backend, tiny_wrapper, long_move_board_list - ): + def test_wrapper_prediction_long(self, tiny_lczero_backend, tiny_wrapper, long_move_board_list): """Test that the wrapper prediction works.""" move_list, board_list = long_move_board_list for i, board in enumerate(board_list): (out,) = tiny_wrapper.predict(board) policy = out["policy"] value = out["value"] - lczero_game = GameState( - moves=[move.uci() for move in move_list[:i]] - ) - lczero_policy, lczero_value = lczero_utils.prediction_from_backend( - tiny_lczero_backend, lczero_game - ) + lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) + lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) assert torch.allclose(policy, lczero_policy, atol=1e-4) assert torch.allclose(value, lczero_value, atol=1e-4) diff --git a/tests/unit/utils/test_board.py b/tests/unit/utils/test_board.py index 42c558e..07e38b4 100644 --- a/tests/unit/utils/test_board.py +++ b/tests/unit/utils/test_board.py @@ -4,111 +4,85 @@ from lczero.backends import GameState -from lczerolens import board_utils -from lczerolens.utils import lczero as lczero_utils +from lczerolens import board_encodings +from lczerolens.model import lczero as lczero_utils class TestWithBackend: - def test_board_to_config_tensor( - self, random_move_board_list, tiny_lczero_backend - ): + def test_board_to_config_tensor(self, random_move_board_list, tiny_lczero_backend): """ Test that the board to tensor function works. """ move_list, board_list = random_move_board_list for i, board in enumerate(board_list): - board_tensor = board_utils.board_to_config_tensor(board) + board_tensor = board_encodings.board_to_config_tensor(board) uci_moves = [move.uci() for move in move_list[:i]] lczero_game = GameState(moves=uci_moves) - lczero_input_tensor = lczero_utils.board_from_backend( - tiny_lczero_backend, lczero_game, planes=13 - ) + lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game, planes=13) assert (board_tensor == lczero_input_tensor[:13]).all() - def test_board_to_input_tensor( - self, random_move_board_list, tiny_lczero_backend - ): + def test_board_to_input_tensor(self, random_move_board_list, tiny_lczero_backend): """ Test that the board to tensor function works. """ move_list, board_list = random_move_board_list for i, board in enumerate(board_list): - board_tensor = board_utils.board_to_input_tensor(board) + board_tensor = board_encodings.board_to_input_tensor(board) uci_moves = [move.uci() for move in move_list[:i]] lczero_game = GameState(moves=uci_moves) - lczero_input_tensor = lczero_utils.board_from_backend( - tiny_lczero_backend, lczero_game - ) + lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game) # assert (board_tensor == lczero_input_tensor).all() for plane in range(112): - assert ( - board_tensor[plane] == lczero_input_tensor[plane] - ).all() + assert (board_tensor[plane] == lczero_input_tensor[plane]).all() class TestRepetition: - def test_board_to_config_tensor( - self, repetition_move_board_list, tiny_lczero_backend - ): + def test_board_to_config_tensor(self, repetition_move_board_list, tiny_lczero_backend): """ Test that the board to tensor function works. """ move_list, board_list = repetition_move_board_list for i, board in enumerate(board_list): uci_moves = [move.uci() for move in move_list[:i]] - board_tensor = board_utils.board_to_config_tensor(board) + board_tensor = board_encodings.board_to_config_tensor(board) lczero_game = GameState(moves=uci_moves) - lczero_input_tensor = lczero_utils.board_from_backend( - tiny_lczero_backend, lczero_game, planes=13 - ) + lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game, planes=13) assert (board_tensor == lczero_input_tensor[:13]).all() - def test_board_to_input_tensor( - self, repetition_move_board_list, tiny_lczero_backend - ): + def test_board_to_input_tensor(self, repetition_move_board_list, tiny_lczero_backend): """ Test that the board to tensor function works. """ move_list, board_list = repetition_move_board_list for i, board in enumerate(board_list): uci_moves = [move.uci() for move in move_list[:i]] - board_tensor = board_utils.board_to_input_tensor(board) + board_tensor = board_encodings.board_to_input_tensor(board) lczero_game = GameState(moves=uci_moves) - lczero_input_tensor = lczero_utils.board_from_backend( - tiny_lczero_backend, lczero_game - ) + lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game) assert (board_tensor == lczero_input_tensor).all() class TestLong: - def test_board_to_config_tensor( - self, long_move_board_list, tiny_lczero_backend - ): + def test_board_to_config_tensor(self, long_move_board_list, tiny_lczero_backend): """ Test that the board to tensor function works. """ move_list, board_list = long_move_board_list for i, board in enumerate(board_list): uci_moves = [move.uci() for move in move_list[:i]] - board_tensor = board_utils.board_to_config_tensor(board) + board_tensor = board_encodings.board_to_config_tensor(board) lczero_game = GameState(moves=uci_moves) - lczero_input_tensor = lczero_utils.board_from_backend( - tiny_lczero_backend, lczero_game, planes=13 - ) + lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game, planes=13) assert (board_tensor == lczero_input_tensor[:13]).all() - def test_board_to_input_tensor( - self, long_move_board_list, tiny_lczero_backend - ): + def test_board_to_input_tensor(self, long_move_board_list, tiny_lczero_backend): """ Test that the board to tensor function works. """ move_list, board_list = long_move_board_list for i, board in enumerate(board_list): uci_moves = [move.uci() for move in move_list[:i]] - board_tensor = board_utils.board_to_input_tensor(board) + board_tensor = board_encodings.board_to_input_tensor(board) lczero_game = GameState(moves=uci_moves) - lczero_input_tensor = lczero_utils.board_from_backend( - tiny_lczero_backend, lczero_game - ) + lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game) assert (board_tensor == lczero_input_tensor).all() diff --git a/tests/unit/utils/test_lczero.py b/tests/unit/utils/test_lczero.py index 38f6453..36a51a1 100644 --- a/tests/unit/utils/test_lczero.py +++ b/tests/unit/utils/test_lczero.py @@ -5,7 +5,7 @@ import torch from lczero.backends import GameState -from lczerolens.utils import lczero as lczero_utils +from lczerolens.model import lczero as lczero_utils class TestExecution: @@ -21,9 +21,7 @@ def test_convertnet(self): """ Test that the convert_to_onnx function works. """ - conversion = lczero_utils.convert_to_onnx( - "assets/tinygyal-8.pb.gz", "assets/tinygyal-8.onnx" - ) + conversion = lczero_utils.convert_to_onnx("assets/tinygyal-8.pb.gz", "assets/tinygyal-8.onnx") assert isinstance(conversion, str) assert "INPUT_CLASSICAL_112_PLANE" in conversion @@ -40,9 +38,7 @@ def test_board_from_backend(self, tiny_lczero_backend): Test that the board from backend function works. """ lczero_game = GameState() - lczero_board_tensor = lczero_utils.board_from_backend( - tiny_lczero_backend, lczero_game - ) + lczero_board_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game) assert lczero_board_tensor.shape == (112, 8, 8) def test_prediction_from_backend(self, tiny_lczero_backend): @@ -50,18 +46,10 @@ def test_prediction_from_backend(self, tiny_lczero_backend): Test that the prediction from backend function works. """ lczero_game = GameState() - lczero_policy, lczero_value = lczero_utils.prediction_from_backend( - tiny_lczero_backend, lczero_game - ) + lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) assert lczero_policy.shape == (1858,) assert (lczero_value >= -1) and (lczero_value <= 1) - lczero_policy_softmax, _ = lczero_utils.prediction_from_backend( - tiny_lczero_backend, lczero_game, softmax=True - ) + lczero_policy_softmax, _ = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game, softmax=True) assert lczero_policy_softmax.shape == (1858,) - assert (lczero_policy_softmax >= 0).all() and ( - lczero_policy_softmax <= 1 - ).all() - assert torch.softmax(lczero_policy, dim=0).allclose( - lczero_policy_softmax, atol=1e-4 - ) + assert (lczero_policy_softmax >= 0).all() and (lczero_policy_softmax <= 1).all() + assert torch.softmax(lczero_policy, dim=0).allclose(lczero_policy_softmax, atol=1e-4) diff --git a/tests/unit/utils/test_move.py b/tests/unit/utils/test_move.py index dfa5d44..94a932f 100644 --- a/tests/unit/utils/test_move.py +++ b/tests/unit/utils/test_move.py @@ -5,8 +5,8 @@ import chess from lczero.backends import GameState -from lczerolens import move_utils -from lczerolens.utils import lczero as lczero_utils +from lczerolens import move_encodings +from lczerolens.model import lczero as lczero_utils class TestStability: @@ -16,10 +16,8 @@ def test_encode_decode(self, random_move_board_list): """ us, them = chess.WHITE, chess.BLACK for move, board in zip(*random_move_board_list): - encoded_move = move_utils.encode_move(move, (us, them)) - decoded_move = move_utils.decode_move( - encoded_move, (us, them), board - ) + encoded_move = move_encodings.encode_move(move, (us, them)) + decoded_move = move_encodings.decode_move(encoded_move, (us, them), board) assert move == decoded_move us, them = them, us @@ -31,9 +29,7 @@ def test_encode_decode_random(self, random_move_board_list): """ move_list, board_list = random_move_board_list for i, board in enumerate(board_list): - lczero_game = GameState( - moves=[move.uci() for move in move_list[:i]] - ) + lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) legal_moves = [move.uci() for move in board.legal_moves] ( lczero_legal_moves, @@ -42,8 +38,7 @@ def test_encode_decode_random(self, random_move_board_list): assert len(legal_moves) == len(lczero_legal_moves) assert set(legal_moves) == set(lczero_legal_moves) policy_indices = [ - move_utils.encode_move(move, (board.turn, not board.turn)) - for move in board.legal_moves + move_encodings.encode_move(move, (board.turn, not board.turn)) for move in board.legal_moves ] assert len(lczero_policy_indices) == len(policy_indices) assert set(lczero_policy_indices) == set(policy_indices) @@ -54,9 +49,7 @@ def test_encode_decode_long(self, long_move_board_list): """ move_list, board_list = long_move_board_list for i, board in enumerate(board_list): - lczero_game = GameState( - moves=[move.uci() for move in move_list[:i]] - ) + lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) legal_moves = [move.uci() for move in board.legal_moves] ( lczero_legal_moves, @@ -65,8 +58,7 @@ def test_encode_decode_long(self, long_move_board_list): assert len(legal_moves) == len(lczero_legal_moves) assert set(legal_moves) == set(lczero_legal_moves) policy_indices = [ - move_utils.encode_move(move, (board.turn, not board.turn)) - for move in board.legal_moves + move_encodings.encode_move(move, (board.turn, not board.turn)) for move in board.legal_moves ] assert len(lczero_policy_indices) == len(policy_indices) assert set(lczero_policy_indices) == set(policy_indices) diff --git a/tests/unit/xai/test_activation.py b/tests/unit/xai/test_activation.py index 62550e1..a5fe97a 100644 --- a/tests/unit/xai/test_activation.py +++ b/tests/unit/xai/test_activation.py @@ -1,5 +1,4 @@ -"""Activation lens tests. -""" +"""Activation lens tests.""" from lczerolens import Lens from lczerolens.xai import ActivationLens diff --git a/tests/unit/xai/test_concept.py b/tests/unit/xai/test_concept.py index e9b941c..c23ee15 100644 --- a/tests/unit/xai/test_concept.py +++ b/tests/unit/xai/test_concept.py @@ -35,38 +35,18 @@ def test_compute_label(self): Test the compute_label method. """ concept = AndBinaryConcept(HasPieceConcept("p"), HasPieceConcept("n")) - assert ( - concept.compute_label(chess.Board("8/8/8/8/8/8/8/8 w - - 0 1")) - == 0 - ) - assert ( - concept.compute_label(chess.Board("8/p7/8/8/8/8/8/8 w - - 0 1")) - == 0 - ) - assert ( - concept.compute_label(chess.Board("8/pn6/8/8/8/8/8/8 w - - 0 1")) - == 1 - ) + assert concept.compute_label(chess.Board("8/8/8/8/8/8/8/8 w - - 0 1")) == 0 + assert concept.compute_label(chess.Board("8/p7/8/8/8/8/8/8 w - - 0 1")) == 0 + assert concept.compute_label(chess.Board("8/pn6/8/8/8/8/8/8 w - - 0 1")) == 1 def test_relative_threat(self): """ Test the relative threat concept. """ - concept = HasThreatConcept( - "p", relative=True - ) # Is an enemy pawn threatened? - assert ( - concept.compute_label(chess.Board("8/8/8/8/8/8/8/8 w - - 0 1")) - == 0 - ) - assert ( - concept.compute_label(chess.Board("R7/8/8/8/8/8/p7/8 w - - 0 1")) - == 1 - ) - assert ( - concept.compute_label(chess.Board("R7/8/8/8/8/8/p7/8 b - - 0 1")) - == 0 - ) + concept = HasThreatConcept("p", relative=True) # Is an enemy pawn threatened? + assert concept.compute_label(chess.Board("8/8/8/8/8/8/8/8 w - - 0 1")) == 0 + assert concept.compute_label(chess.Board("R7/8/8/8/8/8/p7/8 w - - 0 1")) == 1 + assert concept.compute_label(chess.Board("R7/8/8/8/8/8/p7/8 b - - 0 1")) == 0 class TestDataset: diff --git a/tests/unit/xai/test_crp.py b/tests/unit/xai/test_crp.py index 2322557..dbb3b0b 100644 --- a/tests/unit/xai/test_crp.py +++ b/tests/unit/xai/test_crp.py @@ -1,5 +1,4 @@ -"""CRP lens tests. -""" +"""CRP lens tests.""" from lczerolens import Lens from lczerolens.xai import CrpLens diff --git a/tests/unit/xai/test_lrp.py b/tests/unit/xai/test_lrp.py index fd72d25..f58ef6f 100644 --- a/tests/unit/xai/test_lrp.py +++ b/tests/unit/xai/test_lrp.py @@ -1,5 +1,4 @@ -"""LRP lens tests. -""" +"""LRP lens tests.""" import chess import torch