mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 23:11:55 +00:00
prepare allgather input in advance
This commit is contained in:
parent
aaafb38851
commit
348520de5d
@ -544,12 +544,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
# and should not be updated
|
# and should not be updated
|
||||||
real_working_params = dict()
|
real_working_params = dict()
|
||||||
real_master_params = dict()
|
real_master_params = dict()
|
||||||
|
params_to_gather_buffer = dict()
|
||||||
|
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
master_params = self._master_param_groups_of_current_rank[group_id]
|
master_params = self._master_param_groups_of_current_rank[group_id]
|
||||||
working_params = self._working_param_groups[group_id]
|
working_params = self._working_param_groups[group_id]
|
||||||
real_working_params[group_id] = []
|
real_working_params[group_id] = []
|
||||||
real_master_params[group_id] = []
|
real_master_params[group_id] = []
|
||||||
|
params_to_gather_buffer[group_id] = []
|
||||||
working_grads = []
|
working_grads = []
|
||||||
for working_param, master_param in zip(working_params, master_params):
|
for working_param, master_param in zip(working_params, master_params):
|
||||||
# if a working param requires grad and has no grad
|
# if a working param requires grad and has no grad
|
||||||
@ -596,13 +598,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list
|
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device = get_accelerator().get_current_device()
|
||||||
|
for group_id in range(self.num_param_groups):
|
||||||
|
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||||
|
for idx, master_param in enumerate(master_working_param):
|
||||||
|
param_to_gather = master_param.to(device).to(self._dtype)
|
||||||
|
params_to_gather_buffer[group_id].append(param_to_gather)
|
||||||
|
|
||||||
# update working partition updated by the current rank
|
# update working partition updated by the current rank
|
||||||
device = get_accelerator().get_current_device()
|
device = get_accelerator().get_current_device()
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||||
for idx, master_param in enumerate(master_working_param):
|
for idx, master_param in enumerate(master_working_param):
|
||||||
working_param = real_working_params[group_id][idx]
|
working_param = real_working_params[group_id][idx]
|
||||||
param_to_gather = master_param.to(device).to(self._dtype)
|
param_to_gather = params_to_gather_buffer[group_id][idx]
|
||||||
pg = self.param_to_pg[working_param]
|
pg = self.param_to_pg[working_param]
|
||||||
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
||||||
if self._overlap_allgather:
|
if self._overlap_allgather:
|
||||||
@ -634,6 +643,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
if not tensor_bucket.is_empty():
|
if not tensor_bucket.is_empty():
|
||||||
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
||||||
|
|
||||||
|
del params_to_gather_buffer
|
||||||
|
|
||||||
def _compute_grad_norm(
|
def _compute_grad_norm(
|
||||||
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
|
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
|
||||||
) -> float:
|
) -> float:
|
||||||
|
Loading…
Reference in New Issue
Block a user