From 05e33b257886d78322b2d94badb632a172f9cd9b Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 25 Mar 2022 18:23:25 +0800 Subject: [PATCH] [zero] fix grad offload (#528) * [zero] fix grad offload * polish code --- colossalai/utils/memory_utils/utils.py | 17 +++++++++++++++++ .../zero/sharded_model/sharded_model_v2.py | 14 +++++++------- .../zero/sharded_optim/sharded_optim_v2.py | 4 ++++ 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/colossalai/utils/memory_utils/utils.py b/colossalai/utils/memory_utils/utils.py index df41ac95d..763ec3358 100644 --- a/colossalai/utils/memory_utils/utils.py +++ b/colossalai/utils/memory_utils/utils.py @@ -114,3 +114,20 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None: GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload) t_payload.data = t_payload.data.cpu() GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload) + + +def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor: + """ + Clone a model data tensor + + Args: + t (Union[ShardedTensor, torch.Tensor]): a model data tensor + target_device (torch.device): the target device + Returns: + torch.Tensor: a cloned torch tensor + """ + t_payload = t.payload if isinstance(t, ShardedTensor) else t + + ret = t_payload.to(target_device) + GLOBAL_MODEL_DATA_TRACER.add_tensor(ret) + return ret diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 9860f91dd..ee1016f64 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -11,7 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.logging import get_dist_logger -from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity +from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity, colo_model_tensor_clone from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer @@ -198,16 +198,16 @@ class ShardedModelV2(nn.Module): # the shape `grad` is the same as unsharded param # So we can just use `view(-1)` to ensure grad is a flat tensor shard if self.reuse_fp16_shard: - grad = p.col_attr.sharded_data_tensor.payload + grad_payload = p.col_attr.sharded_data_tensor.payload else: - grad = cast_tensor_to_fp32(p.col_attr.fp16_grad) + grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad) if p.col_attr.offload_grad: - colo_model_data_move_to_cpu(grad) + grad_payload = colo_model_tensor_clone(grad_payload, torch.device('cpu')) if p.col_attr.fp32_grad is not None: assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' - p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad)) - grad = p.col_attr.fp32_grad - p.grad.data = grad + p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad)) + grad_payload = p.col_attr.fp32_grad + p.grad.data = grad_payload p.col_attr.fp16_grad = None p.col_attr.fp32_grad = None diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 3ba5fa4bd..51e1e9aa4 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -9,6 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from torch import Tensor @@ -217,6 +218,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # We must set grad to None # Because we will judge whether local grad accumulation # is enabled by wheter grad is None + for group in self.param_groups: + for p in group['params']: + GLOBAL_MODEL_DATA_TRACER.delete_tensor(p.grad) self.optim.zero_grad(set_to_none=True) def sync_grad(self):