Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor game runner and env #149

Open
wants to merge 76 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
2979751
tests are passing
alexandrasouly Jan 23, 2023
02844d7
tensor game
alexandrasouly Jan 25, 2023
f57b28a
new payoff matrix
alexandrasouly Jan 25, 2023
7844f25
moved to tensor repo
alexandrasouly Jan 25, 2023
ca09efe
works, before fixing wandb
alexandrasouly Jan 29, 2023
ba36084
why this not work i do not know
alexandrasouly Jan 29, 2023
47c9004
fixed wandb
alexandrasouly Jan 29, 2023
9781add
tried using strategies
alexandrasouly Jan 29, 2023
2615e66
added tft variant and tests
alexandrasouly Jan 31, 2023
dbffc3a
new strategies
alexandrasouly Jan 31, 2023
3de0abc
experiments
alexandrasouly Feb 8, 2023
efefe56
adding 3ppl eval
Aidandos Feb 8, 2023
5d03ca2
fixed bug
Aidandos Feb 8, 2023
2a0bf28
changing wandb group
Aidandos Feb 8, 2023
374304a
fixing small bug
Aidandos Feb 8, 2023
7fbd58b
change all configs to include num_inner_steps
Aidandos Feb 8, 2023
ab42d4d
no annealing
Aidandos Feb 8, 2023
db52275
correct payoff
alexandrasouly Feb 8, 2023
742ff83
conflict
alexandrasouly Feb 8, 2023
3e80315
sort of works
alexandrasouly Feb 21, 2023
85788d8
results are still dodge
alexandrasouly Feb 21, 2023
799c136
half of merge conflicts
alexandrasouly Feb 22, 2023
177fbee
things run again
alexandrasouly Feb 23, 2023
a80ee95
formatting
alexandrasouly Feb 23, 2023
d6c19ae
refactored coop logging func and formatted
alexandrasouly Feb 23, 2023
b556cfe
updated
akbir Feb 23, 2023
1e2a379
updated
akbir Feb 23, 2023
7f2144f
time update
akbir Feb 23, 2023
161b011
gs_v_tabular
alexandrasouly Feb 26, 2023
b6215b3
conflict
alexandrasouly Feb 26, 2023
29ebbb3
before moving stuff around for debugging
alexandrasouly Feb 26, 2023
26ef6f3
rng bug fixed
alexandrasouly Feb 28, 2023
006540f
rng bug fixed, ppo_mem runnable
alexandrasouly Mar 12, 2023
3f0659f
started n player
alexandrasouly Mar 17, 2023
a7a33e4
pulled from main
alexandrasouly Mar 17, 2023
4f76571
testing shaping on cluster
Mar 17, 2023
baebae1
testing shaping yaml
Mar 17, 2023
a96fda7
fixed rngs in runners based on main
alexandrasouly Mar 18, 2023
e64c825
evo new params and more opps
alexandrasouly Mar 18, 2023
e6ce23a
added 3player_ipd experiments
alexandrasouly Mar 18, 2023
e142a87
conflict
alexandrasouly Mar 18, 2023
45078bf
conflict
alexandrasouly Mar 18, 2023
4c289fd
tabular fixed
alexandrasouly Mar 18, 2023
149e0ed
Merge branch 'alex_learning' of github.com:akbir/pax into alex_learning
alexandrasouly Mar 18, 2023
66ebabe
eval works
alexandrasouly Mar 26, 2023
151ed79
n player tests pass
alexandrasouly Mar 27, 2023
93a7729
started adding n player evo
alexandrasouly Mar 28, 2023
f0ef2ce
n player evo runner works
alexandrasouly Mar 28, 2023
af72da0
3pl runs again
alexandrasouly Mar 28, 2023
0efde8c
fixed 2player agent setup
alexandrasouly Mar 28, 2023
179b465
wandb runs for n_evo but state vis is not working
alexandrasouly Mar 28, 2023
eff177a
wandb logging fixed
alexandrasouly Mar 28, 2023
22c793d
cant believe 8player game actually ran"
alexandrasouly Mar 28, 2023
0a8bebe
n player eval works
alexandrasouly Apr 3, 2023
d8d810f
format
alexandrasouly Apr 3, 2023
1e517b2
nplayer ppo config files
alexandrasouly Apr 3, 2023
8f29b7d
Merge branch 'alex_learning' of github.com:akbir/pax into alex_learning
alexandrasouly Apr 3, 2023
1e67c74
fix up config
alexandrasouly Apr 4, 2023
483da97
Merge branch 'alex_learning' of github.com:akbir/pax into alex_learning
alexandrasouly Apr 4, 2023
5be995a
nplayer rl runner works
alexandrasouly Apr 4, 2023
9903d05
Merge branch 'alex_learning' of github.com:akbir/pax into alex_learning
alexandrasouly Apr 4, 2023
c98ee58
payoff from paper
alexandrasouly Apr 17, 2023
6a5d5d9
Merge branch 'alex_learning' of github.com:akbir/pax into alex_learning
alexandrasouly Apr 17, 2023
d1f9dbb
stag hunt and snowdrift payoffs
alexandrasouly Apr 18, 2023
896bcd8
Merge branch 'alex_learning' of https://github.com/akbir/pax into ale…
alexandrasouly Apr 18, 2023
79d9869
added tc
alexandrasouly Apr 18, 2023
f88c7a7
added runtime debug scripts
alexandrasouly Apr 29, 2023
25c26b3
fix scripts and add runner log
alexandrasouly Apr 29, 2023
dfcb7aa
n player scripts and baseline strategies
alexandrasouly Apr 29, 2023
91f295f
typo
alexandrasouly Apr 29, 2023
0f5dcb8
added global welfare and grouped visitations
alexandrasouly May 9, 2023
12996ee
added more shaper 5pl configs
alexandrasouly May 9, 2023
640aa75
fix global welfare
alexandrasouly May 9, 2023
242862d
fix global welfare2
alexandrasouly May 9, 2023
094250c
naive scripts
alexandrasouly May 28, 2023
fc2624d
tft scripts
alexandrasouly May 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.6.0
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
- id: black
language_version: python3.9
- repo: https://github.com/pycqa/flake8
rev: '3.9.2'
hooks:
- id: flake8
additional_dependencies: [flake8-bugbear]
args: [
"--show-source",
"--ignore=E203,E266,E501,W503,F403,F401,B008,E712",
"--max-line-length=100",
"--max-complexity=18",
"--select=B,C,E,F,W,T4,B9"]
# - repo: https://github.com/pycqa/flake8
# rev: '3.9.2'
# hooks:
# - id: flake8
# additional_dependencies: [flake8-bugbear]
# args: [
# "--show-source",
# "--ignore=E203,E266,E501,W503,F403,F401,B008,E712",
# "--max-line-length=100",
# "--max-complexity=18",
# "--select=B,C,E,F,W,T4,B9"]
7,845 changes: 7,845 additions & 0 deletions experiment.log

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions mess.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "unsupported operand type(s) for +: 'int' and 'str'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mjax\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnumpy\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mjnp\u001b[39;00m\n\u001b[1;32m 2\u001b[0m num_players \u001b[39m=\u001b[39m \u001b[39m4\u001b[39m\n\u001b[0;32m----> 4\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39msum\u001b[39;49m(\u001b[39mbin\u001b[39;49m(\u001b[39m2\u001b[39;49m\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49m(num_players \u001b[39m-\u001b[39;49m \u001b[39m1\u001b[39;49m)\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m)))\n\u001b[1;32m 5\u001b[0m \u001b[39mprint\u001b[39m(jnp\u001b[39m.\u001b[39mbitwise_and(\u001b[39m2\u001b[39m, \u001b[39m2\u001b[39m\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m(num_players \u001b[39m-\u001b[39m \u001b[39m1\u001b[39m)\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m))\n\u001b[1;32m 7\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mnumpy\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mnp\u001b[39;00m\n",
"\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for +: 'int' and 'str'"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"num_players = 4\n",
"\n",
"print(sum(bin(2**(num_players - 1)-1)))\n",
"print(jnp.bitwise_and(2, 2**(num_players - 1)-1))\n",
"\n",
"import numpy as np\n",
"np.binary_repr(2**(num_players - 1)-1)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pax",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 3 additions & 0 deletions outputs/2023-01-18/19-35-11/experiment.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[2023-01-18 19:35:11,465][root][INFO] - => Global setup ...
[2023-01-18 19:35:11,465][root][INFO] - => Done in 287.056 us
[2023-01-18 19:35:11,465][root][INFO] -
3 changes: 2 additions & 1 deletion pax/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Tuple

from pax.utils import MemoryState, TrainingState
import jax.numpy as jnp

from pax.utils import MemoryState, TrainingState


class AgentInterface:
"""Interface for agents to interact with runners and environemnts.
Expand Down
4 changes: 3 additions & 1 deletion pax/agents/hyper/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def model_update_epoch(
return new_state, new_mem, metrics

@jax.jit
def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState:
def make_initial_state(
key: Any, hidden: jnp.ndarray
) -> Tuple[TrainingState, MemoryState]:
"""Initialises the training state (parameters and optimiser state)."""
key, subkey = jax.random.split(key)
dummy_obs = jnp.zeros(shape=obs_spec)
Expand Down
2 changes: 1 addition & 1 deletion pax/agents/naive_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import jax
import jax.numpy as jnp
from pax.agents.agent import AgentInterface

from pax.agents.agent import AgentInterface
from pax.envs.infinite_matrix_game import EnvParams as InfiniteMatrixGameParams
from pax.utils import MemoryState

Expand Down
8 changes: 5 additions & 3 deletions pax/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from pax import utils
from pax.agents.agent import AgentInterface
from pax.agents.ppo.networks import (
make_ipditm_network,
make_sarl_network,
make_coingame_network,
make_ipd_network,
make_ipditm_network,
make_sarl_network,
)
from pax.utils import Logger, MemoryState, TrainingState, get_advantages

Expand Down Expand Up @@ -336,7 +336,9 @@ def model_update_epoch(

return new_state, new_memory, metrics

def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState:
def make_initial_state(
key: Any, hidden: jnp.ndarray
) -> Tuple[TrainingState, MemoryState]:
"""Initialises the training state (parameters and optimiser state)."""
key, subkey = jax.random.split(key)

Expand Down
9 changes: 9 additions & 0 deletions pax/agents/ppo/ppo_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,15 @@ def make_gru_agent(
network, initial_hidden_state = make_GRU_ipd_network(
action_spec, agent_args.hidden_size
)
elif args.env_id == "iterated_tensor_game":
network, initial_hidden_state = make_GRU_ipd_network(
action_spec, agent_args.hidden_size
)

elif args.env_id == "iterated_nplayer_tensor_game":
network, initial_hidden_state = make_GRU_ipd_network(
action_spec, agent_args.hidden_size
)

elif args.env_id == "InTheMatrix":
network, initial_hidden_state = make_GRU_ipditm_network(
Expand Down
6 changes: 3 additions & 3 deletions pax/agents/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import jax.numpy as jnp
import jax.random
from pax.agents.agent import AgentInterface

from pax.agents.agent import AgentInterface
from pax.utils import Logger, MemoryState, TrainingState

# states are [CC, CD, DC, DD, START]
Expand Down Expand Up @@ -381,7 +381,7 @@ def _policy(

def _reciprocity(self, obs: jnp.ndarray, *args) -> jnp.ndarray:
# now either 0, 1, 2, 3
batch_size, _ = obs.shape
# batch_size, _ = obs.shape
obs = obs.argmax(axis=-1)
# if 0 | 2 | 4 -> C
# if 1 | 3 -> D
Expand Down Expand Up @@ -488,7 +488,7 @@ def make_initial_state(self, _unused, *args) -> TrainingState:


class Stay(AgentInterface):
def __init__(self, num_actions: int, num_envs: int):
def __init__(self, num_actions: int, num_envs: int, num_players: int = 2):
self.make_initial_state = initial_state_fun(num_envs)
self._state, self._mem = self.make_initial_state(None, None)
self._logger = Logger()
Expand Down
Loading