From 34603f02d32625e6fa4ad42516a50d73c7978afd Mon Sep 17 00:00:00 2001 From: ojaffe Date: Fri, 15 Mar 2024 10:54:10 +0000 Subject: [PATCH] add CDTA eval --- evals/elsuite/cant_do_that_anymore/README.md | 69 ++++ .../cant_do_that_anymore/chess/board.py | 244 ++++++++++++++ .../cant_do_that_anymore/chess/board_test.py | 95 ++++++ .../chess/move_variants.py | 120 +++++++ .../cant_do_that_anymore/chess/notation.py | 106 ++++++ .../cant_do_that_anymore/chess/pieces.py | 263 +++++++++++++++ .../cant_do_that_anymore/chess/utils.py | 107 ++++++ .../elsuite/cant_do_that_anymore/defaults.py | 15 + evals/elsuite/cant_do_that_anymore/eval.py | 201 +++++++++++ .../scripts/dataset_creation.py | 312 +++++++++++++++++ .../scripts/diagonal_dataset_creation.py | 316 ++++++++++++++++++ .../scripts/make_plots.py | 128 +++++++ .../scripts/run_experiments.sh | 67 ++++ evals/elsuite/cant_do_that_anymore/utils.py | 250 ++++++++++++++ .../diagonal_moves_dataset.jsonl | 3 + .../gpt-3.5-turbo-0125_dataset.jsonl | 3 + .../gpt-3.5-turbo-instruct_dataset.jsonl | 3 + .../gpt-4-0125-preview_dataset.jsonl | 3 + .../gpt-4-0314_dataset.jsonl | 3 + .../special_moves_dataset.jsonl | 3 + .../registry/evals/cant_do_that_anymore.yaml | 23 ++ .../solvers/cant_do_that_anymore.yaml | 17 + pyproject.toml | 1 + 23 files changed, 2352 insertions(+) create mode 100644 evals/elsuite/cant_do_that_anymore/README.md create mode 100644 evals/elsuite/cant_do_that_anymore/chess/board.py create mode 100644 evals/elsuite/cant_do_that_anymore/chess/board_test.py create mode 100644 evals/elsuite/cant_do_that_anymore/chess/move_variants.py create mode 100644 evals/elsuite/cant_do_that_anymore/chess/notation.py create mode 100644 evals/elsuite/cant_do_that_anymore/chess/pieces.py create mode 100644 evals/elsuite/cant_do_that_anymore/chess/utils.py create mode 100644 evals/elsuite/cant_do_that_anymore/defaults.py create mode 100644 evals/elsuite/cant_do_that_anymore/eval.py create mode 100644 evals/elsuite/cant_do_that_anymore/scripts/dataset_creation.py create mode 100644 evals/elsuite/cant_do_that_anymore/scripts/diagonal_dataset_creation.py create mode 100644 evals/elsuite/cant_do_that_anymore/scripts/make_plots.py create mode 100755 evals/elsuite/cant_do_that_anymore/scripts/run_experiments.sh create mode 100644 evals/elsuite/cant_do_that_anymore/utils.py create mode 100644 evals/registry/data/cant_do_that_anymore/diagonal_moves_dataset.jsonl create mode 100644 evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-0125_dataset.jsonl create mode 100644 evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-instruct_dataset.jsonl create mode 100644 evals/registry/data/cant_do_that_anymore/gpt-4-0125-preview_dataset.jsonl create mode 100644 evals/registry/data/cant_do_that_anymore/gpt-4-0314_dataset.jsonl create mode 100644 evals/registry/data/cant_do_that_anymore/special_moves_dataset.jsonl create mode 100644 evals/registry/evals/cant_do_that_anymore.yaml create mode 100644 evals/registry/solvers/cant_do_that_anymore.yaml diff --git a/evals/elsuite/cant_do_that_anymore/README.md b/evals/elsuite/cant_do_that_anymore/README.md new file mode 100644 index 0000000000..e143f278b9 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/README.md @@ -0,0 +1,69 @@ +# Can't Do That Anymore + +We evaluate how well models can adapt to new rules of an environment, by applying novel reasoning to a task rather than following biases seen during their training. We task models to play a variant of chess and evaluate whether they can avoid making moves that are ordinarily legal, but are illegal in our variant which has slightly different rules. In our variant of chess, bishops move as knights do. + +## Usage + +Run with: + +``` +oaieval cant_do_that_anymore +``` + +We suggest using `generation/direct/gpt-3.5-turbo` or `generation/direct/gpt-4-turbo-preview` as default choices for `` + +For more examples of running this eval, see `scripts/run_experiments.sh` + +## Dataset + +For each model we evaluate, we construct a dataset where every sample contains a board position and the next move that was played, which is legal for the board position under the normal rules of chess, but illegal under the rules of our variant (i.e. the next move is a bishop moving diagonally). We call these types of moves *special moves*. We additionally filter to only include special moves that the model would have predicted under temperature=0 with the normal rules. We can use this to evaluate if models will change their predictions when given the variant rules, despite normally strongly predicting the move under the normal rules. + +Each model's dataset is automatically found and loaded upon running the eval. If a dataset doesn't exist for a particular solver, one will automatically be constructed for it. + +## Evaluation Process + +Samples from the dataset are evaluated one-by-one. Each sample contains a board position and the special move (next move). We prompt models to predict the next best move given the board position, separately under both the normal rules of chess and our variant's rules. We then measure whether the model predicted the special move from the sample under both rule settings. If the model was perfectly following the given rules, we'd expect it to never predict the special move under the variant's rules. + +To see how we prompt models under each rule setting, see `defaults.py`. + +## Metrics + +The below are the key metrics of this eval: + +| Metric | Interpretation | +| --- | --- | +| `variant_impact_factor` | The relative decrease in special move predictions when under the variant's rules, relative to the special move predictions under the normal rules. Lower is better, perfect score is -1. +| `delta` | The absolute decrease in predicting the special move when under the variant's rules, relative to the models predictions under the normal rules. Lower is better. +| `predicted_move_proportion` | The proportion of examples where the model predicted the special move under the normal rules. +| `predicted_move_in_variant_proportion` | The proportion of examples where the model predicted the special move under the variant's rules. +| `avg_num_previous_moves` | Average number of previous moves leading up to the board positions across all samples. +| `std_num_previous_moves` | Standard deviation of the number of previous moves leading up to the board positions across all samples. + +## Variants + +| Variant | Notes | +| --- | --- | +| Default: `cant_do_that_anymore.all` | Default setting. Each dataset has 1000 samples. | +| `cant_do_that_anymore.all_small` | A smaller version of the default setting. Each dataset has 100 samples. | +| `cant_do_that_anymore.all_diagonal` | In this variant, we measure the proportion of samples (board positions) where the model will attempt to move a bishop diagonally. | + +## Custom Solvers + +We use two custom solvers for the base models we evaluate: `chess/generation/direct/gpt-3.5-turbo-instruct` and `chess/generation/direct/gpt-4-base`. These only generate up to four tokens, which prevents the base models from simulating the entire game. + +## Token Usage Estimates + +Below is a rough estimate of the total number of tokens used by the default variant: + +| Solver | Input Tokens | Output Tokens | Total Tokens | +| --- | --- | --- | --- | +| generation/direct/gpt-3.5-turbo | 375,000 | 10,000 | 385,000 | +| generation/direct/gpt-4-turbo-preview | 375,000 | 10,000 | 385,000 | + +## Version History + +- v0: Initial version released + +## Contribution statement + +Eval design, implementation, and results evaluation was primarily conducted by Oliver Jaffe with contributions from Giulio Starace, under the guidance of (alphabetically by last-name) Steven Adler, James Aung, and Chan Jun Shern who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation. diff --git a/evals/elsuite/cant_do_that_anymore/chess/board.py b/evals/elsuite/cant_do_that_anymore/chess/board.py new file mode 100644 index 0000000000..5537b9d5f4 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/board.py @@ -0,0 +1,244 @@ +import copy +from typing import Callable, Dict, Sequence + +from evals.elsuite.cant_do_that_anymore.chess.notation import NotationParser +from evals.elsuite.cant_do_that_anymore.chess.pieces import Piece +from evals.elsuite.cant_do_that_anymore.chess.utils import ( + Move, + get_other_player_id, + get_path_between_coords, + parse_piece, +) + + +class Board: + """ + Represents one board position. Is instantiated several times + by the BoardController to simulate future boards after playing + some moves. + """ + + def __init__( + self, + board_state: Sequence[Sequence[str]], + piece_id_to_instance: Dict[int, Piece], + piece_str_to_id: Dict[str, int], + piece_id_to_str: Dict[int, str], + ): + self.board_state = board_state + self.piece_id_to_instance = piece_id_to_instance + self.piece_str_to_id = piece_str_to_id + self.piece_id_to_str = piece_id_to_str + + def __str__(self) -> str: + str_board = [["" for _ in range(8)] for _ in range(8)] + + for row_idx in range(len(self.board_state)): + row = self.board_state[row_idx] + for col_idx in range(len(row)): + piece_color, piece_id = parse_piece(self.board_state, row_idx, col_idx) + + if piece_color != "E": + white_piece = piece_color == "W" + s = ( + self.piece_id_to_instance[piece_id].white_render + if white_piece + else self.piece_id_to_instance[piece_id].black_render + ) + else: + s = "\u25A1" + str_board[row_idx][col_idx] = s + + # Add letters on bottom + str_board += [["-"] * 8] + str_board += [["a", "b", "c", "d", "e", "f", "g", "h"]] + + # Add numbers on side + str_board = [["|"] + row for row in str_board] + numbers = list(range(8, 0, -1)) + [" ", " "] + str_board = [[str(numbers[idx])] + row for (idx, row) in enumerate(str_board)] + + # Render as string + str_board = "\n".join([" ".join(row) for row in str_board]) + return str_board + + def _update_board(self, move: Move): + """ + Updates board_state according to given move. This move must have previously been checked + to be legal. Edge cases for moves that: + 1) Take pieces at other positions where this piece isn't moving (en passant) + 2) Move two pieces (castling) + 3) Change the id of the piece (promotion) + """ + start_coord, target_coord = move.start_coord, move.target_coord + piece_color, piece_id = parse_piece(self.board_state, start_coord[0], start_coord[1]) + target_piece_color, target_piece_id = parse_piece( + self.board_state, target_coord[0], target_coord[1] + ) + + # En passant + if piece_id == 0 and target_piece_color == "E": + dy = target_coord[1] - start_coord[1] + target_en_passant_piece = [start_coord[0], start_coord[1] + dy] + self.board_state[target_en_passant_piece[0]][target_en_passant_piece[1]] = "E" + + # Castling + if move.castling: + path = get_path_between_coords(start_coord, target_coord) + rook_tile = path[0] + self.board_state[rook_tile[0]][rook_tile[1]] = f"{piece_color}3" + + kingside = target_coord[1] <= 4 + old_rook_tile = [start_coord[0], 0] if kingside else [start_coord[0], 7] + self.board_state[old_rook_tile[0]][old_rook_tile[1]] = "E" + + # Move piece + self.board_state[start_coord[0]][start_coord[1]] = "E" + self.board_state[target_coord[0]][target_coord[1]] = f"{piece_color}{piece_id}" + + # Promotion + if move.promotion is not None: + self.board_state[target_coord[0]][target_coord[1]] = f"{piece_color}{move.promotion}" + + def _get_player_moves(self, player_id: str, previous_moves: Sequence[Move]) -> Sequence[Move]: + """ + Returns all possible moves by pieces for a player. Doesn't filter out moves that + result in the king being placed under check + """ + moves = [] + for row_idx in range(len(self.board_state)): + row = self.board_state[row_idx] + for col_idx in range(len(row)): + piece_color, piece_id = parse_piece(self.board_state, row_idx, col_idx) + if piece_color != player_id: + continue + + piece = self.piece_id_to_instance[piece_id] + possible_piece_moves = piece.get_piece_moves( + self.board_state, player_id, [row_idx, col_idx], previous_moves + ) + moves += possible_piece_moves + + return moves + + def _is_king_in_check(self, player_id: str) -> bool: + other_player_id = get_other_player_id(player_id) + + other_player_moves = self._get_player_moves(other_player_id, []) + king_capturing_moves = self._filter_for_king_capturing_moves(other_player_moves, player_id) + return len(king_capturing_moves) != 0 + + def _filter_for_king_capturing_moves( + self, moves: Sequence[Move], king_color: str + ) -> Sequence[Move]: + king_capturing_moves = [] + for move in moves: + piece_color, piece_id = parse_piece( + self.board_state, move.target_coord[0], move.target_coord[1] + ) + if piece_color == king_color and piece_id == 5: + king_capturing_moves.append(move) + + return king_capturing_moves + + +class BoardController: + """ + Manages a single game of chess. Contains logic to find all legal + moves for a particular player and update the internal board according + to a given move. Maintains one Board obj to represent the true state of play + """ + + def __init__( + self, + board_init: Callable[..., Sequence[Sequence[str]]], + piece_id_to_instance: Dict[int, Piece], + piece_str_to_id: Dict[str, int], + piece_id_to_str: Dict[int, str], + notation_parser: NotationParser, + ): + self.board = Board(board_init(), piece_id_to_instance, piece_str_to_id, piece_id_to_str) + self.notation_parser = notation_parser + + self.previous_moves = [] + + def __str__(self) -> str: + return self.board.__str__() + + def update_board(self, move: str): + """ + Parses move, updates the internal board state, then stores the move + since knowing previous moves is necessary for En Passant and castling + """ + move = self.notation_parser._str_to_move(move, self.board.board_state) + self.board._update_board(move) + self.previous_moves.append(move) + + def get_player_legal_moves(self, player_id: str) -> Sequence[str]: + """ + Gets all legal moves for a player with the given player_id, returned in + the notation this object was initialised with + """ + legal_moves = self.board._get_player_moves(player_id, self.previous_moves) + legal_moves = self._filter_to_prevent_pinning(legal_moves, player_id) + + legal_moves = [ + self.notation_parser._move_to_str(i, self.board.board_state) for i in legal_moves + ] + return legal_moves + + def _filter_to_prevent_pinning(self, moves: Sequence[Move], player_id: str) -> Sequence[Move]: + """ + Filter out moves that would result in the king being pinned, or the king moving over a pinned + position when castling + """ + + def _is_valid_castling(move: Move) -> bool: + if self.board._is_king_in_check(player_id): + return False + + # Check that the king won't move over an attacked position + dy = (move.target_coord[1] - move.start_coord[1]) / abs( + move.target_coord[1] - move.start_coord[1] + ) + king_path = get_path_between_coords( + move.start_coord, [move.target_coord[0], move.target_coord[1] + dy] + ) + + not_pinned_along_path = [] + for coord in king_path: + simulated_board = copy.deepcopy(self.board) + simulated_board._update_board( + Move(move.start_coord, coord, promotion=None, castling=False) + ) + pinned = simulated_board._is_king_in_check(player_id) + not_pinned_along_path.append(not pinned) + + if all(not_pinned_along_path): + return True + + return False + + filtered_moves = [] + for move in moves: + if move.castling and _is_valid_castling(move): + filtered_moves.append(move) + elif not move.castling: + simulated_board = copy.deepcopy(self.board) + simulated_board._update_board(move) + if not simulated_board._is_king_in_check(player_id): + filtered_moves.append(move) + + return filtered_moves + + def _is_checkmate(self, player_id: str) -> bool: + legal_moves = self.get_player_legal_moves(player_id) + if len(legal_moves) == 0 and self.board._is_king_in_check(player_id): + return True + return False + + def _is_stalemate(self, player_id: str) -> bool: + legal_moves = self.get_player_legal_moves(player_id) + if len(legal_moves) == 0 and not self.board._is_king_in_check(player_id): + return True + return False diff --git a/evals/elsuite/cant_do_that_anymore/chess/board_test.py b/evals/elsuite/cant_do_that_anymore/chess/board_test.py new file mode 100644 index 0000000000..0d163f289c --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/board_test.py @@ -0,0 +1,95 @@ +import random +import time +from typing import Sequence + +import pytest +from tqdm import tqdm + +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.move_variants import ( + PIECE_ID_TO_INSTANCE, + PIECE_ID_TO_STR, + PIECE_STR_TO_ID, +) +from evals.elsuite.cant_do_that_anymore.chess.notation import AlgebraicNotationParser + +N_GAMES = 100 +MAX_MOVES = 1000 +VERBOSE = False +VERBOSE_SLOWDOWN = 2 + + +def default_board_init() -> Sequence[Sequence[str]]: + board = [ + ["B3", "B1", "B2", "B4", "B5", "B2", "B1", "B3"], + ["B0", "B0", "B0", "B0", "B0", "B0", "B0", "B0"], + ["E", "E", "E", "E", "E", "E", "E", "E"], + ["E", "E", "E", "E", "E", "E", "E", "E"], + ["E", "E", "E", "E", "E", "E", "E", "E"], + ["E", "E", "E", "E", "E", "E", "E", "E"], + ["W0", "W0", "W0", "W0", "W0", "W0", "W0", "W0"], + ["W3", "W1", "W2", "W4", "W5", "W2", "W1", "W3"], + ] + return board + + +@pytest.mark.skip # avoid unit test that requires chess library +def simulate_games(): + """ + Simulates full chess games and asserts that at every position, the + set of legal moves is equivalent to the legal moves reported by the + python-chess library + + Install such library with: + pip install chess + """ + import chess + + for _ in tqdm(range(N_GAMES)): + my_controller = BoardController( + default_board_init, + PIECE_ID_TO_INSTANCE, + PIECE_STR_TO_ID, + PIECE_ID_TO_STR, + AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR), + ) + their_controller = chess.Board() # python-chess equivalent + + my_player_id = "W" + for _ in range(MAX_MOVES): + our_legal_moves = sorted(my_controller.get_player_legal_moves(my_player_id)) + their_legal_moves = sorted([str(i) for i in their_controller.legal_moves]) + + if our_legal_moves != their_legal_moves: + our_additional_moves = list(set(our_legal_moves) - set(their_legal_moves)) + their_additional_moves = list(set(their_legal_moves) - set(our_legal_moves)) + print( + f""" + Inconsistent legal moves between the boards! + Our legal moves: {our_legal_moves}, + Their legal moves: {their_legal_moves}, + Moves we had they didnt: {our_additional_moves}, + Moves they had we didn't: {their_additional_moves}, + Board state:\n{my_controller.board.board_state} + """ + ) + assert False + + if len(our_legal_moves) == 0: + break + + # Pick random move + move = random.choice(our_legal_moves) + my_controller.update_board(move) + their_controller.push_san(move) + + my_player_id = "B" if my_player_id == "W" else "W" + + if VERBOSE: + print(my_controller) + print(move) + time.sleep(VERBOSE_SLOWDOWN) + + +if __name__ == "__main__": + simulate_games() diff --git a/evals/elsuite/cant_do_that_anymore/chess/move_variants.py b/evals/elsuite/cant_do_that_anymore/chess/move_variants.py new file mode 100644 index 0000000000..50f48c78e1 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/move_variants.py @@ -0,0 +1,120 @@ +# Default initialization +from evals.elsuite.cant_do_that_anymore.chess.pieces import Piece + +# Generic type of moves +STRAIGHT_MOVES = [[0, i] for i in range(-8, 9)] + [[i, 0] for i in range(-8, 9)] +DIAGONAL_MOVES = [[i, i] for i in range(-8, 9)] + [[-i, i] for i in range(-8, 9)] + +# Piece-specific moves +PAWN_MOVES_WHITE = [ + [-1, 0], +] +PAWN_MOVES_BLACK = [ + [1, 0], +] +PAWN_CAPTURING_MOVES = [ + [1, 1], + [1, -1], +] +KNIGHT_MOVES = [ + [1, 2], + [2, 1], + [2, -1], + [1, -2], + [-1, -2], + [-2, -1], + [-2, 1], + [-1, 2], +] +BISHOP_MOVES = DIAGONAL_MOVES +ROOK_MOVES = STRAIGHT_MOVES +QUEEN_MOVES = DIAGONAL_MOVES + STRAIGHT_MOVES +KING_MOVES = [ + [0, 1], + [1, 1], + [1, 0], + [1, -1], + [0, -1], + [-1, -1], + [-1, 0], + [-1, 1], +] + +PIECE_ID_TO_INSTANCE = { + 0: Piece( + 0, + "\u265F", + "\u2659", + PAWN_MOVES_WHITE, + PAWN_MOVES_BLACK, + PAWN_CAPTURING_MOVES, + can_double_step=True, + can_en_passant=True, + captures_like_pawn=True, + can_promote=True, + ), + 1: Piece(1, "\u265E", "\u2658", KNIGHT_MOVES, KNIGHT_MOVES, can_jump_over_pieces=True), + 2: Piece( + 2, + "\u265D", + "\u2657", + BISHOP_MOVES, + BISHOP_MOVES, + ), + 3: Piece( + 3, + "\u265C", + "\u2656", + ROOK_MOVES, + ROOK_MOVES, + ), + 4: Piece( + 4, + "\u265B", + "\u2655", + QUEEN_MOVES, + QUEEN_MOVES, + ), + 5: Piece(5, "\u265A", "\u2654", KING_MOVES, KING_MOVES, can_castle=True), +} +# Bishops can move like knights in this variant. All other pieces play normally +VARIANT_PIECE_ID_TO_INSTANCE = { + 0: Piece( + 0, + "\u265F", + "\u2659", + PAWN_MOVES_WHITE, + PAWN_MOVES_BLACK, + PAWN_CAPTURING_MOVES, + can_double_step=True, + can_en_passant=True, + captures_like_pawn=True, + can_promote=True, + ), + 1: Piece(1, "\u265E", "\u2658", KNIGHT_MOVES, KNIGHT_MOVES, can_jump_over_pieces=True), + 2: Piece( + 2, + "\u265D", + "\u2657", + KNIGHT_MOVES, + KNIGHT_MOVES, + can_jump_over_pieces=True, + ), + 3: Piece( + 3, + "\u265C", + "\u2656", + ROOK_MOVES, + ROOK_MOVES, + ), + 4: Piece( + 4, + "\u265B", + "\u2655", + QUEEN_MOVES, + QUEEN_MOVES, + ), + 5: Piece(5, "\u265A", "\u2654", KING_MOVES, KING_MOVES, can_castle=True), +} +PIECE_STR_TO_ID = {"p": 0, "n": 1, "b": 2, "r": 3, "q": 4, "k": 5} +PIECE_ID_TO_STR = {0: "p", 1: "n", 2: "b", 3: "r", 4: "q", 5: "k"} diff --git a/evals/elsuite/cant_do_that_anymore/chess/notation.py b/evals/elsuite/cant_do_that_anymore/chess/notation.py new file mode 100644 index 0000000000..3d7b113b51 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/notation.py @@ -0,0 +1,106 @@ +import re +from abc import abstractmethod +from typing import Sequence + +from evals.elsuite.cant_do_that_anymore.chess.utils import Move, parse_piece + +letters = ["a", "b", "c", "d", "e", "f", "g", "h"] +letter_to_num = {i: idx for (idx, i) in enumerate(letters)} +num_to_letter = {idx: i for (idx, i) in enumerate(letters)} + + +def row_idx_swap(n: int) -> int: + return 8 - n + + +def coord_str_to_pos(s: str) -> Sequence[int]: + return [ + 8 - int(s[1]), + letter_to_num[s[0]], + ] + + +def coord_pos_to_str(s: str) -> str: + a = num_to_letter[s[1]] + b = 8 - s[0] + return f"{a}{b}".upper() + + +class NotationParser: + def __init__(self, piece_str_to_id, piece_id_to_str) -> None: + self.piece_str_to_id = piece_str_to_id + self.piece_id_to_str = piece_id_to_str + + @abstractmethod + def _str_to_move(self, s: str, board_state: Sequence[Sequence[int]], player_id: str) -> Move: + raise NotImplementedError() + + @abstractmethod + def _move_to_str(self, move: Move, board_state: Sequence[Sequence[int]], player_id: str) -> str: + raise NotImplementedError() + + +class AlgebraicNotationParser(NotationParser): + """ + Converts between coordinates of the board and algebraic notation [0]. The exact implementation + is consistent with the python-chess library + + The regex pattern matches the following groups: + (1) Letter indicating piece to be moved (unused) + (2) Row of piece to be moved + (3) Column of piece to be moved + (4) Row+column of where piece is being moved + (5) Letter indicating what piece the current piece is being promoted to + (6) Special characters indicating status of game (unused) + + [0] https://en.wikipedia.org/wiki/Algebraic_notation_(chess) + [1] https://github.com/niklasf/python-chess + """ + + pattern = re.compile(r"([a-h])([1-8])([a-h][1-8])(=?[nbrqkNBRQK])?") + + def _str_to_move(self, s: str, board_state: Sequence[Sequence[int]]) -> Move: + match = self.pattern.match(s) + if match is None: + raise ValueError( + f"Incorrect notation for move! Full start and end position must be given. Using algebraic notation, got: {s}" + ) + + # Parse start coord + start_row = row_idx_swap(int(match.group(2))) if match.group(2) is not None else None + start_col = letter_to_num[match.group(1)] if match.group(1) is not None else None + start_coord = [start_row, start_col] + + # Parse to coord + to_row = row_idx_swap(int(match.group(3)[1])) + to_col = letter_to_num[match.group(3)[0]] + to_coord = [to_row, to_col] + + # Promotions + promotion = match.group(4) + if promotion is not None: + promotion = self.piece_str_to_id[promotion] + + # Castling + castling = False + if start_row is not None and start_col is not None: + _, piece_id = parse_piece(board_state, start_row, start_col) + if piece_id == 5 and abs(start_col - to_col) == 2: + castling = True + + return Move(start_coord, to_coord, promotion, castling) + + def _move_to_str(self, move: Move, board_state: Sequence[Sequence[int]]) -> str: + out_str = "" + start_coord, target_coord = move.start_coord, move.target_coord + + start = f"{num_to_letter[start_coord[1]]}{row_idx_swap(start_coord[0])}".lower() + out_str += start + + target = f"{num_to_letter[target_coord[1]]}{row_idx_swap(target_coord[0])}".lower() + out_str += target + + if move.promotion is not None: + out_str += self.piece_id_to_str[move.promotion] + + return out_str diff --git a/evals/elsuite/cant_do_that_anymore/chess/pieces.py b/evals/elsuite/cant_do_that_anymore/chess/pieces.py new file mode 100644 index 0000000000..9692a0170c --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/pieces.py @@ -0,0 +1,263 @@ +import copy +from typing import Sequence + +from evals.elsuite.cant_do_that_anymore.chess.utils import ( + Move, + coord_within_board, + get_other_player_id, + get_path_between_coords, + has_piece_been_moved, + move_crosses_pieces, + parse_piece, +) + + +class Piece: + def __init__( + self, + piece_id: int, + white_render: str, + black_render: str, + possible_moves_white: Sequence[Sequence[int]], + possible_moves_black: Sequence[Sequence[int]], + possible_capturing_moves: Sequence[Sequence[int]] = None, + can_double_step: bool = False, + can_en_passant: bool = False, + captures_like_pawn: bool = False, + can_promote: bool = False, + can_jump_over_pieces: bool = False, + can_castle: bool = False, + ): + self.piece_id = piece_id + self.white_render = white_render + self.black_render = black_render + self.possible_moves_white = possible_moves_white + self.possible_moves_black = possible_moves_black + self.possible_capturing_moves = possible_capturing_moves + + self.can_double_step = can_double_step + self.can_en_passant = can_en_passant + self.captures_like_pawn = captures_like_pawn + self.can_promote = can_promote + self.can_jump_over_pieces = can_jump_over_pieces + self.can_castle = can_castle + + def get_piece_moves( + self, + board_state: Sequence[Sequence[int]], + player_id: str, + start_coord: Sequence[int], + previous_moves: Sequence[Move], + ) -> Sequence[Move]: + """ + Returns a sequence representing all moves this piece can make given the current environment + and rules this piece follows + """ + if player_id == "W": + possible_transformations = copy.deepcopy(self.possible_moves_white) + forward_direction = -1 + else: + possible_transformations = copy.deepcopy(self.possible_moves_black) + forward_direction = 1 + + # Get all relative transformations piece can make + if self.can_double_step: + possible_transformations += self._get_pawn_double_step_transformations( + player_id, start_coord + ) + if self.captures_like_pawn: + possible_transformations = self._remove_illegal_pawn_capture_transformations( + board_state, player_id, start_coord, possible_transformations, forward_direction + ) + if self.can_en_passant: + possible_transformations += self._get_en_passant_transformations( + board_state, start_coord, previous_moves, forward_direction + ) + + # Find all legal moves from transformations + piece_moves = self._get_moves_from_transformations( + board_state, player_id, start_coord, possible_transformations + ) + + # Add rule-specific moves + if self.can_promote: + piece_moves = self._add_promotion_moves(piece_moves) + if self.can_castle: + piece_moves += self._get_castling_possible_moves(board_state, player_id, previous_moves) + + return piece_moves + + def _get_moves_from_transformations( + self, + board_state: Sequence[Sequence[int]], + player_id: str, + start_coord: Sequence[int], + possible_transformations: Sequence[Sequence[int]], + ) -> Sequence[Move]: + """ + Given a piece's position within a board and the set of possible relative + transformations the piece can make, convert each transformation into a `Move` + object if: + 1) Transformation results in piece being on board + 2) Transformation doesn't result in piece ending up on piece of same color + 3) Transformation doesn't "jump" over other pieces, unless this piece is + allowed to do so (e.g. knight) + """ + piece_moves = [] + for move in possible_transformations: + new_row_idx = start_coord[0] + move[0] + new_col_idx = start_coord[1] + move[1] + + if not coord_within_board(new_row_idx, new_col_idx): + continue + + target_coord = [new_row_idx, new_col_idx] + target_piece_color, target_piece_id = parse_piece( + board_state, + target_coord[0], + target_coord[1], + ) + move = Move(start_coord, target_coord, None, False) + + if target_piece_color == player_id: + continue + if not self.can_jump_over_pieces and move_crosses_pieces(board_state, move): + continue + + piece_moves.append(move) + + return piece_moves + + def _get_pawn_double_step_transformations( + self, player_id: str, start_coord: Sequence[int] + ) -> Sequence[Sequence[int]]: + if player_id == "W" and start_coord[0] == 6: + return [[-2, 0]] + elif player_id == "B" and start_coord[0] == 1: + return [[2, 0]] + return [] + + def _remove_illegal_pawn_capture_transformations( + self, + board_state: Sequence[Sequence[int]], + player_id: str, + start_coord: Sequence[int], + possible_transformations: Sequence[Sequence[int]], + forward_direction: int, + ) -> Sequence[Sequence[int]]: + """ + Prevents pawns from "capturing forward" + """ + if self.piece_id != 0: + return possible_transformations + + new_possible_transformations = [] + capturing_moves = self.possible_capturing_moves + capturing_moves = [[move[0] * forward_direction, move[1]] for move in capturing_moves] + for move in possible_transformations + capturing_moves: + new_row_idx = start_coord[0] + move[0] + new_col_idx = start_coord[1] + move[1] + + if not coord_within_board(new_row_idx, new_col_idx): + continue + + target_piece_color, target_piece_id = parse_piece(board_state, new_row_idx, new_col_idx) + + if target_piece_color == "E" and move not in capturing_moves: + new_possible_transformations.append(move) + elif target_piece_color == get_other_player_id(player_id) and move in capturing_moves: + new_possible_transformations.append(move) + + return new_possible_transformations + + def _get_en_passant_transformations( + self, + board_state: Sequence[Sequence[int]], + start_coord: Sequence[int], + previous_moves: Sequence[Move], + forward_direction: int, + ) -> Sequence[Sequence[int]]: + last_move = previous_moves[-1] if len(previous_moves) > 0 else None + if last_move is not None and self.piece_id == 0: + _, last_piece_id = parse_piece( + board_state, last_move.target_coord[0], last_move.target_coord[1] + ) + + # If last move was pawn moving two tiles + if ( + last_piece_id == 0 + and abs(last_move.start_coord[0] - last_move.target_coord[0]) == 2 + ): + + # If on same row and one column apart + dx = start_coord[1] - last_move.target_coord[1] + dy = start_coord[0] - last_move.target_coord[0] + if dy == 0 and abs(dx) == 1: + return [[forward_direction, -dx]] + return [] + + def _add_promotion_moves(self, piece_moves: Sequence[Move]) -> Sequence[Move]: + new_piece_moves = [] + for move in piece_moves: + target_coord = move.target_coord + if target_coord[0] == 0 or target_coord[0] == 7: + for promotion_piece_id in [1, 2, 3, 4]: + move_promotion = copy.deepcopy(move) + move_promotion.promotion = promotion_piece_id + new_piece_moves.append(move_promotion) + else: + new_piece_moves.append(move) + + return new_piece_moves + + def _get_castling_possible_moves( + self, board_state: Sequence[Sequence[int]], player_id: str, previous_moves: Sequence[Move] + ) -> Sequence[Move]: + castling_moves = [] + if self.piece_id != 5: + return castling_moves + + def _can_pieces_castle( + king_init_coord: Sequence[int], rook_init_coord: Sequence[int], init_rook_id: int + ) -> Sequence[Move]: + if init_rook_id != 3: + return [] + + if has_piece_been_moved(king_init_coord, previous_moves) or has_piece_been_moved( + rook_init_coord, previous_moves + ): + return [] + + king_to_rook_move = Move(king_init_coord, rook_init_coord, None, False) + if move_crosses_pieces(board_state, king_to_rook_move): + return [] + + king_to_rook_path = get_path_between_coords(king_init_coord, rook_init_coord) + move = Move(king_init_coord, king_to_rook_path[1], None, True) + return [move] + + # ASSUME board init + king_init_coord = [7, 4] if player_id == "W" else [0, 4] + _, init_king_id = parse_piece(board_state, king_init_coord[0], king_init_coord[1]) + if init_king_id != 5: + return castling_moves + + # Queenside + queenside_rook_init_coord = [7, 7] if player_id == "W" else [0, 7] + _, init_rook_id = parse_piece( + board_state, queenside_rook_init_coord[0], queenside_rook_init_coord[1] + ) + castling_moves += _can_pieces_castle( + king_init_coord, queenside_rook_init_coord, init_rook_id + ) + + # Kingside + kingside_rook_init_coord = [7, 0] if player_id == "W" else [0, 0] + _, init_rook_id = parse_piece( + board_state, kingside_rook_init_coord[0], kingside_rook_init_coord[1] + ) + castling_moves += _can_pieces_castle( + king_init_coord, kingside_rook_init_coord, init_rook_id + ) + + return castling_moves diff --git a/evals/elsuite/cant_do_that_anymore/chess/utils.py b/evals/elsuite/cant_do_that_anymore/chess/utils.py new file mode 100644 index 0000000000..a92d072037 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/chess/utils.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass +from typing import Sequence + + +@dataclass +class Move: + start_coord: Sequence[int] + target_coord: Sequence[int] + promotion: int # Either None for no promotion, or int for piece id of promotion + castling: bool + + +def get_other_player_id(this_player_id: str) -> str: + if this_player_id == "W": + return "B" + elif this_player_id == "B": + return "W" + else: + raise ValueError(f"this_player_id var must be 'W' or 'B', but is: {this_player_id}") + + +def parse_piece( + board_state: Sequence[Sequence[int]], row_idx: int, col_idx: int +) -> tuple[str, int]: + """ + Returns the color and id of the piece at the given coords. + """ + piece = board_state[row_idx][col_idx] + if piece == "E": + return "E", -1 + + color = piece[0] + id = piece[1] + return color, int(id) + + +def move_crosses_pieces(board_state: Sequence[Sequence[int]], move: Move) -> bool: + path = get_path_between_coords(move.start_coord, move.target_coord) + for (x1, y1) in path: + if board_state[x1][y1] != "E": + return True + + return False + + +def has_piece_been_moved( + piece_coord: Sequence[Sequence[int]], previous_moves: Sequence[Move] +) -> bool: + for move in previous_moves: + if move.start_coord == piece_coord: + return True + if move.target_coord == piece_coord: + return True + return False + + +def coord_within_board(row_idx: int, col_idx: int) -> bool: + if row_idx < 0 or row_idx > 7: + return False + if col_idx < 0 or col_idx > 7: + return False + + return True + + +def move_within_board(move: Move) -> bool: + target_coord = move.target_coord + return coord_within_board(target_coord[0], target_coord[1]) + + +def get_path_between_coords( + start_coord: Sequence[int], target_coord: Sequence[int] +) -> Sequence[Sequence[int]]: + # Unpack the start and end points + x1, y1 = start_coord + x2, y2 = target_coord + + # Determine the steps to take in each direction + dx = 1 if x2 > x1 else -1 if x2 < x1 else 0 + dy = 1 if y2 > y1 else -1 if y2 < y1 else 0 + + path = [(x1, y1)] + while (x1, y1) != (x2, y2): + if x1 != x2: + x1 += dx + if y1 != y2: + y1 += dy + path.append((x1, y1)) + + path = path[1:-1] + return path + + +def same_color_piece_at_move_start( + board_state: Sequence[Sequence[int]], move: Move, player_color: str +) -> bool: + start_coord = move.start_coord + piece_color, _ = parse_piece(board_state, start_coord[0], start_coord[1]) + return player_color == piece_color + + +def capturing_same_color(board_state: Sequence[Sequence[int]], move: Move) -> bool: + start_coord, target_coord = move.start_coord, move.target_coord + start_piece_color, _ = parse_piece(board_state, start_coord[0], start_coord[1]) + target_piece_color, _ = parse_piece(board_state, target_coord[0], target_coord[1]) + + return start_piece_color == target_piece_color diff --git a/evals/elsuite/cant_do_that_anymore/defaults.py b/evals/elsuite/cant_do_that_anymore/defaults.py new file mode 100644 index 0000000000..8c8645f824 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/defaults.py @@ -0,0 +1,15 @@ +TASK_DESCRIPTION = """You are playing a game of chess. Play the best move available to you. + +You must indicate which move you want to make by outputting the move using algebraic notation. Do not output anything else. +- Always write the full start and end position, for example, "a7b6" +- To castle, write the start and end position of the king, for example, "e8g8" +- Do not write "x" for capturing moves""" + +TASK_DESCRIPTION_VARIANT = """You are playing a game of a variant of chess. Play the best move available to you. + +You must indicate which move you want to make by outputting the move using algebraic notation. Do not output anything else. +- Always write the full start and end position, for example, "a7b6" +- To castle, write the start and end position of the king, for example, "e8g8" +- Do not write "x" for capturing moves + +This variant of chess does not follow the same rules as chess. All bishops now move like knights do. DO NOT move bishops diagonally, this is now considered an illegal move. All other pieces don't have their rules changed.""" diff --git a/evals/elsuite/cant_do_that_anymore/eval.py b/evals/elsuite/cant_do_that_anymore/eval.py new file mode 100644 index 0000000000..0ca6df5b0b --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/eval.py @@ -0,0 +1,201 @@ +import logging +import random +from typing import Any, Dict, Sequence, Union + +import numpy as np + +import evals.metrics +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.board_test import default_board_init +from evals.elsuite.cant_do_that_anymore.chess.move_variants import ( + PIECE_ID_TO_INSTANCE, + PIECE_ID_TO_STR, + PIECE_STR_TO_ID, + VARIANT_PIECE_ID_TO_INSTANCE, +) +from evals.elsuite.cant_do_that_anymore.chess.notation import AlgebraicNotationParser +from evals.elsuite.cant_do_that_anymore.chess.pieces import Piece +from evals.elsuite.cant_do_that_anymore.chess.utils import ( + capturing_same_color, + move_within_board, + same_color_piece_at_move_start, +) +from evals.elsuite.cant_do_that_anymore.defaults import TASK_DESCRIPTION, TASK_DESCRIPTION_VARIANT +from evals.elsuite.cant_do_that_anymore.utils import ( + construct_messages, + get_binary_avg, + get_dataset_path, + get_diagonal_dataset_path, +) +from evals.eval import SolverEval +from evals.record import RecorderBase +from evals.solvers.solver import Solver, SolverResult +from evals.task_state import TaskState + +logger = logging.getLogger(__name__) + + +class CantDoThatAnymore(SolverEval): + def __init__( + self, + default_model_dataset: str = "gpt-3.5-turbo-0125", + remake_dataset_if_not_found: bool = True, + n_samples: int = 1000, + diagonal_variation: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.default_model_dataset = default_model_dataset + self.remake_dataset_if_not_found = remake_dataset_if_not_found + self.n_samples = n_samples + self.diagonal_variation = diagonal_variation + self.rng: random.Random = random.Random(self.seed) + + def eval_sample(self, solver: Solver, sample: Any, rng: random.Random): + previous_moves, next_filtered_moves = ( + sample["previous_moves"], + sample["next_filtered_moves"], + ) + + def construct_controller(piece_id_to_instance: Dict[int, Piece]) -> BoardController: + controller = BoardController( + default_board_init, + piece_id_to_instance, + PIECE_STR_TO_ID, + PIECE_ID_TO_STR, + AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR), + ) + for move in previous_moves: + controller.update_board(move) + return controller + + default_controller = construct_controller(PIECE_ID_TO_INSTANCE) + variant_controller = construct_controller(VARIANT_PIECE_ID_TO_INSTANCE) + + # Get solver prediction. Ideally I wouldn't pass the legal_moves to the solvers, they + # should figure them out themselves, but it's necessary for the random solver + def get_solver_pred( + task_description: str, + controller: BoardController, + ) -> SolverResult: + task_state = TaskState( + task_description, + messages=construct_messages(previous_moves), + ) + return solver(task_state, **{"max_tokens": 4}) + + solver_result = get_solver_pred(TASK_DESCRIPTION, default_controller) + solver_result_variant = get_solver_pred(TASK_DESCRIPTION_VARIANT, variant_controller) + + metrics = { + "move": next_filtered_moves, + "predicted_move": solver_result.output.strip() in next_filtered_moves, + "predicted_move_in_variant": solver_result_variant.output.strip() + in next_filtered_moves, + "num_previous_moves": len(previous_moves), + "previous_moves": previous_moves, + } + + # Add violations to metrics + metrics.update( + self.get_violations( + default_controller, solver_result.output, previous_moves, "standard" + ) + ) + metrics.update( + self.get_violations( + variant_controller, solver_result_variant.output, previous_moves, "variant" + ) + ) + + evals.record.record_metrics(**metrics) + + def run(self, recorder: RecorderBase) -> dict[str, Union[float, int]]: + if self.diagonal_variation: + self.samples_jsonl = get_diagonal_dataset_path( + registry_path=self._prefix_registry_path("") + ) + else: + self.samples_jsonl = get_dataset_path( + solver=self._solver, + registry_path=self._prefix_registry_path(""), + remake_dataset_if_not_found=self.remake_dataset_if_not_found, + default_model_dataset=self.default_model_dataset, + ) + samples = self.get_samples() + samples = self.rng.sample(samples, min(self.n_samples, len(samples))) + + self.eval_all_samples(recorder, samples) + metrics = recorder.get_metrics() + + predicted_move_proportion = get_binary_avg(metrics, "predicted_move") + predicted_move_in_variant_proportion = get_binary_avg(metrics, "predicted_move_in_variant") + + avg_num_previous_moves = sum([i["num_previous_moves"] for i in metrics]) / len(metrics) + std_num_previous_moves = np.std([i["num_previous_moves"] for i in metrics]) + + delta = predicted_move_in_variant_proportion - predicted_move_proportion + variant_impact_factor = (delta / predicted_move_proportion) if predicted_move_proportion != 0 else 0 + + results = { + "variant_impact_factor": variant_impact_factor, + "delta": delta, + "predicted_move_proportion": predicted_move_proportion, + "predicted_move_in_variant_proportion": predicted_move_in_variant_proportion, + "avg_num_previous_moves": avg_num_previous_moves, + "std_num_previous_moves": std_num_previous_moves, + } + + # Add violations + violation_keys = [i for i in metrics[0].keys() if "violation" in i] + violation_results = { + f"{name}_rate": get_binary_avg(metrics, name) for name in violation_keys + } + results.update(violation_results) + + return results + + def get_violations( + self, + controller: BoardController, + solver_output: str, + previous_moves: Sequence[str], + variant_name: str, + ) -> dict: + solver_color = "W" if len(previous_moves) % 2 == 0 else "B" + + piece_moved_outside_board = False + moving_invalid_piece = False + piece_capturing_same_color = False + + violation_metrics = {} + try: + move = controller.notation_parser._str_to_move( + solver_output, controller.board.board_state + ) + + piece_moved_outside_board = not move_within_board(move) + moving_invalid_piece = not same_color_piece_at_move_start( + controller.board.board_state, move, solver_color + ) + piece_capturing_same_color = capturing_same_color(controller.board.board_state, move) + incorrect_notation = False + except (ValueError, KeyError): + incorrect_notation = True + + violation = ( + piece_moved_outside_board + or moving_invalid_piece + or piece_capturing_same_color + or incorrect_notation + ) + violation_metrics = { + f"{variant_name}_violation": violation, + f"{variant_name}_violation_moved_outside_board": piece_moved_outside_board, + f"{variant_name}_violation_moving_invalid_piece": moving_invalid_piece, + f"{variant_name}_violation_capturing_same_color": piece_capturing_same_color, + f"{variant_name}_violation_incorrect_notation": incorrect_notation, + } + return violation_metrics diff --git a/evals/elsuite/cant_do_that_anymore/scripts/dataset_creation.py b/evals/elsuite/cant_do_that_anymore/scripts/dataset_creation.py new file mode 100644 index 0000000000..e0c7a0265a --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/scripts/dataset_creation.py @@ -0,0 +1,312 @@ +import argparse +import copy +import os +import pathlib +from typing import Sequence + +import chess.pgn +import requests +import zstandard +from tqdm import tqdm + +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.utils import Move, parse_piece +from evals.elsuite.cant_do_that_anymore.utils import ( + assert_boards_consistent, + dump_sequence_to_jsonl, + initialise_boards, +) + + +def prepare_lichess_2014_dataset(out_dir: str) -> str: + """ + Downloads and extracts Lichess 2014 April dataset, returns the + path to the extracted .pgn file + """ + fname = "lichess_db_standard_rated_2014-04.pgn.zst" + raw_data_out_path = os.path.join(out_dir, fname) + if not os.path.exists(raw_data_out_path): + url = "https://database.lichess.org/standard/" + fname + r = requests.get(url) + open(raw_data_out_path, "wb").write(r.content) + + out_path = os.path.join(out_dir, "pgn_data.pgn") + if not os.path.exists(out_path): + input_file = pathlib.Path(raw_data_out_path) + with open(input_file, "rb") as compressed: + decomp = zstandard.ZstdDecompressor() + with open(out_path, "wb") as destination: + decomp.copy_stream(compressed, destination) + + return out_path + + +class MoveFilter: + def __call__( + self, + default_controller: BoardController, + variant_controller: BoardController, + move: chess.Move, + player_id: str, + ) -> bool: + raise NotImplementedError() + + +class SpecialMoveFilter(MoveFilter): + """ + Filters for moves that are: + 1) Legal under the normal rules of chess + 2) Illegal under the variant's rules (i.e. bishop is moved) + """ + + def __call__( + self, + default_controller: BoardController, + variant_controller: BoardController, + move: Move, + player_id: str, + ) -> bool: + if not is_move_illegal(default_controller, move, player_id) and is_move_illegal( + variant_controller, move, player_id + ): + return True + + return False + + +class ControlMoveFilter(MoveFilter): + """ + Finds positions where solvers should have (almost) equivalent predictions under + both sets of rules + Filters for moves that are: + 1) Legal under both the normal and variant's rules of chess + 2) Are on a board containing no bishops + 3) Are on a board where no pawns are close to promoting; neither players + pawns are in their last three rows + 4) Are on a board with more than four pieces between both players + """ + + def __call__( + self, + default_controller: BoardController, + variant_controller: BoardController, + move: Move, + player_id: str, + ) -> bool: + if is_move_illegal(default_controller, move, player_id): + return False + if is_move_illegal(variant_controller, move, player_id): + return False + + board_state = default_controller.board.board_state + num_pieces = 0 + for row_idx in range(8): + for col_idx in range(8): + _, piece_id = parse_piece(board_state, row_idx, col_idx) + if piece_id == 2: + return False + elif piece_id == 0: + if player_id == "W" and row_idx <= 2: + return False + elif player_id == "B" and row_idx >= 5: + return False + elif piece_id != -1: + num_pieces += 1 + + if num_pieces < 4: + return False + + return True + + +def is_move_illegal(controller: BoardController, move: chess.Move, player_id: str) -> bool: + legal_moves = controller.get_player_legal_moves(player_id) + if move in legal_moves: + return False + return True + + +def find_specific_moves_in_game( + game: chess.pgn.Game, + game_idx: int, + move_filter: MoveFilter, + default_controller: BoardController, + variant_controller: BoardController, + their_controller: chess.Board, + filter_if_found_previous: bool, +) -> Sequence[dict]: + """ + Given a game, finds all moves that satisfy the given filter + If filter_if_found_previous is True, only finds first move in game that + satisfies filter + """ + player_id = "W" + previous_moves = [] + filtered_moves = [] + for move in game.mainline_moves(): + move = move.uci() + + if move_filter(default_controller, variant_controller, move, player_id): + filtered_moves.append( + { + "game_idx": game_idx, + "previous_moves": copy.deepcopy(previous_moves), + "next_filtered_moves": [move], + "any_previous_move_found": len(filtered_moves) > 0, + } + ) + if filter_if_found_previous: + break + + # Ensure my implementation is correct + assert_boards_consistent(default_controller, their_controller, player_id) + + # Update boards + default_controller.update_board(move) + their_controller.push_san(move) + + variant_controller.board.board_state = default_controller.board.board_state + variant_controller.previous_moves = default_controller.previous_moves + + player_id = "B" if player_id == "W" else "W" + previous_moves.append(move) + + return filtered_moves + + +def create_dataset_of_specific_moves( + pgn_path: str, + move_filter: MoveFilter, + target_num_examples: int, + filter_if_found_previous: bool, + filter_for_unique_previous_moves: bool, + continuously_save: bool, + out_path: str, +): + """ + Iterates over games in dataset and filters move according to the given move_filter + If filter_for_unique_previous_moves is True, filter to only include moves that have + unique sets of previous moves + If continuously_save is True, saves dataset everytime it is updated + """ + pgn = open(pgn_path) + dataset = [] + unique_previous_moves = set() + + t_bar = tqdm(total=target_num_examples) + game_idx = 0 + while True: + game = chess.pgn.read_game(pgn) + if game is None: + break + + default_controller, variant_controller, their_controller = initialise_boards() + filtered_moves = find_specific_moves_in_game( + game, + game_idx, + move_filter, + default_controller, + variant_controller, + their_controller, + filter_if_found_previous, + ) + + if filter_for_unique_previous_moves: + for example in filtered_moves: + previous_moves = example["previous_moves"] + if set(previous_moves) not in unique_previous_moves: + dataset.append(example) + unique_previous_moves.add(frozenset(previous_moves)) + t_bar.update(1) + if continuously_save: + dump_sequence_to_jsonl(dataset, out_path) + + elif len(filtered_moves) > 0: + dataset += filtered_moves + t_bar.update(len(filtered_moves)) + if continuously_save: + dump_sequence_to_jsonl(dataset, out_path) + + game_idx += 1 + t_bar.set_description(f"Num games examined: {game_idx}") + + if len(dataset) >= target_num_examples: + break + + return dataset + + +def main(args: argparse.Namespace): + lichess_path = prepare_lichess_2014_dataset(args.out_dir) + + if args.make_special_moves: + move_filter = SpecialMoveFilter() + dataset_name = "special_moves_dataset.jsonl" + out_path = os.path.join(args.out_dir, dataset_name) + dataset = create_dataset_of_specific_moves( + lichess_path, + move_filter, + target_num_examples=args.n_moves, + filter_if_found_previous=args.filter_if_found_previous, + filter_for_unique_previous_moves=args.filter_for_unique_previous_moves, + continuously_save=args.continuously_save, + out_path=out_path, + ) + dump_sequence_to_jsonl(dataset, out_path) + + if args.make_control_moves: + move_filter = ControlMoveFilter() + dataset_name = "control_moves_dataset.jsonl" + out_path = os.path.join(args.out_dir, dataset_name) + dataset = create_dataset_of_specific_moves( + lichess_path, + move_filter, + target_num_examples=args.n_moves, + filter_if_found_previous=args.filter_if_found_previous, + filter_for_unique_previous_moves=args.filter_for_unique_previous_moves, + continuously_save=args.continuously_save, + out_path=out_path, + ) + dump_sequence_to_jsonl(dataset, out_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + parser.add_argument("--n_moves", type=int, default=5000) + parser.add_argument( + "--out_dir", type=str, default="./evals/registry/data/cant_do_that_anymore/" + ) + parser.add_argument( + "--make_special_moves", + action="store_true", + help="Whether to search and build a dataset of special moves", + default=False, + ) + parser.add_argument( + "--make_control_moves", + action="store_true", + help="Whether to search and build a dataset of control moves", + default=False, + ) + parser.add_argument( + "--filter_if_found_previous", + action="store_true", + help="Whether to filter out moves that have had previous moves that satisfy the filtering condition.", + default=False, + ) + parser.add_argument( + "--filter_for_unique_previous_moves", + action="store_true", + help="Whether to only search for moves with unique previous moves (up to such position at the move)", + default=False, + ) + parser.add_argument( + "--continuously_save", + action="store_true", + help="Whether to save the dataset everytime a new example has been found", + default=False, + ) + args = parser.parse_args() + + main(args) diff --git a/evals/elsuite/cant_do_that_anymore/scripts/diagonal_dataset_creation.py b/evals/elsuite/cant_do_that_anymore/scripts/diagonal_dataset_creation.py new file mode 100644 index 0000000000..491acf3c95 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/scripts/diagonal_dataset_creation.py @@ -0,0 +1,316 @@ +import argparse +import copy +import os +import random +from typing import Optional, Sequence + +from stockfish import Stockfish +from tqdm import tqdm + +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.move_variants import DIAGONAL_MOVES +from evals.elsuite.cant_do_that_anymore.chess.utils import ( + Move, + coord_within_board, + move_crosses_pieces, + parse_piece, +) +from evals.elsuite.cant_do_that_anymore.utils import dump_sequence_to_jsonl, initialise_boards + +# NOTE change threads, hash depending on hardware +# https://pypi.org/project/stockfish/ +STOCKFIAH_MOVES_CONSIDERED = 5 +STOCKFISH_DEPTH = 18 +STOCKFISH_PARAMS = { + "Debug Log File": "", + "Contempt": 0, + "Min Split Depth": 0, + "Threads": 8, + "Ponder": "false", + "Hash": 4096, + "MultiPV": 1, + "Skill Level": 10, + "Move Overhead": 10, + "Minimum Thinking Time": 20, + "Slow Mover": 100, + "UCI_Chess960": "true", + "UCI_LimitStrength": "false", + "UCI_Elo": 1500, +} + + +def get_stockfish_move(stockfish: Stockfish, num_moves_to_consider: int) -> str: + """ + Gets the next move predicted by stockfish. Gets top n predictions and + selects randomly weighted by each move's centipawn value + Filters out bishop promotions, since our variant shouldn't have bishops + """ + # Get top moves, filter out bad ones + top_moves = stockfish.get_top_moves(num_moves_to_consider) + + # Filter out bishop promotions + top_moves = [i for i in top_moves if not i["Move"].endswith("b")] + + # If stockfish considers moves that it knows will lead to mate, only + # select from these moves + mates = [i for i in top_moves if i["Mate"] is not None] + if len(mates) > 0: + top_moves = mates + + # Ensures centipawn value isn't None + if all([i["Centipawn"] is None for i in top_moves]): + for move in top_moves: + move["Centipawn"] = 1 + else: + top_moves = [i for i in top_moves if i["Centipawn"] is not None] + + # Makes all centipawns positive + min_centipawn_value = min([i["Centipawn"] for i in top_moves]) + for move in top_moves: + move["Centipawn"] += abs(min_centipawn_value) + + # Normalise centipawn to a probability distribution + centipawn_sum = sum([i["Centipawn"] for i in top_moves]) + for move in top_moves: + move["prob"] = move["Centipawn"] / centipawn_sum + + # Pick move randomly + prob = random.uniform(0, 1) + selected_move = None + for move in top_moves: + prob -= move["prob"] + if prob <= 0: + selected_move = move["Move"] + break + + return selected_move + + +def parse_stockfish_move(controller: BoardController, move: str) -> str: + """ + When stockfish outputs a castling move, the move is from the kings position to the + rooks position, e.g. "e8a8" + In my framework castling is indicated by the start+end position of the king, e.g. "e8c8" + This functions converts the stockfish notation to my notation + """ + move = controller.notation_parser._str_to_move(move, controller.board.board_state) + _, piece_id = parse_piece( + controller.board.board_state, move.start_coord[0], move.start_coord[1] + ) + + # If castling move + dy = move.target_coord[1] - move.start_coord[1] + if piece_id == 5: + if dy > 2 or dy < -2: + direction = dy / abs(dy) + if direction == 1: # Kingside castling + move.target_coord = [move.target_coord[0], move.target_coord[1] - 1] + else: # Queenside castling + move.target_coord = [move.target_coord[0], move.target_coord[1] + 2] + + move = controller.notation_parser._move_to_str(move, controller.board.board_state) + return move + + +def get_bishop_diagonal_moves(controller: BoardController, player_id: str) -> Sequence[str]: + """ + Gets all possible diagonal moves that a bishop could make on a board, even if the bishop isn't + allowed to move diagonally under the board's rules + """ + # Find all bishops on board + bishop_coords = [] + board_state = controller.board.board_state + for row_idx in range(8): + for col_idx in range(8): + piece_color, piece_id = parse_piece(board_state, row_idx, col_idx) + if piece_color == player_id and piece_id == 2: + bishop_coords.append([row_idx, col_idx]) + + # Find all possible diagonal movements of each bishop + bishop_diagonal_moves = [] + for row_idx, col_idx in bishop_coords: + for transformation in DIAGONAL_MOVES: + new_coord = [row_idx + transformation[0], col_idx + transformation[1]] + move = Move([row_idx, col_idx], new_coord, promotion=None, castling=False) + + # If piece doesn't move + if transformation[0] == 0 and transformation[1] == 0: + continue + # If transformation moves piece outside board + if not coord_within_board(new_coord[0], new_coord[1]): + continue + # If transformation moves onto piece of same color + piece_color, _ = parse_piece(controller.board.board_state, new_coord[0], new_coord[1]) + if piece_color == player_id: + continue + # If move crosses friendly pieces + if move_crosses_pieces(controller.board.board_state, move): + continue + + move = controller.notation_parser._move_to_str(move, controller.board.board_state) + bishop_diagonal_moves.append(move) + + return bishop_diagonal_moves + + +def find_specific_moves_in_game( + game_idx: int, + variant_controller: BoardController, + filter_if_found_previous: bool, + max_moves: int, +) -> Sequence[dict]: + """ + Simulates an individual game, using the variant's rules. Finds all possible + diagonal moves from bishops (even though moving bishops diagonally is + illegal under the variant) + If filter_if_found_previous is True, only finds the first position with possible + bishop moves + """ + stockfish = Stockfish(depth=STOCKFISH_DEPTH, parameters=STOCKFISH_PARAMS) + # HACK to have stockfish play our variant, just swap out the bishops for knights + # then later pretend the knights are bishops + stockfish.set_fen_position("rnnqknnr/pppppppp/8/8/8/8/PPPPPPPP/RNNQKNNR w KQkq - 0 1") + previous_moves = [] + player_id = "W" + + # Get ELO of each player + elos = [1350, 1000] + random.shuffle(elos) + white_elo, black_elo = elos + + bishop_diagonal_moves = [] + for _ in range(max_moves): + if player_id == "W": + stockfish.set_elo_rating(white_elo) + else: + stockfish.set_elo_rating(black_elo) + + # Find all diagonal bishop moves from this position + found_moves = get_bishop_diagonal_moves(variant_controller, player_id) + if len(found_moves) > 0: + bishop_diagonal_moves.append( + { + "game_idx": game_idx, + "previous_moves": copy.deepcopy(previous_moves), + "next_filtered_moves": found_moves, + } + ) + if filter_if_found_previous: + break + + move = get_stockfish_move(stockfish, STOCKFIAH_MOVES_CONSIDERED) + stockfish.make_moves_from_current_position([move]) + + # Parse into notation that is compatible with my framework + move = parse_stockfish_move(variant_controller, move) + variant_controller.update_board(move) + + player_id = "B" if player_id == "W" else "W" + previous_moves.append(move) + + # If checkmate or stalemate, end + if len(variant_controller.get_player_legal_moves(player_id)) == 0: + break + + return bishop_diagonal_moves + + +def create_bishop_diagonal_dataset( + target_num_examples: int, + max_moves: int, + filter_if_found_previous: bool, + filter_for_unique_previous_moves: bool, + continuously_save: bool, + out_path: Optional[str], +) -> Sequence[dict]: + """ + Simulates stockfish games and finds possible diagonal moves that could be + made by bishops. + If filter_if_found_previous is True, finds the first move that satisfies this + criteria in each game + If filter_for_unique_previous_moves is True, filters to ensure each + example has a unique set of previous moves + If continuously_save is True, saves dataset everytime it is updated + """ + dataset = [] + unique_previous_moves = set() + + t_bar = tqdm(total=target_num_examples) + game_idx = 0 + while True: + _, variant_controller, _ = initialise_boards() + filtered_moves = find_specific_moves_in_game( + game_idx, + variant_controller, + filter_if_found_previous, + max_moves, + ) + + if filter_for_unique_previous_moves: + for example in filtered_moves: + previous_moves = example["previous_moves"] + if set(previous_moves) not in unique_previous_moves: + dataset.append(example) + unique_previous_moves.add(frozenset(previous_moves)) + t_bar.update(1) + if continuously_save: + dump_sequence_to_jsonl(dataset, out_path) + + elif len(filtered_moves) > 0: + dataset += filtered_moves + t_bar.update(len(filtered_moves)) + if continuously_save: + dump_sequence_to_jsonl(dataset, out_path) + + game_idx += 1 + t_bar.set_description(f"Num games examined: {game_idx}") + + if len(dataset) >= target_num_examples: + break + + return dataset + + +def main(args: argparse.Namespace): + dataset_name = "diagonal_moves_dataset.jsonl" + out_path = os.path.join(args.out_dir, dataset_name) + dataset = create_bishop_diagonal_dataset( + target_num_examples=args.n_moves, + max_moves=args.max_moves, + filter_if_found_previous=args.filter_if_found_previous, + filter_for_unique_previous_moves=args.filter_for_unique_previous_moves, + continuously_save=args.continuously_save, + out_path=out_path, + ) + dump_sequence_to_jsonl(dataset, out_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + parser.add_argument("--n_moves", type=int, default=5000) + parser.add_argument("--max_moves", type=int, default=50) + parser.add_argument( + "--out_dir", type=str, default="./evals/registry/data/cant_do_that_anymore/" + ) + parser.add_argument( + "--filter_if_found_previous", + action="store_true", + help="Whether to filter out moves that have had previous moves that satisfy the filtering condition", + default=False, + ) + parser.add_argument( + "--filter_for_unique_previous_moves", + action="store_true", + help="Whether to only search for moves with unique previous moves (up to such position at the move)", + default=False, + ) + parser.add_argument( + "--continuously_save", + action="store_true", + help="Whether to save the dataset everytime a new example has been found", + default=False, + ) + args = parser.parse_args() + + main(args) diff --git a/evals/elsuite/cant_do_that_anymore/scripts/make_plots.py b/evals/elsuite/cant_do_that_anymore/scripts/make_plots.py new file mode 100644 index 0000000000..bd0ea4d5cc --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/scripts/make_plots.py @@ -0,0 +1,128 @@ +import argparse +import os +from pathlib import Path +from typing import Sequence + +import pandas as pd +from matplotlib import pyplot as plt + +from evals.elsuite.cant_do_that_anymore.chess.utils import parse_piece +from evals.elsuite.cant_do_that_anymore.utils import initialise_boards +from evals.utils.log_utils import ( + extract_individual_results, + extract_spec, + get_final_results_from_dir, +) + + +def extract_results(datadir: Path) -> pd.DataFrame: + df_agg = [] # Aggregated results + df_samples = [] # Per sample results + for path, results in sorted(list(get_final_results_from_dir(datadir).items())): + spec = extract_spec(path) + solver_path = Path(spec["completion_fns"][0]) + model = solver_path.name + solver = solver_path.parent.name + # Remove root section of path, which is the eval name + solver_path = solver_path.relative_to(solver_path.parts[0]) + # Aggregated + df_agg.append( + { + "solver_path": str(solver_path), + "model": str(model), + "solver": str(solver), + **spec["run_config"]["eval_spec"]["args"], + **results, + } + ) + # Per-sample + for res in extract_individual_results(path): + df_samples.append( + { + "solver_path": str(solver_path), + "model": str(model), + "solver": str(solver), + **spec["run_config"]["eval_spec"]["args"], + **res, + } + ) + df_agg = pd.DataFrame(df_agg) + df_samples = pd.DataFrame(df_samples) + return df_agg, df_samples + + +def render_results(df: pd.DataFrame, out_dir: Path): + agg_operations = { + "predicted_move_proportion": ["mean", "sem"], + "predicted_move_in_variant_proportion": ["mean", "sem"], + } + df = df.groupby("solver_path").agg(agg_operations).reset_index() + df = df.round(2) + print(df.to_csv(index=False)) + df.to_csv(os.path.join(out_dir, "results.csv"), index=False) + + +def compute_num_previous_bishop_moves(previous_moves: Sequence[str]) -> int: + controller, _, _ = initialise_boards() + + num_previous_bishop_moves = 0 + for move in previous_moves: + start_coord = controller.notation_parser._str_to_move( + move, controller.board.board_state + ).start_coord + _, piece_id = parse_piece(controller.board.board_state, start_coord[0], start_coord[1]) + if piece_id == 2: + num_previous_bishop_moves += 1 + + controller.update_board(move) + + return num_previous_bishop_moves + + +def plot_diagonal_bishop_results(df: pd.DataFrame, out_dir: Path): + # Get number of previous bishop moves + df["num_previous_bishop_moves"] = [ + compute_num_previous_bishop_moves(i) for i in df["previous_moves"] + ] + + # Calculate headline metrics per solver, and number of previous moves + agg_operations = { + "predicted_move_in_variant": ["mean"], + } + df = df.groupby(["solver_path", "num_previous_bishop_moves"]).agg(agg_operations).reset_index() + + # Plot separately for each solver + for model, group in df.groupby("solver_path"): + plt.plot( + group["num_previous_bishop_moves"], + group["predicted_move_in_variant"], + label=model, + ) + + plt.xlabel("Num previous bishop moves") + plt.ylabel("Proportion of (illegal) predicted diagonal bishop moves") + plt.ylim([0, 1]) + plt.legend() + plt.savefig(os.path.join(out_dir, "diagonal.png")) + plt.show() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--log_dir", "-d", type=str, required=True) + parser.add_argument("--out_dir", "-o", type=str, required=True) + parser.add_argument("--diagonal_variant", action="store_true", default=False) + args = parser.parse_args() + log_dir = Path(args.log_dir) + out_dir = Path(args.out_dir) + out_dir.mkdir(exist_ok=True, parents=True) + + df_agg, df_samples = extract_results(log_dir) + render_results(df_agg, out_dir) + + if args.diagonal_variant: + plot_diagonal_bishop_results(df_samples, out_dir) + + +if __name__ == "__main__": + main() diff --git a/evals/elsuite/cant_do_that_anymore/scripts/run_experiments.sh b/evals/elsuite/cant_do_that_anymore/scripts/run_experiments.sh new file mode 100755 index 0000000000..68fe4ac5e7 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/scripts/run_experiments.sh @@ -0,0 +1,67 @@ +#!/bin/bash +logdir=./logs +outputdir=./outputs + +timestamp=$(date +%Y%m%d_%H%M%S) +logpathbase=$logdir/$timestamp/ + +mkdir -p ${logpathbase} + +declare -a SOLVERS_ZEROSHOT=( + "generation/direct/gpt-3.5-turbo" + "chess/generation/direct/gpt-3.5-turbo-instruct" + "generation/direct/gpt-4-turbo-preview" + "chess/generation/direct/gpt-4-base" +) + +# See if variant was indicated +run_diagonal_variant=1 +for arg in "$@" +do + if [[ $arg == "--no_diagonal_variant" ]]; then + run_diagonal_variant=0 + break + fi +done + +# TODO CoT solvers + +echo Running experiments and logging to $logpathbase + +for run_idx in {0..2} +do + for solver in "${SOLVERS_ZEROSHOT[@]}" + do + log_name=${solver//\//-} + oaieval $solver cant_do_that_anymore \ + --record_path ${logpathbase}run_${run_idx}_${log_name}.log \ + --extra_eval_params n_samples=1000 \ + --seed ${run_idx} + done +done + +echo Done running experiments, all logs in $logpathbase + +echo Producing plots, outputs to $outputdir +python make_plots.py --log_dir $logpathbase --out_dir $outputdir + +if [[ $run_diagonal_variant -eq 1 ]]; then + echo Running diagonal experiment and logging to $logpathbase + + for run_idx in {0..2} + do + for solver in "${SOLVERS_ZEROSHOT[@]}" + do + log_name=${solver//\//-} + oaieval $solver cant_do_that_anymore.all_diagonal \ + --record_path ${logpathbase}run_${run_idx}_${log_name}.log \ + --extra_eval_params n_samples=1000 \ + --seed ${run_idx} + done + done + + echo Done running experiments, all logs in $logpathbase + + echo Producing plots, outputs to $outputdir + python make_plots.py --log_dir $logpathbase --out_dir $outputdir --diagonal_variant +fi \ No newline at end of file diff --git a/evals/elsuite/cant_do_that_anymore/utils.py b/evals/elsuite/cant_do_that_anymore/utils.py new file mode 100644 index 0000000000..519aad8596 --- /dev/null +++ b/evals/elsuite/cant_do_that_anymore/utils.py @@ -0,0 +1,250 @@ +import json +import logging +import os +from multiprocessing.pool import ThreadPool +from typing import Sequence + +import chess +from tqdm import tqdm + +from evals.elsuite.cant_do_that_anymore.chess.board import BoardController +from evals.elsuite.cant_do_that_anymore.chess.board_test import default_board_init +from evals.elsuite.cant_do_that_anymore.chess.move_variants import ( + PIECE_ID_TO_INSTANCE, + PIECE_ID_TO_STR, + PIECE_STR_TO_ID, + VARIANT_PIECE_ID_TO_INSTANCE, +) +from evals.elsuite.cant_do_that_anymore.chess.notation import AlgebraicNotationParser +from evals.elsuite.cant_do_that_anymore.defaults import TASK_DESCRIPTION +from evals.record import DummyRecorder, RecorderBase +from evals.solvers.solver import DummySolver, Solver +from evals.task_state import Message, TaskState + +logger = logging.getLogger(__name__) + + +def construct_messages(previous_moves: Sequence[str]) -> Sequence[Message]: + """ + Creates list of Message's containing the previous chess moves. The last + Message is always from the "user" + """ + solver_is_white = len(previous_moves) % 2 == 0 + messages = [] + current_player = "assistant" if solver_is_white else "user" + for move in previous_moves: + messages.append(Message(current_player, move)) + # toggle current player + current_player = "assistant" if current_player == "user" else "user" + + return messages + + +def dump_sequence_to_jsonl(data: Sequence[dict], path: str): + with open(path, "w+") as f: + for example in data: + example = json.dumps(example) + f.write(f"{example}\n") + + +def load_sequence_from_jsonl(path: str) -> Sequence[dict]: + data = [] + with open(path, "r") as f: + for line in f: + line = json.loads(line) + data.append(line) + + return data + + +def initialise_boards() -> tuple[BoardController, BoardController, chess.Board]: + """ + Initialises local chess framework, and framework from + python-chess library + """ + default_controller = BoardController( + default_board_init, + PIECE_ID_TO_INSTANCE, + PIECE_STR_TO_ID, + PIECE_ID_TO_STR, + AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR), + ) + variant_controller = BoardController( + default_board_init, + VARIANT_PIECE_ID_TO_INSTANCE, + PIECE_STR_TO_ID, + PIECE_ID_TO_STR, + AlgebraicNotationParser(PIECE_STR_TO_ID, PIECE_ID_TO_STR), + ) + their_controller = chess.Board() + + return default_controller, variant_controller, their_controller + + +def assert_boards_consistent( + controller: BoardController, their_controller: chess.Board, player_id: str +): + """ + Checks both boards have consistent states by ensuring both have same set of legal moves + """ + our_legal_moves = sorted(controller.get_player_legal_moves(player_id)) + their_legal_moves = sorted([str(i) for i in their_controller.legal_moves]) + if our_legal_moves != their_legal_moves: + our_additional_moves = list(set(our_legal_moves) - set(their_legal_moves)) + their_additional_moves = list(set(their_legal_moves) - set(our_legal_moves)) + assert False, f""" + Inconsistent legal moves between the boards! + Our legal moves: {our_legal_moves}, + Their legal moves: {their_legal_moves}, + Moves we had they didnt: {our_additional_moves}, + Moves they had we didn't: {their_additional_moves}, + Board state:\n{controller.board.board_state} + """ + + +def does_solver_predict_move( + solver: Solver, + recorder: RecorderBase, + task_description: str, + special_move: str, + previous_moves: Sequence[str], +): + task_state = TaskState( + task_description, + construct_messages(previous_moves), + ) + + with recorder.as_default_recorder(-1): + solver_result = solver(task_state, **{"max_tokens": 4}) + pred_str = solver_result.output.strip() + + if pred_str == special_move: + return True + + return False + + +def process_example(work_input: dict): + solver, recorder, example, task_description = ( + work_input["solver"], + work_input["recorder"], + work_input["example"], + work_input["task_description"], + ) + special_move, previous_moves = example["special_move"], example["previous_moves"] + + predicts_move = does_solver_predict_move( + solver, + recorder, + task_description, + special_move, + previous_moves, + ) + return predicts_move, example + + +def get_solver_predictions( + solver: Solver, + recorder: RecorderBase, + special_moves_dataset: Sequence[dict], + n_threads: int, + task_description: str, +) -> Sequence[dict]: + """ + Filter to find all special moves that the solver would have predicted under the normal + rules of chess with temp=0, then dump this dataset + """ + solver_moves_dataset = [] + work_items = [ + { + "solver": solver, + "recorder": recorder, + "example": example, + "task_description": task_description, + } + for example in special_moves_dataset + ] + + t_bar = tqdm(total=len(special_moves_dataset)) + with ThreadPool(n_threads) as pool: + iter = pool.imap_unordered(process_example, work_items) + + for result in (t_bar := tqdm(iter, total=len(work_items))): + predicts_move, example = result + if predicts_move: + solver_moves_dataset.append(example) + t_bar.set_description(f"Dataset size: {len(solver_moves_dataset)}") + + return solver_moves_dataset + + +def get_dataset_path( + solver: Solver, + registry_path: str, + remake_dataset_if_not_found: bool, + default_model_dataset: str, +) -> str: + """ + This dataset requires each evaluated model to have its own dataset. We get the exact + model being exaluated, check if a dataset exists for it, if not we generate one + """ + recorder = DummyRecorder(None) + with recorder.as_default_recorder("x"): + solver_version = solver.model_version + + # If nested solver, convert returned dict to str + if isinstance(solver_version, dict): + solver_version = json.dumps(solver_version) + + all_datasets_path = os.path.join(registry_path, "cant_do_that_anymore") + + # Check if dataset exists + solver_dataset_path = os.path.join(all_datasets_path, f"{solver_version}_dataset.jsonl") + if os.path.exists(solver_dataset_path): + return solver_dataset_path + + # Remake, or load default + if isinstance(solver, DummySolver): + return f"cant_do_that_anymore/{default_model_dataset}_dataset.jsonl" + elif remake_dataset_if_not_found: + logger.warning( + f"Generating dataset for {solver_version}! Ideally the solver should be using temperature=0 when creating the dataset, " + "otherwise generated dataset will be of a slightly different distribution" + ) + create_dataset(solver, recorder, solver_dataset_path, all_datasets_path) + return solver_dataset_path + else: + logger.warning( + f"Dataset for {solver_version} wasn't found! Using the dataset for {default_model_dataset} instead." + ) + return f"cant_do_that_anymore/{default_model_dataset}_dataset.jsonl" + + +def create_dataset( + solver: Solver, recorder: RecorderBase, solver_dataset_path: str, all_datasets_path: str +): + threads = int(os.environ.get("EVALS_THREADS", "10")) + + special_moves_dataset = load_sequence_from_jsonl( + os.path.join(all_datasets_path, "special_moves_dataset.jsonl") + ) + solver_moves_dataset = get_solver_predictions( + solver, + recorder, + special_moves_dataset, + n_threads=threads, + task_description=TASK_DESCRIPTION, + ) + dump_sequence_to_jsonl(solver_moves_dataset, solver_dataset_path) + + +def get_diagonal_dataset_path( + registry_path: str, +) -> str: + return os.path.join(registry_path, "cant_do_that_anymore/diagonal_moves_dataset.jsonl") + + +def get_binary_avg(metrics: dict, key: str) -> float: + positive_examples = [i for i in metrics if i[key]] + avg = len(positive_examples) / len(metrics) + return avg diff --git a/evals/registry/data/cant_do_that_anymore/diagonal_moves_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/diagonal_moves_dataset.jsonl new file mode 100644 index 0000000000..7cce7ab588 --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/diagonal_moves_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:345340a9c74ae6d3ad73393b43986c37fa30ad2df8e94d147d9f63cf519e703e +size 540964 diff --git a/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-0125_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-0125_dataset.jsonl new file mode 100644 index 0000000000..d63a762d37 --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-0125_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08d0cbf162d7b46e8931c74816f597085d5d365895e7f8c9f9b20d98be0566c8 +size 170427 diff --git a/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-instruct_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-instruct_dataset.jsonl new file mode 100644 index 0000000000..43161bec40 --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/gpt-3.5-turbo-instruct_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3d9927244f61a7e00d7b4d9e5521b8ad3249be08cbf8afd3c75b30fe8f4e9a5 +size 223466 diff --git a/evals/registry/data/cant_do_that_anymore/gpt-4-0125-preview_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/gpt-4-0125-preview_dataset.jsonl new file mode 100644 index 0000000000..1c693f76de --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/gpt-4-0125-preview_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80a41ce88bab1d6b9315835fa2845bb754ed52d0d7983857f255f5de0fd2fbdb +size 283930 diff --git a/evals/registry/data/cant_do_that_anymore/gpt-4-0314_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/gpt-4-0314_dataset.jsonl new file mode 100644 index 0000000000..e6dffa7d4d --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/gpt-4-0314_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5df2376c0805ea323dddec11a01d5d843edce069f86550f2a9e91efcad4f51cc +size 549365 diff --git a/evals/registry/data/cant_do_that_anymore/special_moves_dataset.jsonl b/evals/registry/data/cant_do_that_anymore/special_moves_dataset.jsonl new file mode 100644 index 0000000000..6f5e89e691 --- /dev/null +++ b/evals/registry/data/cant_do_that_anymore/special_moves_dataset.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:baea567fbd18be57a6fba31a8e7d05a670bfd86799397269aa9b47ab6d2f2a5b +size 3381675 diff --git a/evals/registry/evals/cant_do_that_anymore.yaml b/evals/registry/evals/cant_do_that_anymore.yaml new file mode 100644 index 0000000000..d7254a9545 --- /dev/null +++ b/evals/registry/evals/cant_do_that_anymore.yaml @@ -0,0 +1,23 @@ +cant_do_that_anymore: + id: cant_do_that_anymore.all + metrics: [variant_impact_factor, delta, predicted_move_proportion, predicted_move_in_variant_proportion, avg_num_previous_moves, std_num_previous_moves] + description: Evaluates how well models can adapt to new rules of an environment (chess) + +cant_do_that_anymore.all: + class: evals.elsuite.cant_do_that_anymore.eval:CantDoThatAnymore + args: + default_model_dataset: "gpt-3.5-turbo-0125" + n_samples: 1000 + +cant_do_that_anymore.all_small: + class: evals.elsuite.cant_do_that_anymore.eval:CantDoThatAnymore + args: + default_model_dataset: "gpt-3.5-turbo-0125" + n_samples: 100 + +cant_do_that_anymore.all_diagonal: + class: evals.elsuite.cant_do_that_anymore.eval:CantDoThatAnymore + args: + default_model_dataset: "gpt-3.5-turbo-0125" + n_samples: 1000 + diagonal_variation: True diff --git a/evals/registry/solvers/cant_do_that_anymore.yaml b/evals/registry/solvers/cant_do_that_anymore.yaml new file mode 100644 index 0000000000..951dd066bf --- /dev/null +++ b/evals/registry/solvers/cant_do_that_anymore.yaml @@ -0,0 +1,17 @@ +chess/generation/direct/gpt-3.5-turbo-instruct: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-3.5-turbo-instruct + extra_options: + temperature: 1 + max_tokens: 4 + +chess/generation/direct/gpt-4-base: + class: evals.solvers.openai_solver:OpenAISolver + args: + completion_fn_options: + model: gpt-4-base + extra_options: + temperature: 1 + max_tokens: 4 diff --git a/pyproject.toml b/pyproject.toml index 2b226b4ef0..4c4e6cbfa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "jiwer", "seaborn", "statsmodels", + "chess", ] [project.urls]