Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4c60148
fixed inference skipping
guipenedo Nov 7, 2025
97bd686
nit
guipenedo Nov 7, 2025
00045c6
nit
guipenedo Nov 7, 2025
16d20e0
refactored inf runner
guipenedo Nov 11, 2025
8a2f019
style
guipenedo Nov 11, 2025
9819f29
drop documents with 0 successful rollouts
guipenedo Nov 11, 2025
f798ddd
add requests cache
guipenedo Nov 12, 2025
50f3157
nit
guipenedo Nov 12, 2025
c1e0d60
perf improvements (less aggressive fs hits)
guipenedo Nov 12, 2025
9358be6
improved writes with queue
guipenedo Nov 12, 2025
fb01939
aiosqlite
guipenedo Nov 12, 2025
fb8413e
tmp sync on cluster
hynky1999 Nov 14, 2025
f5d97af
working version for slurm
Nov 17, 2025
023c538
fix master node import
Nov 17, 2025
1876c8b
capture output of ray stop
Nov 17, 2025
4ea13d8
sync locally
Nov 19, 2025
5e81239
final polishes
hynky1999 Nov 19, 2025
f92057b
nit condition during distributed check
Nov 19, 2025
a7752bd
Merger with main
hynky1999 Nov 19, 2025
1914b1d
push ray
hynky1999 Nov 20, 2025
b20218c
fix issues with vllm and sglang on slurm
Nov 20, 2025
51c61e7
Merge branch 'multi-node-inference' of github.com:huggingface/datatro…
Nov 20, 2025
0282f60
get ray + sglang working
Nov 20, 2025
46e0d50
logging node in multinode, fixes from debugging + prettier
hynky1999 Nov 20, 2025
261f680
prettier
hynky1999 Nov 20, 2025
0d654af
removed auto restart and distributed coordinator + small nits
guipenedo Nov 21, 2025
5278c66
remove wal
guipenedo Nov 21, 2025
e212a7f
envs vars consistency + vllm master node tracking or nodes
Nov 24, 2025
45b5eaa
fmt
hynky1999 Nov 24, 2025
ba8ad67
add example
hynky1999 Nov 24, 2025
0715bff
make ray checks async and with lnoger timeout
Nov 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/inference_example_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from datatrove.data import Document
from datatrove.executor.local import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.inference.run_inference import InferenceConfig, InferenceResult, InferenceRunner
from datatrove.pipeline.writers import JsonlWriter

Expand Down Expand Up @@ -189,6 +190,27 @@ async def process_page(page: int) -> InferenceResult:
tasks=1,
)

# Example 3: Distributed inference
pipeline_executor_distributed = SlurmPipelineExecutor(
tasks=100,
time="10:00:00",
partition="hopper-prod",
gpus_per_task=8,
nodes_per_task=2,
logging_dir=LOGS_PATH,
pipeline=[
documents,
InferenceRunner(
rollout_fn=chunked_rollout,
config=InferenceConfig(
server_type="vllm",
model_name_or_path="deepseek-ai/DeepSeek-R1",
tp=16,
),
output_writer=JsonlWriter(OUTPUT_PATH),
),
],
)
if __name__ == "__main__":
# Run the pipeline
pipeline_executor.run()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ inference = [
"aiosqlite",
]
ray = [
"ray"
"ray[default]"
]
quality = [
"ruff>=0.1.5"
Expand Down
64 changes: 53 additions & 11 deletions src/datatrove/executor/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dataclasses
import json
import os
import random
import time
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Sequence
from typing import Callable
from typing import Callable, TypedDict

from datatrove.io import DataFolderLike, get_datafolder
from datatrove.pipeline.base import PipelineStep
Expand All @@ -20,6 +21,19 @@
from datatrove.utils.stats import PipelineStats


class DistributedEnvVars(TypedDict):
"""Required environment variables that must be set by get_distributed_env.

All values must be strings.
"""

datatrove_node_ips: str # comma-separated list of node IPs/hostnames
datatrove_cpus_per_task: str # number of CPUs per task
datatrove_mem_per_cpu: str # memory per CPU in GB
datatrove_gpus_on_node: str # number of GPUs on the node
datatrove_executor: str # executor type


class PipelineExecutor(ABC):
"""Base class for pipeline executors (local, slurm, etc.)

Expand Down Expand Up @@ -62,22 +76,50 @@ def world_size(self) -> int:
"""
return 0

def _run_for_rank(self, rank: int, local_rank: int = 0) -> PipelineStats:
@abstractmethod
def get_distributed_env(self, node_rank: int = -1) -> DistributedEnvVars:
"""
Returns a dictionary of environment variables to set for distributed execution.
This method is called by `_run_for_rank` to set up the distributed environment.

Args:
node_rank: node rank/ID. -1 means single node mode (default).

Returns: DistributedEnvVars dictionary with all required environment variables.
All values must be strings.
"""
pass

def _set_distributed_environment(self, node_rank: int):
env_vars = self.get_distributed_env(node_rank)
os.environ["DATATROVE_NODE_RANK"] = str(node_rank)
os.environ["DATATROVE_EXECUTOR"] = env_vars["datatrove_executor"]
os.environ["DATATROVE_NODE_IPS"] = env_vars["datatrove_node_ips"]
os.environ["DATATROVE_CPUS_PER_TASK"] = env_vars["datatrove_cpus_per_task"]
os.environ["DATATROVE_MEM_PER_CPU"] = env_vars["datatrove_mem_per_cpu"]
os.environ["DATATROVE_GPUS_ON_NODE"] = env_vars["datatrove_gpus_on_node"]

def _run_for_rank(self, rank: int, local_rank: int = 0, node_rank: int = -1) -> PipelineStats:
"""
Main executor's method. Sets up logging, pipes data from each pipeline step to the next, saves statistics
and marks tasks as completed.
and marks tasks as completed. We assume node_rank == 0 is the master node. node_rank == -1 means single node mode.
Completion is only marked on the master node, all other nodes are ignored in terms of job completion as we use 1-master, many-workers mode.
In this case it's master responsibility to check for workers completion and mark the job as complete.
Args:
rank: the rank that we want to run the pipeline for
local_rank: at the moment this is only used for logging.
Any task with local_rank != 0 will not print logs to console.

node_rank: node rank/ID for logging prefix. Logs will be prefixed with [NODE X] if node_rank != -1. We assume node_rank == 0 is the master node. -1 means single node mode (default).
Returns: the stats for this task

"""
if self.is_rank_completed(rank):
logger.info(f"Skipping {rank=} as it has already been completed.")
return PipelineStats()
logfile = add_task_logger(self.logging_dir, rank, local_rank)

self._set_distributed_environment(node_rank)

logfile = add_task_logger(self.logging_dir, rank, local_rank, node_rank=node_rank)
log_pipeline(self.pipeline)

if self.randomize_start_duration > 0:
Expand All @@ -97,13 +139,13 @@ def _run_for_rank(self, rank: int, local_rank: int = 0) -> PipelineStats:

logger.success(f"Processing done for {rank=}")

# stats
# stats - only save on master node in distributed setting (or when node_rank <= 0 for single node)
stats = PipelineStats(self.pipeline)
with self.logging_dir.open(f"stats/{rank:05d}.json", "w") as f:
stats.save_to_disk(f)
logger.info(stats.get_repr(f"Task {rank}"))
# completed
self.mark_rank_as_completed(rank)
if node_rank <= 0:
with self.logging_dir.open(f"stats/{rank:05d}.json", "w") as f:
stats.save_to_disk(f)
logger.info(stats.get_repr(f"Task {rank}"))
self.mark_rank_as_completed(rank)
except Exception as e:
logger.exception(e)
raise e
Expand Down
15 changes: 14 additions & 1 deletion src/datatrove/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import multiprocess

from datatrove.executor.base import PipelineExecutor
from datatrove.executor.base import DistributedEnvVars, PipelineExecutor
from datatrove.io import DataFolderLike
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.logging import logger
Expand Down Expand Up @@ -150,6 +150,19 @@ def run(self):
logger.success(stats.get_repr(f"All {self.local_tasks} tasks"))
return stats

def get_distributed_env(self, node_rank: int = -1) -> DistributedEnvVars:
"""Get distributed environment variables for LOCAL executor."""
# Default values for local execution - these can be overridden if needed
# For now, we'll use reasonable defaults

return DistributedEnvVars(
datatrove_node_ips="localhost",
datatrove_cpus_per_task="-1",
datatrove_mem_per_cpu="-1",
datatrove_gpus_on_node="-1",
datatrove_executor="LOCAL",
)

@property
def world_size(self) -> int:
"""
Expand Down
Loading