mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[zero] fix memory leak for zero2 (#1955)
This commit is contained in:
@@ -48,7 +48,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
verbose=False,
|
||||
|
||||
# communication
|
||||
reduce_bucket_size=500000000,
|
||||
reduce_bucket_size=50000000,
|
||||
communication_dtype=torch.float16,
|
||||
overlap_communication=False,
|
||||
|
||||
@@ -125,14 +125,14 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
# partition these param groups for data parallel training
|
||||
# and add buffers to parameter store for future access
|
||||
for group_id, param_group in enumerate(self._optimizer.param_groups):
|
||||
params = param_group['params']
|
||||
group_params = param_group['params']
|
||||
|
||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||
self._fp16_param_groups[group_id] = params
|
||||
self._fp16_param_groups[group_id] = group_params
|
||||
|
||||
# assign parameters to ranks
|
||||
# the params in the list are sorted
|
||||
params_per_rank = self._partition_param_list(params)
|
||||
params_per_rank = self._partition_param_list(group_params)
|
||||
|
||||
# store the mapping between param to rank
|
||||
# each param should belong to only one rank
|
||||
@@ -143,14 +143,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
# move to cpu to make room to create the flat tensor
|
||||
# move_tensor(params, device='cpu')
|
||||
for param in params:
|
||||
for param in group_params:
|
||||
param.data = param.data.cpu()
|
||||
|
||||
# flatten the reordered tensors
|
||||
for rank in range(self._world_size):
|
||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||
flat_tensor = flatten(tensor_list)
|
||||
flat_tensor = flat_tensor.cuda()
|
||||
with torch.no_grad():
|
||||
flat_tensor = flatten(tensor_list)
|
||||
flat_tensor = flat_tensor.data.cuda()
|
||||
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
|
||||
|
||||
# sync parameters
|
||||
@@ -161,7 +162,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.clone().float().detach()
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
||||
device = 'cpu' if self._cpu_offload else get_current_device()
|
||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
||||
fp32_flat_current_rank.requires_grad = True
|
||||
@@ -384,7 +385,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
# torch.optim.Optimizer methods
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=True):
|
||||
def backward(self, loss, retain_graph=False):
|
||||
loss = self.loss_scale * loss
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
|
Reference in New Issue
Block a user