[refactor] moving memtracer to gemini (#801)

This commit is contained in:
Jiarui Fang
2022-04-19 10:13:08 +08:00
committed by GitHub
parent 8711c706f4
commit 4d9332b4c5
24 changed files with 102 additions and 87 deletions

View File

@@ -12,10 +12,10 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_model import ShardedModelV2
from ._base_schedule import BaseSchedule
def get_tensor_shape():
if hasattr(gpc.config, 'TENSOR_SHAPE'):
return gpc.config.TENSOR_SHAPE
@@ -23,7 +23,8 @@ def get_tensor_shape():
if not gpc.is_initialized(ParallelMode.PIPELINE):
return None
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(
gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
if gpc.is_initialized(ParallelMode.DATA):
dp_size = gpc.get_world_size(ParallelMode.DATA)
else:
@@ -34,12 +35,12 @@ def get_tensor_shape():
seq_size = 1
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES,
gpc.config.HIDDEN_SIZE)
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE)
return tensor_shape
else:
return None
def pack_return_tensors(return_tensors):
output, label = tuple(zip(*return_tensors))
if isinstance(output[0], torch.Tensor):
@@ -114,7 +115,7 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine):
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
if isinstance(model, (NaiveAMPModel)) or hasattr(model, 'colo_attr'):
self.dtype = torch.half
model = model.model
sig = inspect.signature(model.forward)
@@ -125,7 +126,7 @@ class PipelineSchedule(BaseSchedule):
def _call_engine(model, input_tensor, batch_data):
if isinstance(model, NaiveAMPModel):
sig = inspect.signature(model.model.forward)
elif isinstance(model, ShardedModelV2):
elif hasattr(model, 'colo_attr'):
sig = inspect.signature(model.module.forward)
else:
sig = inspect.signature(model.forward)
@@ -385,7 +386,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.num_model_chunks = num_model_chunks
def pre_processing(self, engine):
if isinstance(engine.model, ShardedModelV2):
# FIXME(jiaruifang) we shall not use ShardedModelV2 in pipeline mode, due to circular dependency.
if hasattr(engine.model, 'colo_attr'):
self.dtype = torch.half
elif isinstance(engine.model[0], NaiveAMPModel):
self.dtype = torch.half