prepare allgather input in advance

This commit is contained in:
BurkeHulk 2024-12-17 17:44:54 +08:00
parent aaafb38851
commit 348520de5d

View File

@ -544,12 +544,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# and should not be updated # and should not be updated
real_working_params = dict() real_working_params = dict()
real_master_params = dict() real_master_params = dict()
params_to_gather_buffer = dict()
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id] master_params = self._master_param_groups_of_current_rank[group_id]
working_params = self._working_param_groups[group_id] working_params = self._working_param_groups[group_id]
real_working_params[group_id] = [] real_working_params[group_id] = []
real_master_params[group_id] = [] real_master_params[group_id] = []
params_to_gather_buffer[group_id] = []
working_grads = [] working_grads = []
for working_param, master_param in zip(working_params, master_params): for working_param, master_param in zip(working_params, master_params):
# if a working param requires grad and has no grad # if a working param requires grad and has no grad
@ -596,13 +598,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list
} }
device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, master_param in enumerate(master_working_param):
param_to_gather = master_param.to(device).to(self._dtype)
params_to_gather_buffer[group_id].append(param_to_gather)
# update working partition updated by the current rank # update working partition updated by the current rank
device = get_accelerator().get_current_device() device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"] master_working_param = self.optim.param_groups[group_id]["params"]
for idx, master_param in enumerate(master_working_param): for idx, master_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx] working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype) param_to_gather = params_to_gather_buffer[group_id][idx]
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
padded_working_param = self._working_param_to_padded_working_param[working_param] padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather: if self._overlap_allgather:
@ -634,6 +643,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if not tensor_bucket.is_empty(): if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
del params_to_gather_buffer
def _compute_grad_norm( def _compute_grad_norm(
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2 self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
) -> float: ) -> float: