mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
add interleaved pipeline, fix naive amp and update pipeline model initializer (#80)
This commit is contained in:
@@ -6,7 +6,7 @@ from colossalai.logging import get_dist_logger
|
||||
import colossalai
|
||||
import torch
|
||||
import os
|
||||
from colossalai.builder import PipelineModelInitializer
|
||||
from colossalai.builder import build_pipeline_model_from_cfg
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_dataloader, MultiTimer
|
||||
from colossalai.nn.loss import CrossEntropyLoss2D
|
||||
@@ -50,8 +50,7 @@ def test_hybrid_parallel():
|
||||
# suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w')
|
||||
|
||||
# build vit-t-32
|
||||
initializer = PipelineModelInitializer(vit_t_2d.model_cfg, num_chunks=1)
|
||||
model = initializer.initialize()
|
||||
model = build_pipeline_model_from_cfg(vit_t_2d.model_cfg, num_chunks=1)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
@@ -139,4 +138,4 @@ def test_hybrid_parallel():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
test_hybrid_parallel()
|
||||
|
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.builder.pipeline import PipelineModelInitializer
|
||||
from colossalai.builder.pipeline import build_pipeline_model_from_cfg
|
||||
from colossalai.core import global_context
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -28,7 +28,7 @@ def run_partition(rank, world_size):
|
||||
logger.info('finished initialization')
|
||||
|
||||
# build model
|
||||
model = PipelineModelInitializer(global_context.config.model, 1, verbose=True).initialize()
|
||||
model = build_pipeline_model_from_cfg(global_context.config.model, 1, verbose=True)
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
logger.info('model is created')
|
||||
|
||||
|
@@ -8,7 +8,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import model
|
||||
|
||||
from colossalai.builder import PipelineModelInitializer
|
||||
from colossalai.builder import build_pipeline_model_from_cfg
|
||||
from colossalai.communication import p2p as p2p_communication
|
||||
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
@@ -39,7 +39,7 @@ def run_schedule(rank, world_size):
|
||||
backend='nccl')
|
||||
|
||||
# build model
|
||||
model = PipelineModelInitializer(gpc.config.model, 1).initialize()
|
||||
model = build_pipeline_model_from_cfg(gpc.config.model, 1)
|
||||
print_rank_0('model is created')
|
||||
|
||||
train_dataset = CIFAR10(
|
||||
|
Reference in New Issue
Block a user