diff --git a/swift/megatron/utils/megatron_lm_utils.py b/swift/megatron/utils/megatron_lm_utils.py index 1710d333aa..d7a7b16247 100644 --- a/swift/megatron/utils/megatron_lm_utils.py +++ b/swift/megatron/utils/megatron_lm_utils.py @@ -82,6 +82,26 @@ def _initialize_mpu(args): f'VPP: {args.virtual_pipeline_model_parallel_size}, CP: {args.context_parallel_size}, ' f'EP: {args.expert_model_parallel_size}, ETP: {args.expert_tensor_parallel_size}') + # before model weights are loaded onto GPU. + # + # Background: PyTorch lazily initializes NCCL communicators — dist.new_group() only registers + # metadata; the actual ncclCommInitRankConfig bootstrap runs on first collective use. + # For MoE models this group's first use is the very first optimizer step, by which point GPU + # memory is near its limit (~125-130 GiB on H200). If any rank cannot allocate the NCCL + # temporary buffer at that moment, its bootstrap thread stalls silently — NCCL_TIMEOUT does + # NOT cover the bootstrap phase — and all other ranks wait forever (deadlock). + # + # Calling a barrier here forces bootstrap while GPU memory is still empty, eliminating the race. + # get_inter_distributed_optimizer_instance_group returns _INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP, + # which is None for dense models or when EP=1, so the guard is safe in all configurations. + if hasattr(mpu, 'get_inter_distributed_optimizer_instance_group'): + inter_ep_dp_group = mpu.get_inter_distributed_optimizer_instance_group(check_initialized=False) + if inter_ep_dp_group is not None: + logger.info('Pre-initializing INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP NCCL communicator ' + 'to avoid lazy-init deadlock during first optimizer step.') + torch.distributed.barrier(group=inter_ep_dp_group, device_ids=[torch.cuda.current_device()]) + torch.cuda.synchronize() + def initialize_megatron(args): # Pytorch distributed.