You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When developing test_pp_cp and chatting with fegin, we realized the
freqs_cis buffers are not being handled correctly in torchtitan for the
pipelining case.
CP needs to modify the freqs_cis buffer to account for sharding on seq
dim, but in the previous titan code this was implemented incorrectly.
`model.freqs_cis` was passed to CP for sharding, but pipelining does not
use `model` at all, it uses the different stage-models contained in
`model_parts` list. The fix is to tell CP context about each freqs_cis
buffer inside `model_parts` models.
Alternatively we could tie the freqs_cis buffers for each pp stage
together, by explicitly doing so after calling init_weights per
pp-stage. However this is of limited value so we skip it.
ghstack-source-id: 7aa3935
Pull Request resolved: #792
0 commit comments