Skip to content

Commit f546777

Browse files
committed
core for collection
1 parent 63700b3 commit f546777

13 files changed

+1472
-1
lines changed

.gitignore

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# poetry
98+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102+
#poetry.lock
103+
104+
# pdm
105+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106+
#pdm.lock
107+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108+
# in version control.
109+
# https://pdm.fming.dev/#use-with-ide
110+
.pdm.toml
111+
112+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113+
__pypackages__/
114+
115+
# Celery stuff
116+
celerybeat-schedule
117+
celerybeat.pid
118+
119+
# SageMath parsed files
120+
*.sage.py
121+
122+
# Environments
123+
.env
124+
.venv
125+
env/
126+
venv/
127+
ENV/
128+
env.bak/
129+
venv.bak/
130+
131+
# Spyder project settings
132+
.spyderproject
133+
.spyproject
134+
135+
# Rope project settings
136+
.ropeproject
137+
138+
# mkdocs documentation
139+
/site
140+
141+
# mypy
142+
.mypy_cache/
143+
.dmypy.json
144+
dmypy.json
145+
146+
# Pyre type checker
147+
.pyre/
148+
149+
# pytype static type analyzer
150+
.pytype/
151+
152+
# Cython debug symbols
153+
cython_debug/
154+
155+
# PyCharm
156+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158+
# and can be added to the global gitignore or merged into this file. For a more nuclear
159+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160+
.idea/
161+
162+
test.py
163+
mlc_run.sh
164+
.ml-job-preset.yml
165+
wandb
166+
run.sh
167+
wandb_sweep.yaml
168+
**/*.hdf5
169+
**/*.DS_Store
170+
*.ipynb
171+
*.hdf5
172+
*.gz

README.md

+18-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,18 @@
1-
# xland-minigrid-datasets
1+
# XLand-100B: A Large-Scale Multi-Task Dataset for In-Context Reinforcement Learning
2+
3+
Official code for the 'XLand-100B: A Large-Scale Multi-Task Dataset for In-Context Reinforcement Learning' paper. We provide the utilities used to collect the datasets as well as the code used for experiments with the baselines, namely AD and DPT. As these parts are semantically unrelated, they are separated into separate directories for simplicity (in the cleanrl style).
4+
5+
Both XLand-100B and XLand-Trivial-20B hosted on public S3 bucket and freely available for everyone under CC BY-SA 4.0 Licence. See the README in each directory for instructions.
6+
7+
## Downloading the datasets
8+
9+
We advise starting with Trivial dataset for debugging due to smaller size and faster downloading time. Both datasets have an identical structure. For additional details we refer to the paper.
10+
11+
Datasets can be downloaded with the curl utility (or any other like wget) as follows:
12+
```commandline
13+
# XLand-Trivial-20B, approx 60GB size
14+
curl -L -o xland-trivial-20b.hdf5 https://sc.link/A4rEW
15+
16+
# XLand-100B, approx 325GB size
17+
curl -L -o xland-100b.hdf5 https://sc.link/MoCvZ
18+
```

collection/.gitkeep

Whitespace-only changes.

collection/README.md

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Dataset collection
2+
3+
Here we provide the code used to collect the datasets. We adapted the single-task recurrent PPO implementation from the original XLand-MiniGrid baselines. We used wandb sweeps to pretraind base agent and collect individual learning histories at scale on multiple GPUs. We then combined all the individual histories into a single dataset using `combine.py`.
4+
5+
If you notice any discrepancies with the paper, don't be afraid to open an issue and report about it!
6+
7+
## Pretraining
8+
9+
Pretraining is simple. We provide config for pretraining in `configs/pretrain_base.yaml`. To start:
10+
```commandline
11+
python training/train.py \
12+
--config_path='configs/pretrain_base.yaml' \
13+
--checkpoint_path='path-for-the-final-checkpoint' \
14+
--wandb_logging=True
15+
```
16+
We used pretraining only for the main dataset (tasks from medium benchmark).
17+
18+
## Collecting
19+
20+
We used wandb sweeps for collection. We provide base configs for trivial and medium in `configs/trivial_base.yaml` and `configs/medium_base.yaml` respectively.
21+
22+
### Trivial
23+
24+
First, create wandb config:
25+
```yaml
26+
# trivial_wandb.yaml
27+
entity: <your-enitty>
28+
project: xminigrid-datasets
29+
program: training/train.py
30+
method: grid
31+
parameters:
32+
config_path:
33+
value: "configs/trivial_base.yaml"
34+
group:
35+
value: "xland-minigrid-datasets-trivial-v0"
36+
dataset_path:
37+
value: <path-to-your-dir-for-data>
38+
dataset_num_histories:
39+
value: 32
40+
ruleset_id:
41+
min: 0
42+
max: 10000
43+
distribution: int_uniform
44+
```
45+
Next, create wandb agent with the `wandb sweep trivial_wandb.yaml` to get the sweep ID. To start collection, run `wandb agent <sweep-id>`.
46+
47+
### Medium
48+
49+
Likewise, create a config:
50+
```yaml
51+
# medium_wandb.yaml
52+
entity: <your-enitty>
53+
project: xminigrid-datasets
54+
program: training/train.py
55+
method: grid
56+
parameters:
57+
config_path:
58+
value: "configs/medium_base.yaml"
59+
group:
60+
value: "xland-minigrid-datasets-medium-v0"
61+
dataset_path:
62+
value: <path-to-your-dir-for-data>
63+
pretrained_checkpoint_path:
64+
value: <path-to-your-pre-trained-checkpoint>
65+
dataset_num_histories:
66+
value: 32
67+
ruleset_id:
68+
min: 0
69+
max: 30000
70+
distribution: int_uniform
71+
```
72+
Unlike trivial, you must additionally specify the path to the pre-trained checkpoint (you can use `None` to train from scratch). After that, create wandb agent with the `wandb sweep medium_wandb.yaml` to get the sweep ID. To start collection, run `wandb agent <sweep-id>`.
73+
74+
## Combining
75+
76+
We used simple `combine.py` script to combine all individual learning histories into one dataset. As we described in the paper, we already tuned the hdf5 chunk size which worked best in our experiments, however you can customise it by changing the hardcoded values in the code.
77+
78+
For example, we filterd out all runs with last return below 0.3:
79+
```commandline
80+
python combine.py \
81+
--wandb-entity=your-entity \
82+
--wandb-sweep=your-collection-sweep \
83+
--data-path=your-data-path \
84+
--combined-path=your-combined-path \
85+
--final-return-thrs=0.3 \
86+
```

collection/combine.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import argparse
2+
import glob
3+
import gzip
4+
import os
5+
6+
import h5py
7+
import wandb
8+
from tqdm.auto import tqdm
9+
10+
11+
def get_run(sweep_runs, ruleset_id):
12+
wandb_run = [r for r in sweep_runs if r.config["ruleset_id"] == ruleset_id]
13+
assert len(wandb_run) == 1
14+
return wandb_run[0]
15+
16+
17+
def extract_id(filename):
18+
return int(os.path.basename(filename).split("-")[-1].split(".")[0])
19+
20+
21+
def main(args):
22+
print("Processing sweep runs...")
23+
api = wandb.Api()
24+
all_runs = api.runs(args.wandb_entity)
25+
dataset_runs = [r for r in tqdm(all_runs) if hasattr(r.sweep, "id") and r.sweep.id == args.wandb_sweep]
26+
27+
print("Combining...")
28+
files = glob.glob(os.path.join(args.data_path, "*.gz"))
29+
files = sorted(files, key=lambda f: extract_id(f))
30+
31+
with h5py.File(args.combined_path, "w", rdcc_nbytes=5e9, rdcc_nslots=20000) as new_df:
32+
idx = 0
33+
for file in tqdm(files):
34+
try:
35+
with gzip.open(file, "rb") as gf:
36+
with h5py.File(gf, "r") as df:
37+
# checking that agent achieved return >= thrs, else skip
38+
wandb_run = get_run(dataset_runs, df.attrs["ruleset-id"])
39+
40+
if "final_return" not in wandb_run.summary:
41+
print(f"Corrupted run {file}, skipping...")
42+
continue
43+
44+
if wandb_run.summary["final_return"] < args.final_return_thrs:
45+
continue
46+
47+
assert str(idx) not in new_df.keys(), "key already exists"
48+
g = new_df.create_group(str(idx))
49+
g.attrs.update(df.attrs)
50+
51+
g.create_dataset(
52+
"states",
53+
shape=df["states"].shape,
54+
dtype=df["states"].dtype,
55+
data=df["states"][:],
56+
compression="gzip",
57+
compression_opts=6,
58+
chunks=(1, 4096, 5, 5),
59+
)
60+
g.create_dataset(
61+
"actions",
62+
shape=df["actions"].shape,
63+
dtype=df["actions"].dtype,
64+
data=df["actions"][:],
65+
compression="gzip",
66+
compression_opts=6,
67+
chunks=(1, 4096),
68+
)
69+
g.create_dataset(
70+
"rewards",
71+
shape=df["rewards"].shape,
72+
dtype=df["rewards"].dtype,
73+
data=df["rewards"][:],
74+
compression="gzip",
75+
compression_opts=6,
76+
chunks=(1, 4096),
77+
)
78+
g.create_dataset(
79+
"dones",
80+
shape=df["dones"].shape,
81+
dtype=df["dones"].dtype,
82+
data=df["dones"][:],
83+
compression="gzip",
84+
compression_opts=6,
85+
chunks=(1, 4096),
86+
)
87+
g.create_dataset(
88+
"expert_actions",
89+
shape=df["expert_actions"].shape,
90+
dtype=df["expert_actions"].dtype,
91+
data=df["expert_actions"][:],
92+
compression="gzip",
93+
compression_opts=6,
94+
chunks=(1, 4096),
95+
)
96+
97+
except OSError:
98+
print(f"Corrupted file {file}, skipping...")
99+
continue
100+
101+
idx = idx + 1
102+
103+
104+
if __name__ == "__main__":
105+
parser = argparse.ArgumentParser()
106+
parser.add_argument("--wandb-entity", type=str)
107+
parser.add_argument("--wandb-sweep", type=str)
108+
parser.add_argument("--final-return-thrs", type=float, default=0.3)
109+
parser.add_argument("--data-path", type=str)
110+
parser.add_argument("--combined-path", type=str)
111+
args = parser.parse_args()
112+
main(args)

0 commit comments

Comments
 (0)