-
Notifications
You must be signed in to change notification settings - Fork 716
[feature]Token-Level Re-Inference for Fault Tolerance in vLLM-Ascend #5530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fault tolerance mechanism for vLLM on Ascend, enabling token-level re-inference to handle failures like network errors. The implementation includes a FaultAware thread for monitoring, a FaultTolerance decorator to wrap model execution, and a set of recovery handlers. While this is a significant feature, I've identified a few critical issues in the implementation that could compromise the robustness of the fault tolerance system itself. Specifically, the fault-aware thread is not resilient to distributed errors, and process groups are not correctly re-initialized after a fault, which would cause subsequent failures. Addressing these issues is crucial for the feature to work reliably.
| except Exception as e: | ||
| logger.error(f"Exception in fault aware handler:{e}") | ||
| if not threading.main_thread().is_alive(): | ||
| break | ||
| raise e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _handler_loop runs in a background thread to monitor for faults. If a distributed operation like gather or broadcast fails (e.g., due to a worker failure), it will raise a torch.distributed.DistError. The current implementation catches this exception, logs it, and then re-raises it, which terminates the fault-aware thread. This defeats the purpose of the fault tolerance mechanism, as it will stop monitoring for faults after the first one. The thread should handle distributed errors gracefully and continue its monitoring loop.
I suggest catching torch.distributed.DistError specifically and continuing the loop after a short delay, while letting other unexpected exceptions terminate the thread.
| except Exception as e: | |
| logger.error(f"Exception in fault aware handler:{e}") | |
| if not threading.main_thread().is_alive(): | |
| break | |
| raise e | |
| except torch.distributed.DistError as e: | |
| logger.warning(f"Fault aware handler caught a distributed error, will retry: {e}") | |
| time.sleep(self.interval_s) | |
| except Exception as e: | |
| logger.error(f"Exception in fault aware handler:{e}") | |
| if not threading.main_thread().is_alive(): | |
| break | |
| raise e |
| FaultAware( | ||
| self.rank,self.world_size,self.fault_queue,aware_event=self.aware_event | ||
| ).start() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FaultAware instance is created and started, but no reference to it is stored. This makes it impossible to interact with it later, for example, to re-initialize its process group after the main process group has been re-initialized during fault recovery. A reference to the FaultAware instance should be stored in self.
| FaultAware( | |
| self.rank,self.world_size,self.fault_queue,aware_event=self.aware_event | |
| ).start() | |
| self.fault_aware = FaultAware( | |
| self.rank,self.world_size,self.fault_queue,aware_event=self.aware_event | |
| ) | |
| self.fault_aware.start() |
| try: | ||
| torch_npu.npu.restart_device(torch.npu.current_device()) | ||
| torch.distributed.reinit_process_group(group=None, rebuild_link=False) | ||
| self.model_runner.execute_model_state = None | ||
| self._restore_essential_state(ctx.back_up) | ||
| reinit_status = RecoveryStatus.SUCCESS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _clean_fault method calls torch.distributed.reinit_process_group, which invalidates all existing process groups. However, it does not re-initialize the custom process groups used for fault tolerance (_recovery_group, _sync_group, and FaultAware._fault_aware_group). Subsequent distributed operations using these old groups will fail. All process groups must be re-created after re-initializing the distributed environment.
This requires the FaultAware instance to be stored as suggested in another comment.
| try: | |
| torch_npu.npu.restart_device(torch.npu.current_device()) | |
| torch.distributed.reinit_process_group(group=None, rebuild_link=False) | |
| self.model_runner.execute_model_state = None | |
| self._restore_essential_state(ctx.back_up) | |
| reinit_status = RecoveryStatus.SUCCESS | |
| try: | |
| torch_npu.npu.restart_device(torch.npu.current_device()) | |
| torch.distributed.reinit_process_group(group=None, rebuild_link=False) | |
| self.model_runner.execute_model_state = None | |
| self._restore_essential_state(ctx.back_up) | |
| self._init_recovery_group() | |
| self._init_sync_group() | |
| self.fault_aware.init_fault_aware_group() | |
| reinit_status = RecoveryStatus.SUCCESS |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?