mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[polish] use GLOBAL_MODEL_DATA_TRACER (#417)
This commit is contained in:
@@ -3,7 +3,7 @@ import functools
|
||||
import torch
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
# Inserts _post_init_method at the end of init method
|
||||
|
||||
@@ -153,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||
ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
if param.col_attr.grad and self.shard_grad:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||
ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
|
||||
Reference in New Issue
Block a user