mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
prepare allgather input in advance
This commit is contained in:
parent
aaafb38851
commit
348520de5d
@ -544,12 +544,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
# and should not be updated
|
||||
real_working_params = dict()
|
||||
real_master_params = dict()
|
||||
params_to_gather_buffer = dict()
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_params = self._master_param_groups_of_current_rank[group_id]
|
||||
working_params = self._working_param_groups[group_id]
|
||||
real_working_params[group_id] = []
|
||||
real_master_params[group_id] = []
|
||||
params_to_gather_buffer[group_id] = []
|
||||
working_grads = []
|
||||
for working_param, master_param in zip(working_params, master_params):
|
||||
# if a working param requires grad and has no grad
|
||||
@ -596,13 +598,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list
|
||||
}
|
||||
|
||||
device = get_accelerator().get_current_device()
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, master_param in enumerate(master_working_param):
|
||||
param_to_gather = master_param.to(device).to(self._dtype)
|
||||
params_to_gather_buffer[group_id].append(param_to_gather)
|
||||
|
||||
# update working partition updated by the current rank
|
||||
device = get_accelerator().get_current_device()
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, master_param in enumerate(master_working_param):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
param_to_gather = master_param.to(device).to(self._dtype)
|
||||
param_to_gather = params_to_gather_buffer[group_id][idx]
|
||||
pg = self.param_to_pg[working_param]
|
||||
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
||||
if self._overlap_allgather:
|
||||
@ -634,6 +643,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
||||
|
||||
del params_to_gather_buffer
|
||||
|
||||
def _compute_grad_norm(
|
||||
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
|
||||
) -> float:
|
||||
|
Loading…
Reference in New Issue
Block a user