feat: RolloutGenerationMixin + per-prompt prefetch#694
Open
shuangwu wants to merge 5 commits into
Open
Conversation
4995882 to
48fb71f
Compare
Collaborator
Author
|
The job did not fail on a Python exception in the test code; it timed out after 2 hours |
d7f5487 to
cac10e6
Compare
20fa017 to
317abc0
Compare
Add Slurm CI tooling and harden replica teardown so test_process_flow no longer hangs for hours: fail-fast Redis stream polls during shutdown, controller finalize on heartbeat-death reap (with a startup guard so SFT controllers do not self-destruct before workers register), bounded HTTP timeouts, teardown progress logging, and a bounded test-side wait with faulthandler diagnostics.
Reworks the UCXX payload transport for correctness and
simplicity, and removes a deprecated trainer-side mixin.
**Wire protocol (single chunk per slot).** One connection, one
``send([slot])``, one ``recv(status)``, one ``recv(payload)``.
The handler that performs the slot's ``READY -> READING``
transition is the unique owner of its
``READING -> {FREE | READY}`` finalisation for that read
attempt. No refcounting, no shared cross-handler state, no
multi-chunk fan-out.
**Two structural invariants** make orphan handlers harmless:
1. *Defensive slot-state transitions.*
``SharedRingBuffer.mark_consumed`` and ``release_reading``
no-op (with a warning) on unexpected slot states instead of
clobbering them. A late finalise from an orphan handler --
e.g. one still mid-``send`` after the client rotated to a
different port -- cannot silently undo a writer's recycle.
2. *Single owner per read attempt.* Each handler calls exactly
one of ``mark_consumed`` (success) or ``release_reading``
(failure) in its ``finally`` block.
**Three-layer client retry stack.**
* ``UCXXClient.read``: per-call port rotation with on-timeout
fallback, plus a per-port skip-list so ports that have just
produced ``UCXXConnectionResetError`` /
``UCXXMessageTruncated`` / timeout are quarantined for
``_PORT_QUARANTINE_SEC`` and skipped on subsequent calls
while at least one healthy port remains.
* ``UCXXDataPackerMixin._read_one``: per-slot fresh-call retry,
with ``StaleSlotError`` flagged non-retryable so a recycled
slot is not re-fetched in the same batch.
* ``UCXXDataPackerMixin._ucxx_dp_fetch_all``: multi-round batch
retry (``_MAX_FETCH_ROUNDS``).
**Audit-pass fixes.**
* Stop server-thread busy-spin: shutdown watchdog now sleeps
for 50 ms instead of yielding via ``asyncio.sleep(0)``.
* ``SharedRingBuffer.get_ready_count`` actually counts
``READY`` slots instead of returning the saturating
``entry_count`` header.
* ``UCXXClient`` preemptively closes pooled endpoints older
than ``_POOL_ENDPOINT_MAX_AGE_S`` on checkout, replacing
them with fresh connections so a client never tries to
reuse an endpoint the server has already idle-evicted.
* Demote the per-request ``[UCXXBuffer] req slot=...`` log
from INFO to DEBUG.
**Cleanup.** Removes the deprecated ``UCXXTrainerMixin``
(superseded by ``UCXXDataPackerMixin``) and the
``ucxx_n_chunks`` config knob. ``read_raw`` still tolerates
re-entry from ``READING`` -- not for chunking, but to let a
fresh handler serve a slot when the original is still
mid-``send``; the defensive guards above make the duplicate
finalise harmless.
**Docs.** ``ucxx/__init__.py`` and ``rl-gym/docs/UCXX.md``
rewritten with the slot-lifecycle diagram, the three-layer
retry table, port skip-list semantics, and the
preemptive-eviction self-healing window.
Tests: 55 unit tests in ``tests/test_ucxx_transport.py`` and
``tests/test_ucxx_data_packer_mixin.py`` covering the
defensive primitive guards, ``get_ready_count`` correctness,
stale-pooled-endpoint replacement, and the three-layer retry
behaviour.
Separate training completion from shutdown so rollouts can drain cleanly. Defer TrainingComplete until policy replicas finish in-flight work, bounded-drain P2R/teardown streams, and kill full subprocess trees in GPU integration tests so controller/torchrun zombies do not poison later CI cases.
…klog/staleness logging Defer TrainingComplete in finish_draining when buffered rollouts can still fill remaining training steps, fixing custom_rollout computed_cnt CI failure.
…prep overlap) Opt-in RolloutGenerationMixin for RolloutBase subclasses that structures rollout_generation via five hooks (_preflight_check, _prepare_sample, _collate_batch, _generate, _postprocess) and overlaps per-prompt preprocessing with in-flight generation when config.rollout.prefetch_rollout is set. The gym example is the first in-tree adopter. _single_producer_mode reads prefetch_rollout live (not snapshotted) so subclasses can flip it from post_init_hook. Two overlap paths: - Fetch overlap (single-process, world_size == 1): the dedicated rank-0 _prefetch_loop hides the controller get_next_prompt round-trip and stages prep. Requires world_size == 1 because its HTTP fetch ends in a broadcast/scatter every rank must join. - Prep overlap (multi-rank, world_size > 1): main_loop fetches one batch ahead in lockstep (request_new_prompts max_pending_batches=2) and hands the freshly-fetched batch to _submit_prefetch_setup, so per-prompt _prepare_sample runs on the mixin's rank-local bg setup thread while the batch in front of it generates. No collective beyond the existing fetch, so it is safe under DP/TP/PP/CP parallelism. Previously prefetch was silently disabled for multi-rank workers. Shared helpers _bind_prefetch_context_once / _submit_prefetch_setup are used by both the rank-0 fetch loop and the multi-rank main_loop. The legacy enqueue_prefetch_payloads hook remains supported as a deprecation shim. Tests: tests/test_rollout_generation_mixin.py, tests/test_rollout_prefetch_loop_integration.py, and TestGymBackendStructured / TestGymBackendPrefetch in tests/test_gym_example.py.
317abc0 to
176151a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Branch:
feat/rollout-generation-mixin(currently atc9b2183)Base:
upstream/main(c35f517), but stacked on PR #689 and the UCXX PR (see below)Relationship:
feat/gym-example-trainer) — the gym example is the first in-tree adopter of the mixin. Two commits from feat: gym example trainer #689 (7706326,95b35db) appear in this branch'supstream/main..HEADuntil feat: gym example trainer #689 merges; they drop naturally on rebase.feat/ucxx-single-chunk-rotation) — the UCXX single-chunk commit (c9b2183here,acec19aon the standalone branch — same content, different SHA due to rebase) is included so this branch is self-sufficient for bench runs. Drops on rebase once the UCXX PR merges.96941ac) once both prerequisite PRs land.Suggested merge order: PR #689 → UCXX PR → this PR.
Summary
Adds
RolloutGenerationMixinforRolloutBasesubclasses that structuresrollout_generationvia five hooks and overlaps per-prompt preprocessing with in-flight generation whenconfig.rollout.prefetch_rolloutis set. Bundles in eight independent controller-rollout robustness fixes debugged alongside the mixin during the May-2026 bench cycle.The two concerns are combined because they share hunks in
cosmos_rl/rollout/worker/rollout_control.py(prefetch-loop instrumentation: Trace A/B/C, branch counters,_prompt_queuepeeks) and were bench-validated together. Splitting them on a hunk basis produced an awkward "subset" commit that wasn't a clean cherry-pick; combining them keeps the provenance honest.Background
Why a mixin
Today every
RolloutBasesubclass implements its ownrollout_generationloop end-to-end, duplicating the prompt-pull → preflight → generate → postprocess scaffolding. This makes it hard to add cross-cutting features (prefetch, instrumentation, retry policy) without N-way drift.RolloutGenerationMixinreifies the loop in five hooks (_preflight_check,_prepare_sample,_collate_batch,_generate,_postprocess) so subclasses contribute the workload-specific pieces and the mixin owns the orchestration.Why prefetch
On the disaggregated single-prompt path the rollout worker spends N ms per prompt on preprocessing (tokenisation, image decode, prompt rewrites) before reaching
generate(). Whengenerate()itself is also CPU-bound between kernel launches, the prep cost serialises the loop and inflates per-iteration latency. Prefetch overlaps the next prompt's preprocessing with the current generation, on the same worker thread, using a small bounded queue. Opt-in viaconfig.rollout.prefetch_rolloutso subclasses that don't want it pay no cost.Why eight robustness fixes
The May-2026 bench cycle surfaced eight independent controller-rollout correctness / stability issues, each reproducible and small enough to land alongside the mixin work:
samples_on_the_flyover-counting drove the controller's soft throttle to engage spuriously.prompt_consume_endbranch indefinitely.unregistercould block worker exit.Each is described in the commit body and is self-contained (could be reverted independently).
Changes
Commit
96941ac— RolloutGenerationMixin + prefetch + robustnessPrefetch mixin (
cosmos_rl/rollout/worker/rollout_generation_mixin.py+ plumbing):_preflight_check,_prepare_sample,_collate_batch,_generate,_postprocess.config.rollout.prefetch_rollout; reads the flag live (not snapshotted in__init__) so subclasses can flip it frompost_init_hook._prompt_queuefor overlap.Controller-rollout robustness fixes:
samples_on_the_flyaccounting — decrement onfilter_outdated_rolloutsdiscards; decrementtrain_ackby the actual per-step dispatch count viadispatched_rollouts_by_step(was sized bytrain_batch_per_replica * arrived_replicas, equal on normal steps but zero onis_fake_last_cmd); include in-flight samples inrecompute_total_steps.shutdown_signalonprompt_consume_end()(mirrors the async path; previously a crashed controller left workers hot-spinning the consume_end branch indefinitely); 50 ms backoff after weight-version rejection (was busy-spinning until the next P→R broadcast).unregister/post_heartbeat/ daemon joins with timeouts so a wedged controller can't prevent worker exit; skipBuildMeshCommandwhen this replica is draining (defence-in-depth for the controller-side filter below).RolloutStatusManager.unregister's BuildMesh recipient set to live survivors only (not status.ended) and require>= 2survivors before triggering rebuild.PR_SET_PDEATHSIGon the heartbeat subprocess so a parent SIGSEGV doesn't leave it orphaned and lying to the controller; opt-inCOSMOS_SHUTDOWN_ON_NO_POLICY_REPLICASto exit the controller when all policy replicas are dead (off by default to preserve dynamic-replica deployments).Commit
c9b2183— UCXX single-chunk + client rotationSame content as
feat/ucxx-single-chunk-rotation'sacec19a; see that PR. Included here so this branch is self-sufficient for bench runs. Drops on rebase once that PR merges.Validation
Local CPU test suite (driver:
pytest -q --tb=shorton branch tipc9b2183)tests/test_rollout_generation_mixin.py(16 tests incl.test_prefetch_throughput_synthetic,test_subclass_preflight_check_propagates)tests/test_rollout_prefetch_loop_integration.pytests/test_put_rollouts.pytests/test_gym_example.py::TestGymBackendStructured+::TestGymBackendPrefetchtests/test_ucxx_*(UCXX commit)Bench validation
rl-gym companion branch
feat/rollout-generation-mixincarries theSimpleRolloutWorker/ModularRolloutWorkeradopters plus the bench matrix. May-2026 cycle's headline numbers are paste-ready inrl-gym/docs/BENCHMARKS.md. Five disaggregated reruns confirmed each robustness fix eliminates its corresponding wedge mode.Caveats
test_put_rollouts.pyis the only CPU-runnable test that imports any of the eight changed controller/dispatcher modules. The dispatch-accounting / soft-throttle /RolloutStatusManager.unregisterfilter changes are validated by the May-2026 bench cycle (see commit body). Multi-rank integration tests (test_policy_overfit,test_policy_to_rollout) gate this on the CI side.Failed to import VLLM Rolloutlog line — benign, pre-existing, caused by the local venv not havingvllminstalled. Not introduced by this PR.Known follow-up (not blocking)
Bench data revealed a separate NCCL P2R recv backlog stalls the inference stream on the rollout side when the trainer is fast. This is a downstream effect that limits prefetch's measured upside in the May-2026 matrix to 5–7% (vs. the historical 17%), not a correctness issue. Investigation is complete; fix lives outside this stack and will ship as a separate PR. Cross-references:
rl-gym/docs/PREFETCH.mdcarries a note that prefetch's observed tail-latency cost is downstream of this bug, not intrinsic to prefetch.generate_ms=50026.17on job28243035(vs. typical ~2,500 ms).cosmos_rl_in_flight.md.Test plan
test_policy_overfit,test_policy_to_rollout,test_context_parallel(multi-GPU NCCL paths that exercise the controller-rollout protocol).prefetch_rolloutflag read (deliberately not snapshotted in__init__so subclasses can flip frompost_init_hook— verify this is the contract you want).c9b2183).feat/rollout-generation-mixincompanion branch.