From 272ebfb57d9ba87f25d42f00ec40b9097956aacb Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 10 Mar 2022 19:28:03 +0800 Subject: [PATCH] [bug] shard param during initializing the ShardedModelV2 (#381) --- colossalai/zero/init_ctx/init_context.py | 2 +- colossalai/zero/sharded_model/sharded_model_v2.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 3cc32f49e..17e89cbf7 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -139,7 +139,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.convert_fp16: param.data = param.data.to(torch.half) if param.grad is not None: - param.grad = param.grad.to(torch.half).to(target_device) + param.grad = param.grad.to(torch.half) # move torch parameters to the target device param.data = param.data.to(target_device) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index a1172fdaa..55e7b26f0 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,3 +1,4 @@ +from ast import Try import functools from collections import OrderedDict from typing import Any, Optional @@ -54,7 +55,7 @@ class ShardedModelV2(nn.Module): # In case user didn't use ZeroInitContext for param in self.module.parameters(): if not hasattr(param, 'col_attr'): - param.col_attr = ShardedParamV2(param, process_group) + param.col_attr = ShardedParamV2(param, process_group, rm_torch_payload=True) if self.shard_param: self.shard_strategy.shard([param.col_attr.data])