Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in-house MoE, loss parallel #159

Merged
merged 235 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
235 commits
Select commit Hold shift + click to select a range
70c3a5e
Add `MoERouter`
epwalsh Jan 30, 2025
5f3994e
Add `MoEMLP`
epwalsh Jan 30, 2025
b3d1fa6
Merge branch 'main' into epwalsh/moe-in-house
epwalsh Jan 30, 2025
122af0a
update Docker build
epwalsh Jan 30, 2025
937baf5
update MLP test
epwalsh Jan 30, 2025
0b127d2
add a test with expert parallel
epwalsh Jan 30, 2025
9746d9a
clean up test
epwalsh Jan 31, 2025
1dba833
add launch script to quickly run tests
epwalsh Jan 31, 2025
90921db
fix dtype
epwalsh Jan 31, 2025
027bc5b
try no host networking
epwalsh Jan 31, 2025
68a6b0f
make host-networking configurable
epwalsh Jan 31, 2025
149e47f
add config class
epwalsh Jan 31, 2025
1dc5c12
clean up for running tests
epwalsh Jan 31, 2025
72939aa
improve script
epwalsh Jan 31, 2025
c7a3890
setup distributed
epwalsh Jan 31, 2025
3e80cc4
fix
epwalsh Jan 31, 2025
1f208fa
fix
epwalsh Jan 31, 2025
da045ea
Add parallel MLP implementation
epwalsh Jan 31, 2025
0d792d7
add MoE base
epwalsh Jan 31, 2025
a284892
Merge branch 'epwalsh/moe-in-house' into v2-moe
epwalsh Jan 31, 2025
27b72ee
Merge branch 'v2' into v2-moe
epwalsh Jan 31, 2025
d1f4984
integrate new MoE code
epwalsh Feb 1, 2025
b0155a8
fix init
epwalsh Feb 1, 2025
cacfaa8
improve how we get MoE losses
epwalsh Feb 1, 2025
055af0d
fixes
epwalsh Feb 1, 2025
0f9b518
clean up
epwalsh Feb 1, 2025
d258b42
fix
epwalsh Feb 1, 2025
02fa200
Add router test
epwalsh Feb 1, 2025
31a972f
fix config
epwalsh Feb 1, 2025
d26755d
fixes
epwalsh Feb 1, 2025
dd91090
fix?
epwalsh Feb 1, 2025
05026cd
fix loss
epwalsh Feb 1, 2025
925bf24
Add test with expert parallelism
epwalsh Feb 1, 2025
38b674a
lol, fix
epwalsh Feb 1, 2025
daff437
fix dtypes?
epwalsh Feb 1, 2025
dfde49d
fix some typos
epwalsh Feb 1, 2025
b8bee1e
check that loss is finite
epwalsh Feb 1, 2025
1f869de
compute active params
epwalsh Feb 2, 2025
41ccfe4
Allow expert parallelism
epwalsh Feb 2, 2025
70a1275
test size of experts
epwalsh Feb 2, 2025
953049c
don't require grouped gemm for MoEMLP test
epwalsh Feb 2, 2025
4709940
move losses to their own module
epwalsh Feb 2, 2025
f9f79b6
remove megablocks from build
epwalsh Feb 2, 2025
674850d
update build deps
epwalsh Feb 2, 2025
36ab5b1
fix
epwalsh Feb 2, 2025
b94fa69
update stable image
epwalsh Feb 2, 2025
0479f74
update nightly build
epwalsh Feb 2, 2025
59867e8
pin grouped gemm to commit
epwalsh Feb 2, 2025
3a7bc84
build with CUTLASS again
epwalsh Feb 2, 2025
c54d882
update images used
epwalsh Feb 2, 2025
0c7baed
fix expert parallelism, implement sequence parallelism
epwalsh Feb 3, 2025
dfb118a
fix test
epwalsh Feb 3, 2025
baef2f2
fix
epwalsh Feb 3, 2025
d352a2b
fix
epwalsh Feb 3, 2025
83d28af
Start on regular MoE
epwalsh Feb 3, 2025
62dfdc4
finish?
epwalsh Feb 3, 2025
ff743e3
refactor
epwalsh Feb 3, 2025
aa7da18
fix
epwalsh Feb 3, 2025
76c5ec2
fix
epwalsh Feb 3, 2025
8fb887d
fix
epwalsh Feb 3, 2025
6951bb1
clean up
epwalsh Feb 3, 2025
c8be5dc
fix
epwalsh Feb 3, 2025
774b77f
fix
epwalsh Feb 4, 2025
7ecebd2
add parallel test for default
epwalsh Feb 4, 2025
6841a6a
fix?
epwalsh Feb 4, 2025
62f5496
clean up
epwalsh Feb 4, 2025
45482d8
clean up
epwalsh Feb 4, 2025
4f38c01
improve test
epwalsh Feb 4, 2025
3a1b06c
fix
epwalsh Feb 4, 2025
f085171
fix
epwalsh Feb 4, 2025
f077870
fix
epwalsh Feb 4, 2025
3fda274
fix
epwalsh Feb 4, 2025
3c3544f
fix
epwalsh Feb 4, 2025
29a7fcc
fix
epwalsh Feb 4, 2025
54ae3ac
add more
epwalsh Feb 4, 2025
4c3d0a9
add
epwalsh Feb 4, 2025
9fb3d52
init weights
epwalsh Feb 4, 2025
dc5c13f
add extra repr to MLP class
epwalsh Feb 4, 2025
97bf618
debugging
epwalsh Feb 4, 2025
5073d50
debug
epwalsh Feb 4, 2025
242649a
fix
epwalsh Feb 4, 2025
0c94a00
clean up
epwalsh Feb 4, 2025
a6cb7ef
debug
epwalsh Feb 4, 2025
66f4afd
debugging
epwalsh Feb 4, 2025
9582c3c
more debug
epwalsh Feb 4, 2025
461072a
lol
epwalsh Feb 4, 2025
d44000c
expert indices
epwalsh Feb 4, 2025
acf03be
try w/ uniform assignment
epwalsh Feb 4, 2025
1ee38fa
fix
epwalsh Feb 4, 2025
fd945a5
debug
epwalsh Feb 4, 2025
e13fedc
small bz
epwalsh Feb 4, 2025
87afd27
more
epwalsh Feb 5, 2025
2ceba52
more
epwalsh Feb 5, 2025
7db7018
more debug
epwalsh Feb 5, 2025
157b383
try this
epwalsh Feb 5, 2025
b111659
try this
epwalsh Feb 5, 2025
0cc1c78
try this
epwalsh Feb 5, 2025
fea70b7
cache
epwalsh Feb 5, 2025
06784d0
fix
epwalsh Feb 5, 2025
9a2cdad
cache
epwalsh Feb 5, 2025
a6b7b93
clean up
epwalsh Feb 5, 2025
0024a8c
add tests for ops
epwalsh Feb 5, 2025
bc61bfa
fix?
epwalsh Feb 5, 2025
74c4e11
fix?
epwalsh Feb 5, 2025
069a77d
add another test
epwalsh Feb 5, 2025
650f030
test with shared
epwalsh Feb 5, 2025
c875ab7
fix
epwalsh Feb 5, 2025
b925af2
check losses
epwalsh Feb 5, 2025
48d7e9a
fix
epwalsh Feb 5, 2025
d47c29d
fix?
epwalsh Feb 5, 2025
0d1ea1f
clean up
epwalsh Feb 5, 2025
4d621d8
clean up
epwalsh Feb 5, 2025
5db8850
comments
epwalsh Feb 5, 2025
3fd2c64
add batched histc
epwalsh Feb 5, 2025
91365f3
stuff
epwalsh Feb 6, 2025
ba2f6ad
add config builder for MoE
epwalsh Feb 6, 2025
729e6cc
Merge branch 'v2' into v2-moe
epwalsh Feb 6, 2025
b315a9c
Merge branch 'v2' into v2-moe
epwalsh Feb 6, 2025
108fd69
Add SmallMoE config
epwalsh Feb 6, 2025
8cd7ed1
use replicate with EP
epwalsh Feb 6, 2025
10e845c
fix?
epwalsh Feb 6, 2025
c130305
dumb
epwalsh Feb 6, 2025
2c6bbc7
fix test?
epwalsh Feb 6, 2025
9bbbcad
okay, let's try this
epwalsh Feb 7, 2025
b134bc9
fix
epwalsh Feb 7, 2025
1474827
require HSDP for expert parallelism
epwalsh Feb 7, 2025
e7e726e
fix dtype
epwalsh Feb 7, 2025
64e384d
fewer active experts
epwalsh Feb 7, 2025
2a9df66
idk
epwalsh Feb 7, 2025
557699e
try this
epwalsh Feb 7, 2025
6972b22
fix?
epwalsh Feb 7, 2025
af6a774
clean up
epwalsh Feb 7, 2025
75dc2bb
custom op
epwalsh Feb 7, 2025
c7b248a
try again
epwalsh Feb 7, 2025
4fb0cfd
try this
epwalsh Feb 7, 2025
24a41a0
pre-cast to int
epwalsh Feb 7, 2025
095f389
debugging
epwalsh Feb 7, 2025
fc1c248
try not flattening
epwalsh Feb 7, 2025
f9d70b0
revert
epwalsh Feb 7, 2025
b1a423c
try this
epwalsh Feb 7, 2025
5c62a67
let's try this
epwalsh Feb 7, 2025
dc0e960
fix
epwalsh Feb 7, 2025
5f47723
fix
epwalsh Feb 7, 2025
45ec6d8
fix
epwalsh Feb 7, 2025
0936393
clean up
epwalsh Feb 7, 2025
4e48bb2
logging
epwalsh Feb 7, 2025
d63e372
fix
epwalsh Feb 7, 2025
99a7f09
clean up
epwalsh Feb 7, 2025
2bcc079
try with replicate
epwalsh Feb 7, 2025
b2c39b8
clean up
epwalsh Feb 7, 2025
d988ff2
back to sharding
epwalsh Feb 7, 2025
b441c45
clean up
epwalsh Feb 7, 2025
18a0073
try dropless
epwalsh Feb 7, 2025
7d7f7b5
revert change to dropless
epwalsh Feb 7, 2025
281785d
fixes
epwalsh Feb 7, 2025
4c6d48e
try again
epwalsh Feb 7, 2025
4a0c0e2
try this
epwalsh Feb 7, 2025
c66bede
idk
epwalsh Feb 7, 2025
a2b8a1b
fix?
epwalsh Feb 7, 2025
ea2bb5c
idk
epwalsh Feb 7, 2025
f8d5940
debugging
epwalsh Feb 7, 2025
406d5ed
make input local
epwalsh Feb 7, 2025
5928f0e
debug
epwalsh Feb 7, 2025
3c65c0c
debug
epwalsh Feb 7, 2025
0c4b132
debug
epwalsh Feb 7, 2025
a829d73
try this
epwalsh Feb 7, 2025
f247d89
debug
epwalsh Feb 7, 2025
78722ca
try again
epwalsh Feb 7, 2025
06d88e3
assert
epwalsh Feb 7, 2025
9ceb128
fix that
epwalsh Feb 7, 2025
cead175
debug
epwalsh Feb 7, 2025
794bc03
try this
epwalsh Feb 7, 2025
fb415b6
maybe fix
epwalsh Feb 7, 2025
76c322a
remove inplace op
epwalsh Feb 7, 2025
d9d9d33
clean up
epwalsh Feb 7, 2025
0ea34ed
clean up tensor parallel
epwalsh Feb 8, 2025
ae24c2a
fix
epwalsh Feb 8, 2025
da04166
try this?
epwalsh Feb 8, 2025
728121e
fix
epwalsh Feb 8, 2025
d613fea
ooops
epwalsh Feb 8, 2025
dc4a8a8
debug
epwalsh Feb 8, 2025
e19ace5
fix
epwalsh Feb 8, 2025
94d3178
extra safety
epwalsh Feb 8, 2025
f3de9c6
fix router test
epwalsh Feb 8, 2025
9c9f5e7
implement loss parallel
epwalsh Feb 8, 2025
7d7249e
do in pipeline too
epwalsh Feb 8, 2025
3994f41
start test for CE loss
epwalsh Feb 8, 2025
5deafea
clean up
epwalsh Feb 8, 2025
9550286
fix
epwalsh Feb 8, 2025
2c7219f
clean up
epwalsh Feb 8, 2025
cfcc2a1
fix?
epwalsh Feb 8, 2025
041bbcd
clean up
epwalsh Feb 8, 2025
3cab69c
add case for none reduction
epwalsh Feb 8, 2025
30245a4
ok try again
epwalsh Feb 9, 2025
cfaf057
fix
epwalsh Feb 9, 2025
761c258
try this
epwalsh Feb 9, 2025
ae8b27d
fix?
epwalsh Feb 9, 2025
75eee3a
fix
epwalsh Feb 9, 2025
a035ee2
fix
epwalsh Feb 9, 2025
b1e2db4
fix
epwalsh Feb 9, 2025
4828414
fix
epwalsh Feb 9, 2025
524876c
debug
epwalsh Feb 9, 2025
00a7b6e
fix?
epwalsh Feb 9, 2025
7509318
check for gradients
epwalsh Feb 9, 2025
31e6615
update train modules
epwalsh Feb 9, 2025
55d60f5
fix
epwalsh Feb 9, 2025
8c9cac1
fix?
epwalsh Feb 9, 2025
db7e42d
idk
epwalsh Feb 9, 2025
747a7ad
make loss parallel configurable
epwalsh Feb 9, 2025
eeb310a
add long context config
epwalsh Feb 9, 2025
2537864
increase context length
epwalsh Feb 9, 2025
e8db8ba
try 64
epwalsh Feb 9, 2025
ae58452
rename some things
epwalsh Feb 10, 2025
91679ef
back to non moe
epwalsh Feb 10, 2025
9e426d5
fix?
epwalsh Feb 10, 2025
5ea11ea
fix?
epwalsh Feb 10, 2025
92538ca
fixed eval sequence length when TP enabled
epwalsh Feb 10, 2025
2bfafde
fix eval batch size
epwalsh Feb 10, 2025
2b2d838
revert the FSL
epwalsh Feb 10, 2025
7b9488f
upgrade olmo-eval
epwalsh Feb 10, 2025
0c7cadc
update test
epwalsh Feb 10, 2025
95ca989
fix?
epwalsh Feb 10, 2025
be2b15b
check loss first
epwalsh Feb 10, 2025
cd818fd
compare local tensor
epwalsh Feb 10, 2025
0ec327c
clean up
epwalsh Feb 10, 2025
d87a5d3
try this
epwalsh Feb 10, 2025
2ea5307
fix?
epwalsh Feb 10, 2025
c55d733
fix
epwalsh Feb 10, 2025
f4ec2e1
exclude default evals from long context config
epwalsh Feb 10, 2025
77d3d71
okay try without that
epwalsh Feb 10, 2025
c84adef
revert
epwalsh Feb 10, 2025
53c1f6c
update install instructions
epwalsh Feb 10, 2025
45fe06c
try with host-networking
epwalsh Feb 10, 2025
9e656d7
use multiple GPUs for MoE test
epwalsh Feb 10, 2025
06d8140
update long context defaults
epwalsh Feb 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ jobs:
matrix:
task:
- name: Test (GPU)
image: olmo-core-tch251cu124
image: olmo-core-tch260cu124
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
Expand All @@ -120,15 +120,15 @@ jobs:
src/test/

- name: Test checkpoint (GPU)
image: olmo-core-tch251cu124
image: olmo-core-tch260cu124
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
src/test/distributed/checkpoint*

- name: Test MoE (GPU)
image: olmo-core-tch251cu124
gpus: 1
image: olmo-core-tch260cu124
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
src/test/nn/moe*
Expand Down Expand Up @@ -182,6 +182,7 @@ jobs:
preemptible: true
resources:
gpuCount: ${{ matrix.task.gpus }}
hostNetworking: true
constraints:
cluster:
# H100 clusters
Expand Down
14 changes: 7 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
CUDA_VERSION = "12.4"
TORCH_CUDA_VERSION = $(shell echo $(CUDA_VERSION) | tr -d .)
TORCH_VERSION = "2.5.1"
TORCH_VERSION = "2.6.0"
TORCH_VERSION_SHORT = $(shell echo $(TORCH_VERSION) | tr -d .)
# NOTE: when upgrading the nightly version you also need to upgrade the torch version specification
# in 'pyproject.toml' to include that nightly version.
TORCH_NIGHTLY_VERSION = "2.6.0.dev20241209"
TORCH_NIGHTLY_VERSION = "2.7.0.dev20250202"
TORCH_NIGHTLY_VERSION_SHORT = $(shell echo $(TORCH_NIGHTLY_VERSION) | tr -d .)
TORCHAO_VERSION = "0.6.1"
MEGABLOCKS_VERSION = "megablocks[gg] @ git+https://[email protected]/epwalsh/megablocks.git@epwalsh/deps"
FLASH_ATTN_WHEEL = https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
TORCHAO_VERSION = "0.8.0"
GROUPED_GEMM_VERSION = "grouped_gemm @ git+https://[email protected]/tgale96/grouped_gemm.git@main"
FLASH_ATTN_WHEEL = https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl

VERSION = $(shell python src/olmo_core/version.py)
VERSION_SHORT = $(shell python src/olmo_core/version.py short)
Expand Down Expand Up @@ -55,7 +55,7 @@ stable-image :
--build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \
--build-arg TORCH_VERSION=$(TORCH_VERSION) \
--build-arg FLASH_ATTN_WHEEL=$(FLASH_ATTN_WHEEL) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg GROUPED_GEMM_VERSION=$(GROUPED_GEMM_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--target stable \
--progress plain \
Expand All @@ -70,7 +70,7 @@ nightly-image :
--build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \
--build-arg TORCH_VERSION=$(TORCH_VERSION) \
--build-arg FLASH_ATTN_WHEEL=$(FLASH_ATTN_WHEEL) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg GROUPED_GEMM_VERSION=$(GROUPED_GEMM_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--build-arg TORCH_NIGHTLY_VERSION=$(TORCH_NIGHTLY_VERSION) \
--target nightly \
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pip install ai2-olmo-core
There are a number of optional dependencies that must be installed to use certain functionality as well, including:
- [flash-attn](https://github.com/Dao-AILab/flash-attention) for flash attention and certain other fused operations.
- [torchao](https://github.com/pytorch/ao) for float8 training.
- [megablocks](https://github.com/databricks/megablocks) for mixture-of-experts (MoE) models.
- [grouped_gemm](https://github.com/tgale96/grouped_gemm) for dropless mixture-of-experts (MoE) models. You may need to compile from source until [PR #21](https://github.com/tgale96/grouped_gemm/pull/21) is released (post v0.1.6).

The published [Docker images](https://github.com/orgs/allenai/packages?repo_name=OLMo-core) contain all core and optional dependencies, and are regularly tested on our in-house H100 clusters.
But there are several things to keep in mind if you intend to use these images:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/overview/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ There are a number of optional dependencies that must be installed to use certai

- `flash-attn <https://github.com/Dao-AILab/flash-attention>`_ for flash attention and certain other fused operations.
- `torchao <https://github.com/pytorch/ao>`_ for float8 training (see :mod:`olmo_core.float8`).
- `megablocks <https://github.com/databricks/megablocks>`_ for mixture-of-experts (MoE) models (see :mod:`olmo_core.nn.moe`).
- `grouped_gemm <https://github.com/tgale96/grouped_gemm>`_ for dropless mixture-of-experts (MoE) models (see :mod:`olmo_core.nn.moe`).
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"omegaconf",
"safetensors",
"importlib_resources",
"ai2-olmo-eval==0.5.0",
"ai2-olmo-eval==0.6.1",
]

[project.urls]
Expand Down Expand Up @@ -169,4 +169,5 @@ filterwarnings = [
'ignore::DeprecationWarning:pkg_resources',
'ignore::DeprecationWarning:google\.rpc',
'ignore::FutureWarning:torch\.distributed\.checkpoint\.default_planner',
'ignore::UserWarning:torch\.distributed\.checkpoint\.state_dict_saver',
]
25 changes: 13 additions & 12 deletions src/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: make sure CUDA_VERSION and TORCH_CUDA_VERSION always match, except for punctuation
ARG CUDA_VERSION="12.4"
ARG TORCH_CUDA_VERSION="124"
ARG TORCH_VERSION="2.5.1"
ARG TORCH_VERSION="2.6.0

#########################################################################
# Build image
Expand All @@ -24,22 +24,23 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
# Install/upgrade Python build dependencies.
RUN pip install --upgrade --no-cache-dir pip wheel packaging "setuptools<70.0.0" ninja

# Build megablocks, grouped-gemm, stanford-stk
# Build grouped-gemm.
# NOTE: right now we need to build with CUTLASS so we can pass batch sizes on GPU.
# See https://github.com/tgale96/grouped_gemm/pull/21
ENV TORCH_CUDA_ARCH_LIST="8.0 9.0"
ENV GROUPED_GEMM_CUTLASS="1"
ARG MEGABLOCKS_VERSION="megablocks[gg] @ git+https://[email protected]/epwalsh/megablocks.git@epwalsh/deps"
RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}"
ARG GROUPED_GEMM_VERSION="grouped_gemm @ git+https://[email protected]/tgale96/grouped_gemm.git@main"
RUN pip wheel --no-build-isolation --no-cache-dir "${GROUPED_GEMM_VERSION}"

# Build flash-attn.
ARG FLASH_ATTN_WHEEL=https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
ARG FLASH_ATTN_WHEEL=https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
RUN wget ${FLASH_ATTN_WHEEL}

# Only keep the target wheels and dependencies with CUDA extensions.
RUN echo "Built wheels:" \
&& ls -lh . \
&& ls -1 | grep -Ev 'megablocks|grouped_gemm|stanford_stk|flash_attn' | xargs rm \
&& echo "Final wheels:" \
&& ls -lh .
RUN echo "Built wheels:" && ls -lh .
# && ls -1 | grep -Ev 'grouped_gemm|flash_attn' | xargs rm \
# && echo "Final wheels:" \
# && ls -lh .

#########################################################################
# Stable image
Expand Down Expand Up @@ -73,7 +74,7 @@ RUN pip install --upgrade --no-cache-dir pip wheel packaging

# Install torchao.
ARG TORCH_CUDA_VERSION
ARG TORCHAO_VERSION="0.6.1"
ARG TORCHAO_VERSION="0.8.0"
RUN pip install --no-cache-dir \
--extra-index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} \
torchao==${TORCHAO_VERSION}
Expand All @@ -100,7 +101,7 @@ WORKDIR /app/olmo-core
FROM stable as nightly

ARG TORCH_CUDA_VERSION
ARG TORCH_NIGHTLY_VERSION="2.6.0.dev20241209"
ARG TORCH_NIGHTLY_VERSION="2.7.0.dev20250202"
RUN pip install --no-cache-dir --pre \
--index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} \
torch==${TORCH_NIGHTLY_VERSION}+cu${TORCH_CUDA_VERSION}
1 change: 1 addition & 0 deletions src/examples/llama/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
tokenizer_config = TokenizerConfig.gpt2()

model_config = TransformerConfig.llama2_271M(
# model_config = TransformerConfig.smallmoe(
vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128
)

Expand Down
4 changes: 3 additions & 1 deletion src/olmo_core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import torch
import torch.nn.functional as F

from olmo_core.aliases import PathOrStr
from olmo_core.io import add_cached_path_clients, get_bytes_range, is_url, resource_path
Expand Down Expand Up @@ -467,4 +468,5 @@ def get_labels(batch: Dict[str, Any], label_ignore_index: int = -100) -> torch.T
labels.masked_fill_(attention_mask == 0.0, label_ignore_index)
if instance_mask is not None:
labels.masked_fill_(~instance_mask.unsqueeze(-1), value=label_ignore_index)
return labels[..., 1:].contiguous()
# Shift and pad.
return F.pad(labels[..., 1:], (0, 1, 0, 0), value=label_ignore_index)
Loading