Skip to content

feat: RolloutGenerationMixin + per-prompt prefetch#694

Open
shuangwu wants to merge 5 commits into
nvidia-cosmos:mainfrom
shuangwu:feat/rollout-generation-mixin
Open

feat: RolloutGenerationMixin + per-prompt prefetch#694
shuangwu wants to merge 5 commits into
nvidia-cosmos:mainfrom
shuangwu:feat/rollout-generation-mixin

Conversation

@shuangwu

Copy link
Copy Markdown
Collaborator

Branch: feat/rollout-generation-mixin (currently at c9b2183)
Base: upstream/main (c35f517), but stacked on PR #689 and the UCXX PR (see below)
Relationship:

  • Stacked on PR feat: gym example trainer #689 (feat/gym-example-trainer) — the gym example is the first in-tree adopter of the mixin. Two commits from feat: gym example trainer #689 (7706326, 95b35db) appear in this branch's upstream/main..HEAD until feat: gym example trainer #689 merges; they drop naturally on rebase.
  • Stacked on the UCXX PR (feat/ucxx-single-chunk-rotation) — the UCXX single-chunk commit (c9b2183 here, acec19a on the standalone branch — same content, different SHA due to rebase) is included so this branch is self-sufficient for bench runs. Drops on rebase once the UCXX PR merges.
  • Final standalone contribution: one commit (96941ac) once both prerequisite PRs land.

Suggested merge order: PR #689 → UCXX PR → this PR.


Summary

Adds RolloutGenerationMixin for RolloutBase subclasses that structures rollout_generation via five hooks and overlaps per-prompt preprocessing with in-flight generation when config.rollout.prefetch_rollout is set. Bundles in eight independent controller-rollout robustness fixes debugged alongside the mixin during the May-2026 bench cycle.

The two concerns are combined because they share hunks in cosmos_rl/rollout/worker/rollout_control.py (prefetch-loop instrumentation: Trace A/B/C, branch counters, _prompt_queue peeks) and were bench-validated together. Splitting them on a hunk basis produced an awkward "subset" commit that wasn't a clean cherry-pick; combining them keeps the provenance honest.

Background

Why a mixin

Today every RolloutBase subclass implements its own rollout_generation loop end-to-end, duplicating the prompt-pull → preflight → generate → postprocess scaffolding. This makes it hard to add cross-cutting features (prefetch, instrumentation, retry policy) without N-way drift. RolloutGenerationMixin reifies the loop in five hooks (_preflight_check, _prepare_sample, _collate_batch, _generate, _postprocess) so subclasses contribute the workload-specific pieces and the mixin owns the orchestration.

Why prefetch

On the disaggregated single-prompt path the rollout worker spends N ms per prompt on preprocessing (tokenisation, image decode, prompt rewrites) before reaching generate(). When generate() itself is also CPU-bound between kernel launches, the prep cost serialises the loop and inflates per-iteration latency. Prefetch overlaps the next prompt's preprocessing with the current generation, on the same worker thread, using a small bounded queue. Opt-in via config.rollout.prefetch_rollout so subclasses that don't want it pay no cost.

Why eight robustness fixes

The May-2026 bench cycle surfaced eight independent controller-rollout correctness / stability issues, each reproducible and small enough to land alongside the mixin work:

  • samples_on_the_fly over-counting drove the controller's soft throttle to engage spuriously.
  • Wedged controllers left workers hot-spinning the prompt_consume_end branch indefinitely.
  • A wedged controller's unregister could block worker exit.
  • Heartbeat subprocesses outlived their parent on SIGSEGV and lied to the controller.
  • The end-of-run mesh-rebuild could try to include dead replicas in the new mesh.
  • Etc.

Each is described in the commit body and is self-contained (could be reverted independently).

Changes

Commit 96941ac — RolloutGenerationMixin + prefetch + robustness

Prefetch mixin (cosmos_rl/rollout/worker/rollout_generation_mixin.py + plumbing):

  • Five hooks: _preflight_check, _prepare_sample, _collate_batch, _generate, _postprocess.
  • Opt-in via config.rollout.prefetch_rollout; reads the flag live (not snapshotted in __init__) so subclasses can flip it from post_init_hook.
  • Single-producer-mode property + bounded _prompt_queue for overlap.
  • Trace A (per-iteration branch counters), Trace B (queue depth), Trace C (per-generation entry/exit) at DEBUG.

Controller-rollout robustness fixes:

  • Controller samples_on_the_fly accounting — decrement on filter_outdated_rollouts discards; decrement train_ack by the actual per-step dispatch count via dispatched_rollouts_by_step (was sized by train_batch_per_replica * arrived_replicas, equal on normal steps but zero on is_fake_last_cmd); include in-flight samples in recompute_total_steps.
  • Rollout main_loop — set shutdown_signal on prompt_consume_end() (mirrors the async path; previously a crashed controller left workers hot-spinning the consume_end branch indefinitely); 50 ms backoff after weight-version rejection (was busy-spinning until the next P→R broadcast).
  • Rollout teardown — bound unregister / post_heartbeat / daemon joins with timeouts so a wedged controller can't prevent worker exit; skip BuildMeshCommand when this replica is draining (defence-in-depth for the controller-side filter below).
  • Controller end-of-run mesh management — filter RolloutStatusManager.unregister's BuildMesh recipient set to live survivors only (not status.ended) and require >= 2 survivors before triggering rebuild.
  • Process lifecyclePR_SET_PDEATHSIG on the heartbeat subprocess so a parent SIGSEGV doesn't leave it orphaned and lying to the controller; opt-in COSMOS_SHUTDOWN_ON_NO_POLICY_REPLICAS to exit the controller when all policy replicas are dead (off by default to preserve dynamic-replica deployments).
  • Log hygiene — soft-throttle engagement / heartbeat / release events + paired Trace D dispatch-state snapshot for the silent throttle path; Trace E for the controller-side filter decisions.

Commit c9b2183 — UCXX single-chunk + client rotation

Same content as feat/ucxx-single-chunk-rotation's acec19a; see that PR. Included here so this branch is self-sufficient for bench runs. Drops on rebase once that PR merges.

Validation

Local CPU test suite (driver: pytest -q --tb=short on branch tip c9b2183)

File Result
tests/test_rollout_generation_mixin.py (16 tests incl. test_prefetch_throughput_synthetic, test_subclass_preflight_check_propagates) passed
tests/test_rollout_prefetch_loop_integration.py passed
tests/test_put_rollouts.py passed
tests/test_gym_example.py::TestGymBackendStructured + ::TestGymBackendPrefetch passed
tests/test_ucxx_* (UCXX commit) 55 passed / 1 skipped
Stack total 171 passed / 1 skipped / 180 subtests

Bench validation

rl-gym companion branch feat/rollout-generation-mixin carries the SimpleRolloutWorker / ModularRolloutWorker adopters plus the bench matrix. May-2026 cycle's headline numbers are paste-ready in rl-gym/docs/BENCHMARKS.md. Five disaggregated reruns confirmed each robustness fix eliminates its corresponding wedge mode.

Caveats

  • The robustness portion has thin direct unit-test coveragetest_put_rollouts.py is the only CPU-runnable test that imports any of the eight changed controller/dispatcher modules. The dispatch-accounting / soft-throttle / RolloutStatusManager.unregister filter changes are validated by the May-2026 bench cycle (see commit body). Multi-rank integration tests (test_policy_overfit, test_policy_to_rollout) gate this on the CI side.
  • Local-only Failed to import VLLM Rollout log line — benign, pre-existing, caused by the local venv not having vllm installed. Not introduced by this PR.

Known follow-up (not blocking)

Bench data revealed a separate NCCL P2R recv backlog stalls the inference stream on the rollout side when the trainer is fast. This is a downstream effect that limits prefetch's measured upside in the May-2026 matrix to 5–7% (vs. the historical 17%), not a correctness issue. Investigation is complete; fix lives outside this stack and will ship as a separate PR. Cross-references:

  • rl-gym/docs/PREFETCH.md carries a note that prefetch's observed tail-latency cost is downstream of this bug, not intrinsic to prefetch.
  • Smoking-gun bench artefact: rollout_0 generate_ms=50026.17 on job 28243035 (vs. typical ~2,500 ms).
  • Three candidate fixes documented in cosmos_rl_in_flight.md.

Test plan

  • full CI matrix green — especially test_policy_overfit, test_policy_to_rollout, test_context_parallel (multi-GPU NCCL paths that exercise the controller-rollout protocol).
  • reviewer sanity-check on the eight robustness fixes — each is self-contained and could be reverted independently if needed.
  • reviewer sanity-check on the live prefetch_rollout flag read (deliberately not snapshotted in __init__ so subclasses can flip from post_init_hook — verify this is the contract you want).
  • rebase once PR feat: gym example trainer #689 lands (drops the two gym-example commits).
  • rebase once the UCXX PR lands (drops c9b2183).
  • confirm bench numbers reproduce on a clean cluster via rl-gym feat/rollout-generation-mixin companion branch.

@shuangwu shuangwu changed the title eat: RolloutGenerationMixin + per-prompt prefetch + controller-rollout robustness Feat: RolloutGenerationMixin + per-prompt prefetch + controller-rollout robustness May 27, 2026
@shuangwu shuangwu changed the title Feat: RolloutGenerationMixin + per-prompt prefetch + controller-rollout robustness feat: RolloutGenerationMixin + per-prompt prefetch + controller-rollout robustness May 27, 2026
@shuangwu shuangwu force-pushed the feat/rollout-generation-mixin branch 2 times, most recently from 4995882 to 48fb71f Compare May 27, 2026 15:26
@shuangwu

Copy link
Copy Markdown
Collaborator Author

The job did not fail on a Python exception in the test code; it timed out after 2 hours

@shuangwu shuangwu force-pushed the feat/rollout-generation-mixin branch 3 times, most recently from d7f5487 to cac10e6 Compare June 2, 2026 04:23
@shuangwu shuangwu changed the title feat: RolloutGenerationMixin + per-prompt prefetch + controller-rollout robustness feat: RolloutGenerationMixin + per-prompt prefetch Jun 2, 2026
@shuangwu shuangwu force-pushed the feat/rollout-generation-mixin branch 2 times, most recently from 20fa017 to 317abc0 Compare June 2, 2026 21:19
shuangwu added 5 commits June 5, 2026 14:16
Add Slurm CI tooling and harden replica teardown so test_process_flow no
longer hangs for hours: fail-fast Redis stream polls during shutdown,
controller finalize on heartbeat-death reap (with a startup guard so SFT
controllers do not self-destruct before workers register), bounded HTTP
timeouts, teardown progress logging, and a bounded test-side wait with
faulthandler diagnostics.
Reworks the UCXX payload transport for correctness and
simplicity, and removes a deprecated trainer-side mixin.

**Wire protocol (single chunk per slot).**  One connection, one
``send([slot])``, one ``recv(status)``, one ``recv(payload)``.
The handler that performs the slot's ``READY -> READING``
transition is the unique owner of its
``READING -> {FREE | READY}`` finalisation for that read
attempt.  No refcounting, no shared cross-handler state, no
multi-chunk fan-out.

**Two structural invariants** make orphan handlers harmless:

1. *Defensive slot-state transitions.*
   ``SharedRingBuffer.mark_consumed`` and ``release_reading``
   no-op (with a warning) on unexpected slot states instead of
   clobbering them.  A late finalise from an orphan handler --
   e.g. one still mid-``send`` after the client rotated to a
   different port -- cannot silently undo a writer's recycle.

2. *Single owner per read attempt.*  Each handler calls exactly
   one of ``mark_consumed`` (success) or ``release_reading``
   (failure) in its ``finally`` block.

**Three-layer client retry stack.**

* ``UCXXClient.read``: per-call port rotation with on-timeout
  fallback, plus a per-port skip-list so ports that have just
  produced ``UCXXConnectionResetError`` /
  ``UCXXMessageTruncated`` / timeout are quarantined for
  ``_PORT_QUARANTINE_SEC`` and skipped on subsequent calls
  while at least one healthy port remains.
* ``UCXXDataPackerMixin._read_one``: per-slot fresh-call retry,
  with ``StaleSlotError`` flagged non-retryable so a recycled
  slot is not re-fetched in the same batch.
* ``UCXXDataPackerMixin._ucxx_dp_fetch_all``: multi-round batch
  retry (``_MAX_FETCH_ROUNDS``).

**Audit-pass fixes.**

* Stop server-thread busy-spin: shutdown watchdog now sleeps
  for 50 ms instead of yielding via ``asyncio.sleep(0)``.
* ``SharedRingBuffer.get_ready_count`` actually counts
  ``READY`` slots instead of returning the saturating
  ``entry_count`` header.
* ``UCXXClient`` preemptively closes pooled endpoints older
  than ``_POOL_ENDPOINT_MAX_AGE_S`` on checkout, replacing
  them with fresh connections so a client never tries to
  reuse an endpoint the server has already idle-evicted.
* Demote the per-request ``[UCXXBuffer] req slot=...`` log
  from INFO to DEBUG.

**Cleanup.**  Removes the deprecated ``UCXXTrainerMixin``
(superseded by ``UCXXDataPackerMixin``) and the
``ucxx_n_chunks`` config knob.  ``read_raw`` still tolerates
re-entry from ``READING`` -- not for chunking, but to let a
fresh handler serve a slot when the original is still
mid-``send``; the defensive guards above make the duplicate
finalise harmless.

**Docs.**  ``ucxx/__init__.py`` and ``rl-gym/docs/UCXX.md``
rewritten with the slot-lifecycle diagram, the three-layer
retry table, port skip-list semantics, and the
preemptive-eviction self-healing window.

Tests: 55 unit tests in ``tests/test_ucxx_transport.py`` and
``tests/test_ucxx_data_packer_mixin.py`` covering the
defensive primitive guards, ``get_ready_count`` correctness,
stale-pooled-endpoint replacement, and the three-layer retry
behaviour.
Separate training completion from shutdown so rollouts can drain cleanly.
Defer TrainingComplete until policy replicas finish in-flight work, bounded-drain
P2R/teardown streams, and kill full subprocess trees in GPU integration tests
so controller/torchrun zombies do not poison later CI cases.
…klog/staleness logging

Defer TrainingComplete in finish_draining when buffered rollouts can still
fill remaining training steps, fixing custom_rollout computed_cnt CI failure.
…prep overlap)

Opt-in RolloutGenerationMixin for RolloutBase subclasses that structures
rollout_generation via five hooks (_preflight_check, _prepare_sample,
_collate_batch, _generate, _postprocess) and overlaps per-prompt
preprocessing with in-flight generation when config.rollout.prefetch_rollout
is set.  The gym example is the first in-tree adopter.  _single_producer_mode
reads prefetch_rollout live (not snapshotted) so subclasses can flip it from
post_init_hook.

Two overlap paths:
- Fetch overlap (single-process, world_size == 1): the dedicated rank-0
  _prefetch_loop hides the controller get_next_prompt round-trip and stages
  prep.  Requires world_size == 1 because its HTTP fetch ends in a
  broadcast/scatter every rank must join.
- Prep overlap (multi-rank, world_size > 1): main_loop fetches one batch
  ahead in lockstep (request_new_prompts max_pending_batches=2) and hands the
  freshly-fetched batch to _submit_prefetch_setup, so per-prompt
  _prepare_sample runs on the mixin's rank-local bg setup thread while the
  batch in front of it generates.  No collective beyond the existing fetch,
  so it is safe under DP/TP/PP/CP parallelism.  Previously prefetch was
  silently disabled for multi-rank workers.

Shared helpers _bind_prefetch_context_once / _submit_prefetch_setup are used
by both the rank-0 fetch loop and the multi-rank main_loop.  The legacy
enqueue_prefetch_payloads hook remains supported as a deprecation shim.

Tests: tests/test_rollout_generation_mixin.py,
tests/test_rollout_prefetch_loop_integration.py, and
TestGymBackendStructured / TestGymBackendPrefetch in tests/test_gym_example.py.
@shuangwu shuangwu force-pushed the feat/rollout-generation-mixin branch from 317abc0 to 176151a Compare June 10, 2026 04:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant