mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[refactor] moving memtracer to gemini (#801)
This commit is contained in:
@@ -12,10 +12,10 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
|
||||
def get_tensor_shape():
|
||||
if hasattr(gpc.config, 'TENSOR_SHAPE'):
|
||||
return gpc.config.TENSOR_SHAPE
|
||||
@@ -23,7 +23,8 @@ def get_tensor_shape():
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
return None
|
||||
|
||||
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
|
||||
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(
|
||||
gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
else:
|
||||
@@ -34,12 +35,12 @@ def get_tensor_shape():
|
||||
seq_size = 1
|
||||
|
||||
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
|
||||
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.HIDDEN_SIZE)
|
||||
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE)
|
||||
return tensor_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def pack_return_tensors(return_tensors):
|
||||
output, label = tuple(zip(*return_tensors))
|
||||
if isinstance(output[0], torch.Tensor):
|
||||
@@ -114,7 +115,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
def pre_processing(self, engine):
|
||||
# TODO: remove this after testing new zero with pipeline parallelism
|
||||
model = engine.model
|
||||
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
|
||||
if isinstance(model, (NaiveAMPModel)) or hasattr(model, 'colo_attr'):
|
||||
self.dtype = torch.half
|
||||
model = model.model
|
||||
sig = inspect.signature(model.forward)
|
||||
@@ -125,7 +126,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
def _call_engine(model, input_tensor, batch_data):
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
sig = inspect.signature(model.model.forward)
|
||||
elif isinstance(model, ShardedModelV2):
|
||||
elif hasattr(model, 'colo_attr'):
|
||||
sig = inspect.signature(model.module.forward)
|
||||
else:
|
||||
sig = inspect.signature(model.forward)
|
||||
@@ -385,7 +386,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
||||
def pre_processing(self, engine):
|
||||
if isinstance(engine.model, ShardedModelV2):
|
||||
# FIXME(jiaruifang) we shall not use ShardedModelV2 in pipeline mode, due to circular dependency.
|
||||
if hasattr(engine.model, 'colo_attr'):
|
||||
self.dtype = torch.half
|
||||
elif isinstance(engine.model[0], NaiveAMPModel):
|
||||
self.dtype = torch.half
|
||||
|
Reference in New Issue
Block a user