Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions vllm_ascend/worker/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from enum import Enum

class FaultToleranceLevel(Enum):
"""
Fault tolerance level
level 0: disable fault tolerance
level 1: enable base fault tolerance for L1->L2 Network Error
"""
OFF = 0
BASIC = 1

class FaultStatus(Enum):
"""
Fault status which fault_tolerance put into fault_queue
"""
ACTIVE = torch.tensor([0])
FORCE_STOP = torch.tensor([1])
NETWORK_ERR = torch.tensor([2])

class FaultCommand:
"""
Fault command which rank 0 broadcast in fault_aware
"""
INIT_CMD = torch.tensor([0])
SILENCE_CMD = torch.tensor([1])
STOP_DEVICE_CMD = torch.tensor([2])

class RecoveryStatus:
"""
Recovery status
"""
SUCCESS = torch.tensor([0])
FAILED = torch.tensor([1])

class FaultAction:
RAISE_EXCEPTION = torch.tensor([0])
RETURN = torch.tensor([1])
RECOMPUTE = torch.tensor([2])
170 changes: 170 additions & 0 deletions vllm_ascend/worker/fault_aware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import time
import threading
import torch
import queue
import torch_npu

from datetime import timedelta
from vllm.logger import logger
from vllm_ascend.worker.common import FaultStatus,FaultCommand

class FaultAware:
_fault_aware_group = None

def __init__(self,rank:int,world_size:int,fault_queue:queue.Queue,interval_s=1,
aware_event:threading.Event=None):
self.rank = rank
self.world_size = world_size
self.npu_id = torch.npu.current_device()
self.fault_queue = fault_queue
self.interval_s = interval_s

Check failure on line 21 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible default for argument "aware_event" (default has type "None", argument has type "Event") [assignment]

Check failure on line 21 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible default for argument "aware_event" (default has type "None", argument has type "Event") [assignment]
self._fault_aware_thread = None
self.aware_event = aware_event

def init_fault_aware_group(self):
"""
Initialize the Torch process group for fault aware.
Rank 0 is the coordinator rank,
the other ranks are the normal rank,which is used for sending status to rank 0.

Rank 0 will collect the status from all the other ranks and broadcast stop_device
command to all the other ranks through `_fault_aware_group`
"""
if not torch.distributed.is_initialized():
raise RuntimeError("Default torch process group must be initialized")

if not torch.distributed.is_gloo_available():
raise RuntimeError("Gloo backend must be available")

logger.info(
f"init fault aware process group: "
f"rank={self.rank},world_size={self.world_size},backend=gloo"
)
try:
FaultAware._fault_aware_group = torch.distributed.new_group(
ranks=None,
timeout=timedelta(minutes=5),
backend="gloo"
)
logger.info(f"Rank {self.rank} successfully initialized fault aware process group")
except Exception as e:
logger.error(f"Rank {self.rank} failed to initialize fault aware group:{e}")
raise e

def start(self):
"""Start the fault aware"""
if self._fault_aware_thread is not None and self._fault_aware_thread.is_alive():
logger.warning("Fault aware thread is already running")
return
self.init_fault_aware_group()
logger.info(f"Rank {self.rank} starting fault aware thread")
try:
self._fault_aware_thread = threading.Thread(
target=self._handler_loop,
name=f"FaultAware-Rank{self.rank}",
daemon=True,
)
self._fault_aware_thread.start()
logger.info(f"Rank {self.rank} successfully started fault aware thread")
except Exception as e:

Check failure on line 70 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible types in assignment (expression has type "Thread", variable has type "None") [assignment]

Check failure on line 70 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible types in assignment (expression has type "Thread", variable has type "None") [assignment]
logger.error(f"Rank {self.rank} failed to start fault aware thread:{e}")
raise e

def _handler_loop(self):
current_status = FaultStatus.ACTIVE.value

Check failure on line 75 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "start" [attr-defined]

Check failure on line 75 in vllm_ascend/worker/fault_aware.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "start" [attr-defined]
status_list = (
[torch.zeros([1],dtype=torch.int64) for _ in range(self.world_size)]
if self.rank == 0
else None
)
while True:
try:
current_status = self._update_status_from_queue(current_status)
self._gather_statuses(current_status,status_list)
fault_cmd = self._determine_fault_command(status_list)
self.broadcast_command(fault_cmd)
current_status = self._execute_command(fault_cmd, current_status)
except Exception as e:
logger.error(f"Exception in fault aware handler:{e}")
if not threading.main_thread().is_alive():
break
raise e
Comment on lines +88 to +92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _handler_loop runs in a background thread to monitor for faults. If a distributed operation like gather or broadcast fails (e.g., due to a worker failure), it will raise a torch.distributed.DistError. The current implementation catches this exception, logs it, and then re-raises it, which terminates the fault-aware thread. This defeats the purpose of the fault tolerance mechanism, as it will stop monitoring for faults after the first one. The thread should handle distributed errors gracefully and continue its monitoring loop.

I suggest catching torch.distributed.DistError specifically and continuing the loop after a short delay, while letting other unexpected exceptions terminate the thread.

Suggested change
except Exception as e:
logger.error(f"Exception in fault aware handler:{e}")
if not threading.main_thread().is_alive():
break
raise e
except torch.distributed.DistError as e:
logger.warning(f"Fault aware handler caught a distributed error, will retry: {e}")
time.sleep(self.interval_s)
except Exception as e:
logger.error(f"Exception in fault aware handler:{e}")
if not threading.main_thread().is_alive():
break
raise e

logger.info(f"Fault aware handler exiting")

def _update_status_from_queue(self,current_status):
try:
msg = self.fault_queue.get_nowait()
if msg:
logger.info(f"Received new status: {msg.name},updating status")
current_status = msg.value
except queue.Empty:
if not threading.main_thread().is_alive():
raise RuntimeError("Main thread is not alive")
except Exception as e:
logger.error(f"Error reading from fault queue:{e}")
raise e

return current_status

def _gather_statuses(self,current_status,status_list):
""" Gather statuses from all ranks to rank 0"""
try:
torch.distributed.gather(
tensor=current_status,
gather_list=status_list,
dst=0,
group=FaultAware._fault_aware_group,
)
except Exception as e:
logger.error(f"Rank {self.rank} failed to gather status:{e}")
raise e

def _determine_fault_command(self,status_list):
"""Determine the command to run"""
fault_cmd = FaultCommand.INIT_CMD
if self.rank == 0:
if all(torch.equal(t, FaultStatus.ACTIVE.value) for t in status_list):
fault_cmd = FaultCommand.SILENCE_CMD
else:
fault_cmd = FaultCommand.STOP_DEVICE_CMD
return fault_cmd

def broadcast_command(self,fault_cmd):
""" BroadCast the fault command to all ranks"""
try:
torch.distributed.broadcast(
tensor=fault_cmd,
src=0,
group=FaultAware._fault_aware_group,
)
except Exception as e:
logger.error(f"Rank {self.rank} failed to broadcast command:{e}")
raise e

def _execute_command(self,fault_cmd,current_status):
""" Execute the fault command"""
if torch.equal(fault_cmd,FaultCommand.SILENCE_CMD):
time.sleep(self.interval_s)
elif torch.equal(fault_cmd,FaultCommand.STOP_DEVICE_CMD):
logger.info(f"Error detected in cluster,executing stop_device on NPU {self.npu_id}")
self._stop_device()
current_status = FaultStatus.ACTIVE.value
else:
logger.error(f"Unknown fault command received:{fault_cmd}")

return current_status

def _stop_device(self):
try:
torch_npu.npu.stop_device(self.npu_id)
logger.info(f"NPU {self.npu_id} execute stop device")

if self.aware_event:
logger.info("Waiting for recovery event")
self.aware_event.wait()
self.aware_event.clear()
logger.info("Recovery event received,resuming operation")
except Exception as e:
logger.error(f"Error during stop_device or recovery:{e}")
raise e
Loading
Loading