Skip to content

Commit

Permalink
Diverse fixes (#14)
Browse files Browse the repository at this point in the history
* fix on device training

* fixed val metrics

* fixed parse bool

* new training parameters

* copy typo

* best param and new eval

* resampling error

* ressampling typo

* ressampling typo

* bias typo

* device typo

* table format

* np hist

* bug hist

* new defaults

* debug

* fixed hist

* tests activated features

* bug test

* debug

* bigger datasets

* no reload

* to cpu

* new asset link
  • Loading branch information
Xmaster6y authored Mar 7, 2024
1 parent e0e447a commit 83742dd
Show file tree
Hide file tree
Showing 13 changed files with 522 additions and 317 deletions.
18 changes: 18 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,23 @@
"console": "integratedTerminal",
"justMyCode": false
}
,
{
"name": "Script simple sae",
"type": "debugpy",
"request": "launch",
"module": "scripts.simple_sae",
"console": "integratedTerminal",
"justMyCode": false
}
,
{
"name": "Debug",
"type": "debugpy",
"request": "launch",
"module": "ignored.debug",
"console": "integratedTerminal",
"justMyCode": false
}
]
}
2 changes: 1 addition & 1 deletion assets/resolve-test-assets.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ poetry run gdown 1Ssl4JanqzQn3p-RoHRDk_aApykl-SukE -O assets/tinygyal-8.pb.gz
poetry run gdown 1WzBQV_zn5NnfsG0K8kOion0pvWxXhgKM -O assets/384x30-2022_0108_1903_17_608.pb.gz
poetry run gdown 1erxB3tULDURjpPhiPWVGr6X986Q8uE6U -O assets/maia-1100.pb.gz
poetry run gdown 1YqqANK-wuZIOmMweuK_oCU7vfPN7G_Z6 -O assets/t1-smolgen-512x15x8h-distilled-swa-3395000.pb.gz
poetry run gdown 1Gw40ENElDrpqx9p0uuAnO0jieX1Clw89 -O assets/test_stockfish_10.jsonl
poetry run gdown 15-eGN7Hz2NM6aEMRaQrbW3ScxxQpAqa5 -O assets/test_stockfish_10.jsonl
199 changes: 87 additions & 112 deletions demo/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@

from demo import constants, utils, visualisation

cache = None
boards = None
board_index = 0


def list_models():
"""
Expand All @@ -38,89 +34,94 @@ def compute_cache(
attention_layer,
attention_head,
square,
quantity,
func,
trick,
aggregate,
state_board_index,
state_boards,
state_cache,
):
global cache
global boards
if model_name == "":
gr.Warning("No model selected.")
return None, None, None
return None, None, None, state_boards, state_cache

try:
board = chess.Board(board_fen)
except ValueError:
board = chess.Board()
gr.Warning("Invalid FEN, using starting position.")
boards = [board.copy()]
state_boards = [board.copy()]
if action_seq:
try:
if action_seq.startswith("1."):
for action in action_seq.split():
if action.endswith("."):
continue
board.push_san(action)
boards.append(board.copy())
state_boards.append(board.copy())
else:
for action in action_seq.split():
board.push_uci(action)
boards.append(board.copy())
state_boards.append(board.copy())
except ValueError:
gr.Warning(f"Invalid action {action} stopping before it.")
try:
wrapper, lens = utils.get_wrapper_lens_from_state(
model_name, "attention"
model_name,
"activation",
lens_name="attention",
module_exp=r"encoder\d+/mha/QK/softmax",
)
except ValueError:
gr.Warning("Could not load model.")
return None, None, None
cache = []
for board in boards:
attention_cache = copy.deepcopy(lens.compute_heatmap(board, wrapper))
cache.append(attention_cache)
return make_plot(
attention_layer,
attention_head,
square,
quantity,
func,
trick,
aggregate,
return None, None, None, state_boards, state_cache
state_cache = []
for board in state_boards:
attention_cache = copy.deepcopy(lens.analyse_board(board, wrapper))
state_cache.append(attention_cache)
return (
*make_plot(
attention_layer,
attention_head,
square,
state_board_index,
state_boards,
state_cache,
),
state_boards,
state_cache,
)


def make_plot(
attention_layer, attention_head, square, quantity, func, trick, aggregate
attention_layer,
attention_head,
square,
state_board_index,
state_boards,
state_cache,
):
global cache
global boards
global board_index

if cache is None:
gr.Warning("Cache not computed!")
return None, None
if state_cache == []:
gr.Warning("No cache available.")
return None, None, None

board = boards[board_index]
num_attention_layers = len(cache[board_index])
board = state_boards[state_board_index]
num_attention_layers = len(state_cache[state_board_index])
if attention_layer > num_attention_layers:
gr.Warning(
f"Attention layer {attention_layer} does not exist, "
f"using layer {num_attention_layers} instead."
)
attention_layer = num_attention_layers

key = f"{attention_layer-1}-{quantity}-{func}"
key = f"encoder{attention_layer-1}/mha/QK/softmax"
try:
attention_tensor = cache[board_index][key]
attention_tensor = state_cache[state_board_index][key]
except KeyError:
gr.Warning(f"Combination {key} does not exist.")
return None, None, None
if attention_head > attention_tensor.shape[1]:
gr.Warning(
f"Attention head {attention_head} does not exist, "
f"using head {attention_tensor.shape[1]} instead."
f"using head {attention_tensor.shape[1]+1} instead."
)
attention_head = attention_tensor.shape[1]
try:
Expand All @@ -132,15 +133,7 @@ def make_plot(
if board.turn == chess.BLACK:
square_index = chess.square_mirror(square_index)

if trick == "revert":
square_index = 63 - square_index

if aggregate == "Row":
heatmap = attention_tensor[0, attention_head - 1, square_index, :]
elif aggregate == "Column":
heatmap = attention_tensor[0, attention_head - 1, :, square_index]
else:
heatmap = attention_tensor[0, attention_head - 1]
heatmap = attention_tensor[0, attention_head - 1, square_index]
if board.turn == chess.BLACK:
heatmap = heatmap.view(8, 8).flip(0).view(64)
svg_board, fig = visualisation.render_heatmap(
Expand All @@ -155,37 +148,49 @@ def previous_board(
attention_layer,
attention_head,
square,
from_to,
color_flip,
trick,
aggregate,
state_board_index,
state_boards,
state_cache,
):
global board_index
board_index -= 1
if board_index < 0:
state_board_index -= 1
if state_board_index < 0:
gr.Warning("Already at first board.")
board_index = 0
return make_plot(
attention_layer, attention_head, square, from_to, color_flip
state_board_index = 0
return (
*make_plot(
attention_layer,
attention_head,
square,
state_board_index,
state_boards,
state_cache,
),
state_board_index,
)


def next_board(
attention_layer,
attention_head,
square,
from_to,
color_flip,
trick,
aggregate,
state_board_index,
state_boards,
state_cache,
):
global board_index
board_index += 1
if board_index >= len(boards):
state_board_index += 1
if state_board_index >= len(state_boards):
gr.Warning("Already at last board.")
board_index = len(boards) - 1
return make_plot(
attention_layer, attention_head, square, from_to, color_flip
state_board_index = len(state_boards) - 1
return (
*make_plot(
attention_layer,
attention_head,
square,
state_board_index,
state_boards,
state_cache,
),
state_board_index,
)


Expand Down Expand Up @@ -254,38 +259,6 @@ def next_board(
value="a1",
scale=1,
)
quantity = gr.Dropdown(
label="Quantity",
choices=["QK", "Q", "K", "out", "QKV"],
value="QK",
scale=2,
)
aggregate = gr.Dropdown(
label="Aggregate",
choices=["Row", "Column", "None"],
value="Row",
scale=2,
)
func = gr.Dropdown(
label="Function",
choices=[
"softmax",
"transpose",
"matmul",
"scale",
],
value="softmax",
scale=2,
)
trick = gr.Dropdown(
label="Trick",
choices=[
"none",
"revert",
],
value="none",
scale=2,
)
with gr.Row():
previous_board_button = gr.Button("Previous board")
next_board_button = gr.Button("Next board")
Expand All @@ -298,32 +271,34 @@ def next_board(
with gr.Column():
image = gr.Image(label="Board")

state_board_index = gr.State(0)
state_boards = gr.State([])
state_cache = gr.State([])
base_inputs = [
attention_layer,
attention_head,
square,
quantity,
func,
trick,
aggregate,
state_board_index,
state_boards,
state_cache,
]
outputs = [image, current_board_fen, colorbar]

compute_cache_button.click(
compute_cache,
inputs=[board_fen, action_seq, model_name] + base_inputs,
outputs=outputs,
outputs=outputs + [state_boards, state_cache],
)

previous_board_button.click(
previous_board, inputs=base_inputs, outputs=outputs
previous_board,
inputs=base_inputs,
outputs=outputs + [state_board_index],
)
next_board_button.click(
next_board, inputs=base_inputs, outputs=outputs + [state_board_index]
)
next_board_button.click(next_board, inputs=base_inputs, outputs=outputs)

attention_layer.change(make_plot, inputs=base_inputs, outputs=outputs)
attention_head.change(make_plot, inputs=base_inputs, outputs=outputs)
square.submit(make_plot, inputs=base_inputs, outputs=outputs)
quantity.change(make_plot, inputs=base_inputs, outputs=outputs)
func.change(make_plot, inputs=base_inputs, outputs=outputs)
trick.change(make_plot, inputs=base_inputs, outputs=outputs)
aggregate.change(make_plot, inputs=base_inputs, outputs=outputs)
3 changes: 2 additions & 1 deletion demo/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
wrappers: Dict[str, ModelWrapper] = {}

lenses: Dict[str, Dict[str, Lens]] = {
"attention": {},
"activation": {},
"lrp": {},
"crp": {},
"policy": {},
"probing": {},
"patching": {},
}
2 changes: 1 addition & 1 deletion demo/statistics_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
dataset = GameDataset("assets/test_stockfish_10.jsonl")
check_concept = HasThreatConcept("K", relative=True)
unique_check_dataset = ConceptDataset.from_game_dataset(dataset)
unique_check_dataset.concept = check_concept
unique_check_dataset.set_concept(check_concept)


def list_models():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ src_paths = ["src", "tests", "scripts", "docs", "demo"]

[tool.poetry]
name = "lczerolens"
version = "0.1.2"
version = "0.1.3"
description = "Interpretability for LeelaChessZero networks."
readme = "README.md"
license = "MIT"
Expand Down
Loading

0 comments on commit 83742dd

Please sign in to comment.