[zero] fix memory leak for zero2 (#1955)

This commit is contained in:
HELSON
2022-11-16 11:43:24 +08:00
committed by GitHub
parent 60abd86d6a
commit 7066dfbf82
2 changed files with 171 additions and 9 deletions

View File

@@ -48,7 +48,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
verbose=False,
# communication
reduce_bucket_size=500000000,
reduce_bucket_size=50000000,
communication_dtype=torch.float16,
overlap_communication=False,
@@ -125,14 +125,14 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self._optimizer.param_groups):
params = param_group['params']
group_params = param_group['params']
# add the fp16 params to fp16_param_groups for bookkeeping
self._fp16_param_groups[group_id] = params
self._fp16_param_groups[group_id] = group_params
# assign parameters to ranks
# the params in the list are sorted
params_per_rank = self._partition_param_list(params)
params_per_rank = self._partition_param_list(group_params)
# store the mapping between param to rank
# each param should belong to only one rank
@@ -143,14 +143,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# move to cpu to make room to create the flat tensor
# move_tensor(params, device='cpu')
for param in params:
for param in group_params:
param.data = param.data.cpu()
# flatten the reordered tensors
for rank in range(self._world_size):
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
flat_tensor = flatten(tensor_list)
flat_tensor = flat_tensor.cuda()
with torch.no_grad():
flat_tensor = flatten(tensor_list)
flat_tensor = flat_tensor.data.cuda()
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
# sync parameters
@@ -161,7 +162,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# create a copy of fp32 weights of the parameters for which this rank is responsible
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
fp32_flat_current_rank = fp16_flat_current_rank.clone().float().detach()
fp32_flat_current_rank = fp16_flat_current_rank.float()
device = 'cpu' if self._cpu_offload else get_current_device()
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
fp32_flat_current_rank.requires_grad = True
@@ -384,7 +385,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# torch.optim.Optimizer methods
################################
def backward(self, loss, retain_graph=True):
def backward(self, loss, retain_graph=False):
loss = self.loss_scale * loss
loss.backward(retain_graph=retain_graph)