[hotfix] fix zero comm buffer init (#6154)

This commit is contained in:
Hongxin Liu 2024-12-10 16:46:15 +08:00 committed by GitHub
parent 8d826a336e
commit de3d371f65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -371,7 +371,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for i, sz in enumerate(bucket_store.sizes):
grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i]
flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz))
received_grad = torch.zeros_like(flat_grads_list[0])
received_grad = torch.empty_like(flat_grads_list[0])
if self._fp8_communication:
reduce_scatter_fp8(
received_grad,