Skip to content

Commit cd2d89f

Browse files
authored
wip add new proposed message structure (#1904)
* wip add new proposed message structure * tokenization * wip * wip transform builder * wip make the chat dataset loadable * wip chatml + llama 3 new chat objects * chore: lint * chore: lint * fix tokenization * remove dacite dependency since we're using pydantic now * fix handling when already correctly split in messages * make sure to remove chat features from tokenized ds * move chat to be a input transform for messages * make sure llama3 has the bos token * remove non-working special token code * fix messages strat loader
1 parent 1834cdc commit cd2d89f

File tree

23 files changed

+1285
-15
lines changed

23 files changed

+1285
-15
lines changed

requirements_env.txt

+315
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
accelerate==0.34.1
2+
addict==2.4.0
3+
aiofiles==23.2.1
4+
aiohttp==3.9.0
5+
aiosignal==1.3.1
6+
aiostream==0.5.2
7+
alembic==1.13.1
8+
annotated-types==0.6.0
9+
annoy==1.17.3
10+
ansible==6.7.0
11+
ansible-core==2.13.13
12+
ansible-vault==2.1.0
13+
anyio==3.7.1
14+
appdirs==1.4.4
15+
art==6.0
16+
asgiref==3.7.2
17+
async-timeout==4.0.2
18+
attrdict==2.0.1
19+
attrs==22.2.0
20+
awscli==1.32.75
21+
-e git+ssh://[email protected]/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
22+
backoff==2.2.1
23+
base58==2.1.1
24+
beartype==0.17.2
25+
bitnet==0.2.1
26+
bitsandbytes==0.42.0
27+
bittensor==6.7.0
28+
black==23.7.0
29+
blinker==1.7.0
30+
boto3==1.34.75
31+
botocore==1.34.75
32+
cachetools==5.3.3
33+
cachy==0.1.1
34+
certifi==2023.7.22
35+
cffi==1.16.0
36+
cfgv==3.3.1
37+
chai-guanaco==1.2.4
38+
charset-normalizer==3.2.0
39+
cleo==0.6.8
40+
click==8.1.7
41+
cloudpickle==2.0.0
42+
cohere==4.11.2
43+
colorama==0.4.4
44+
coloredlogs==15.0.1
45+
CoLT5-attention==0.10.20
46+
contextlib2==21.6.0
47+
contourpy==1.2.0
48+
cryptography==41.0.3
49+
cycler==0.12.1
50+
cytoolz==0.12.3
51+
databricks-cli==0.18.0
52+
dataclasses-json==0.5.7
53+
datasets==2.11.0
54+
ddt==1.6.0
55+
decorator==5.1.1
56+
deepspeed==0.15.0
57+
# Editable Git install with no remote (dialogpt==0.1)
58+
-e /Users/wing/Projects/ml/dialogpt/src
59+
dill==0.3.6
60+
distlib==0.3.6
61+
docker==7.0.0
62+
docker-pycreds==0.4.0
63+
docstring-parser==0.15
64+
docutils==0.16
65+
ecdsa==0.18.0
66+
einops==0.7.0
67+
einops-exts==0.0.4
68+
einx==0.1.3
69+
entrypoints==0.4
70+
eth-hash==0.6.0
71+
eth-keys==0.5.0
72+
eth-typing==4.0.0
73+
eth-utils==2.3.1
74+
evaluate==0.4.0
75+
exceptiongroup==1.1.1
76+
fastapi==0.109.2
77+
fastcore==1.5.29
78+
ffmpy==0.4.0
79+
filelock==3.12.2
80+
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
81+
fire==0.5.0
82+
first==2.0.2
83+
flake8==7.0.0
84+
Flask==3.0.1
85+
fonttools==4.47.2
86+
frozendict==2.4.1
87+
frozenlist==1.3.3
88+
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
89+
fsspec==2023.6.0
90+
fuzzywuzzy==0.18.0
91+
gitdb==4.0.10
92+
GitPython==3.1.31
93+
google-pasta==0.2.0
94+
gradio==4.42.0
95+
gradio_client==1.3.0
96+
greenlet==2.0.2
97+
grpclib==0.4.7
98+
gunicorn==21.2.0
99+
h11==0.14.0
100+
h2==4.1.0
101+
hpack==4.0.0
102+
httpcore==0.17.3
103+
httpx==0.24.1
104+
huggingface-hub==0.23.4
105+
humanfriendly==10.0
106+
hyperframe==6.0.1
107+
identify==2.5.24
108+
idna==3.4
109+
immutables==0.20
110+
importlib-metadata==6.7.0
111+
importlib-resources==6.1.1
112+
inflection==0.5.1
113+
iniconfig==2.0.0
114+
itsdangerous==2.1.2
115+
Jinja2==3.1.2
116+
jmespath==1.0.1
117+
joblib==1.3.2
118+
jsonlines==3.1.0
119+
jsonschema==2.6.0
120+
kiwisolver==1.4.5
121+
langchain==0.0.144
122+
Levenshtein==0.24.0
123+
libcst==1.1.0
124+
liger-kernel==0.0.0
125+
lion-pytorch==0.1.2
126+
llama-cpp-python==0.1.36
127+
llvmlite==0.40.1
128+
local-attention==1.9.0
129+
loguru==0.7.0
130+
Mako==1.3.2
131+
Markdown==3.5.2
132+
markdown-it-py==3.0.0
133+
markdown2==2.4.10
134+
MarkupSafe==2.1.2
135+
marshmallow==3.19.0
136+
marshmallow-enum==1.5.1
137+
matplotlib==3.8.2
138+
mccabe==0.7.0
139+
mdurl==0.1.2
140+
MEGABYTE-pytorch==0.0.7
141+
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
142+
mlflow==2.10.0
143+
modal==0.62.77
144+
more-itertools==10.2.0
145+
mpmath==1.2.1
146+
msgpack==1.0.7
147+
msgpack-numpy-opentensor==0.5.0
148+
multidict==6.0.4
149+
multiprocess==0.70.14
150+
munch==2.5.0
151+
mypy==1.3.0
152+
mypy-extensions==1.0.0
153+
nest-asyncio==1.6.0
154+
netaddr==0.10.1
155+
networkx==3.0rc1
156+
nh3==0.2.14
157+
nodeenv==1.8.0
158+
nomic==2.0.2
159+
numba==0.57.1
160+
numexpr==2.8.4
161+
numpy==1.24.4
162+
oauthlib==3.2.2
163+
openai==0.27.4
164+
openapi==1.1.0
165+
openapi-schema-pydantic==1.2.4
166+
optimum==1.8.6
167+
orjson==3.10.7
168+
packaging==23.1
169+
pandas==2.0.0
170+
parameterized==0.9.0
171+
password-strength==0.0.3.post2
172+
pastel==0.1.1
173+
pathos==0.3.0
174+
pathspec==0.11.1
175+
pathtools==0.1.2
176+
peft==0.11.1
177+
pendulum==3.0.0
178+
Pillow==9.5.0
179+
pip-tools==1.11.0
180+
platformdirs==3.2.0
181+
pluggy==1.4.0
182+
poetry==0.7.1
183+
pox==0.3.2
184+
ppft==1.7.6.6
185+
pre-commit==3.3.2
186+
prettytable==3.10.0
187+
prompt-toolkit==3.0.39
188+
protobuf==3.20.2
189+
protobuf3-to-dict==0.1.5
190+
psutil==5.9.5
191+
psycopg==3.1.18
192+
PuLP==2.8.0
193+
py==1.11.0
194+
py-bip39-bindings==0.1.11
195+
py-cpuinfo==9.0.0
196+
py-ed25519-zebra-bindings==1.0.1
197+
py-sr25519-bindings==0.2.0
198+
pyarrow==11.0.0
199+
pyasn1==0.6.0
200+
pycodestyle==2.11.1
201+
pycparser==2.21
202+
pycryptodome==3.20.0
203+
pydantic==2.5.3
204+
pydantic_core==2.14.6
205+
pydub==0.25.1
206+
pyfiglet==0.8.post1
207+
pyflakes==3.2.0
208+
Pygments==2.15.1
209+
PyJWT==2.8.0
210+
pylev==1.4.0
211+
PyNaCl==1.5.0
212+
pynvml==11.5.0
213+
pyparsing==2.4.7
214+
pyrsistent==0.14.11
215+
pytest==8.0.2
216+
pytest-asyncio==0.23.4
217+
python-dateutil==2.8.2
218+
python-dotenv==1.0.1
219+
python-Levenshtein==0.24.0
220+
python-multipart==0.0.9
221+
pytz==2023.3
222+
PyYAML==6.0.1
223+
querystring-parser==1.2.4
224+
rapidfuzz==3.6.1
225+
regex==2023.6.3
226+
requests==2.31.0
227+
requests-toolbelt==0.8.0
228+
resolvelib==0.8.1
229+
responses==0.18.0
230+
retry==0.9.2
231+
rich==13.7.0
232+
rsa==4.7.2
233+
ruff==0.6.3
234+
s3transfer==0.10.1
235+
safetensors==0.4.5
236+
sagemaker==2.148.0
237+
scalecodec==1.2.7
238+
schedulefree==1.2.1
239+
schema==0.7.5
240+
scikit-learn==1.4.0
241+
scipy==1.9.3
242+
seaborn==0.13.2
243+
semantic-version==2.10.0
244+
sentencepiece==0.2.0
245+
sentry-sdk==1.19.1
246+
setproctitle==1.3.2
247+
shellingham==1.5.4
248+
shortuuid==1.0.11
249+
shtab==1.6.5
250+
sigtools==4.0.1
251+
six==1.16.0
252+
skypilot==0.4.1
253+
smdebug-rulesconfig==1.0.1
254+
smmap==5.0.0
255+
sniffio==1.3.0
256+
SQLAlchemy==1.4.47
257+
sqlparse==0.4.4
258+
starlette==0.36.3
259+
substrate-interface==1.5.2
260+
svgwrite==1.4.3
261+
sympy==1.11.1
262+
synchronicity==0.6.7
263+
tabulate==0.9.0
264+
tblib==1.7.0
265+
tenacity==8.2.2
266+
tensor-parallel==2.0.0
267+
termcolor==2.2.0
268+
text2art==0.2.0
269+
threadpoolctl==3.2.0
270+
tiktoken==0.6.0
271+
time-machine==2.14.1
272+
timm==0.9.16
273+
tokenizers==0.19.1
274+
tokenmonster==1.1.12
275+
toml==0.9.6
276+
tomli==2.0.1
277+
tomlkit==0.12.0
278+
toolz==0.12.1
279+
torch==2.2.0
280+
torchdata==0.6.1
281+
torchdiffeq==0.2.3
282+
TorchFix==0.4.0
283+
torchtext==0.15.2
284+
torchvision==0.17.0
285+
tqdm==4.66.2
286+
transformers==4.44.2
287+
trl==0.9.6
288+
typer==0.12.5
289+
types-certifi==2021.10.8.3
290+
types-requests==2.31.0.20240125
291+
types-setuptools==69.0.0.20240125
292+
types-toml==0.10.8.7
293+
typing==3.7.4.3
294+
typing-inspect==0.8.0
295+
typing_extensions==4.9.0
296+
tyro==0.5.18
297+
tzdata==2023.3
298+
unique-names-generator==1.0.2
299+
urllib3==2.2.2
300+
uvicorn==0.22.0
301+
vector_quantize_pytorch==1.14.1
302+
virtualenv==20.23.0
303+
voyager==2.0.2
304+
wandb==0.16.2
305+
watchfiles==0.21.0
306+
wavedrom==2.0.3.post3
307+
wcwidth==0.2.6
308+
websocket-client==1.7.0
309+
websockets==12.0
310+
Werkzeug==3.0.1
311+
wonderwords==2.2.0
312+
xxhash==3.2.0
313+
yarl==1.8.2
314+
zetascale==2.2.7
315+
zipp==3.15.0

src/axolotl/cli/preprocess.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
register_chatml_template,
2828
register_llama3_template,
2929
)
30+
from axolotl.utils.trainer import disable_datasets_caching
3031

3132
LOG = logging.getLogger("axolotl.cli.preprocess")
3233

@@ -70,10 +71,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
7071
LOG.warning(msg)
7172
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
7273

73-
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
74-
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
75-
else:
76-
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
74+
with disable_datasets_caching():
75+
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
76+
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
77+
else:
78+
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
7779

7880
if parsed_cli_args.download:
7981
model_name = parsed_cfg.base_model

src/axolotl/core/chat/__init__.py

Whitespace-only changes.

src/axolotl/core/chat/format/__init__.py

Whitespace-only changes.
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
ChatML transformation functions for MessageContents
3+
"""
4+
from typing import Optional
5+
6+
from ..messages import MessageContents, Messages
7+
from .shared import wrap_tools
8+
9+
10+
def format_message(
11+
message: Messages,
12+
message_index: Optional[int] = None, # pylint: disable=unused-argument
13+
) -> Messages:
14+
if message.is_chat_formatted:
15+
return message
16+
17+
# prepend the role prefix within a MessageContents to message.content
18+
message.content.insert(
19+
0,
20+
MessageContents(
21+
type="text",
22+
value=f"<|im_start|>{message.role}\n",
23+
weight=0,
24+
),
25+
)
26+
message.content.append(
27+
MessageContents(type="text", value="<|im_end|>", weight=message.weight)
28+
)
29+
message.content.append(MessageContents(type="text", value="\n", weight=0))
30+
31+
message = wrap_tools(message)
32+
33+
message.is_chat_formatted = True
34+
return message

0 commit comments

Comments
 (0)