diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f268e499584..272385b59f5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -102,6 +102,39 @@ def get_sliding_windows() -> int: def init_cpu_threads_env(rank_id: int, world_size: int): + import psutil + allowed_cpus = psutil.Process().cpu_affinity() + if len(allowed_cpus) < psutil.cpu_count(logical=True): + _init_cpu_threads_env_use_allowed(rank_id, world_size, allowed_cpus) + else: + _init_cpu_threads_env_use_all(rank_id, world_size) + +def _init_cpu_threads_env_use_allowed(rank_id: int, world_size: int, allowed_cpus: list): + import importlib.util + + if os.getenv("OMP_NUM_THREADS") is None: + num_cpus_per_rank = max(int(len(allowed_cpus) / world_size), 1) + else: + num_cpus_per_rank = min(int(os.getenv("OMP_NUM_THREADS")), len(allowed_cpus)) + + if importlib.util.find_spec("numa") is not None: + import numa + + slice_info = f"slice {rank_id+1}/{world_size} of externally allowed {len(allowed_cpus)} CPUs" + allowed_mems = numa.memory.get_membind_nodes() + cpu_start = num_cpus_per_rank * rank_id + allowed_cpus_for_rank = allowed_cpus[cpu_start : cpu_start + num_cpus_per_rank] + numa.schedule.run_on_cpus(0, *allowed_cpus_for_rank) + effective_allowed_cpus = numa.schedule.get_affinitive_cpus(0) + else: + slice_info = "externally allowed, cannot import numa for slicing" + allowed_mems = "n/a" + effective_allowed_cpus = allowed_cpus + num_threads = num_cpus_per_rank + torch.set_num_threads(num_threads) + logger.info(f"affinity={effective_allowed_cpus} ({slice_info}), membind={allowed_mems}, threads={num_threads}") + +def _init_cpu_threads_env_use_all(rank_id: int, world_size: int): import importlib.util if importlib.util.find_spec("numa") is not None: