[polish] use GLOBAL_MODEL_DATA_TRACER (#417)

This commit is contained in:
Jiarui Fang
2022-03-15 11:29:46 +08:00
committed by GitHub
parent 23ba3fc450
commit 56bb412e72
8 changed files with 25 additions and 25 deletions

View File

@@ -3,7 +3,7 @@ import functools
import torch
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
# Inserts _post_init_method at the end of init method
@@ -153,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
if param.col_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)