[pipeline] test pure pipeline process using llama (#4218)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision

* add pure pipeline test

* fixed version

* fixed version

* pure pipeline
This commit is contained in:
Jianghai 2023-07-25 14:31:21 +08:00 committed by Hongxin Liu
parent 36e546b2cc
commit d0807122e2
2 changed files with 30 additions and 18 deletions

View File

@ -9,6 +9,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d from torch.distributed import distributed_c10d as c10d
from version_parser.version import Version
from .stage_manager import PipelineStageManager from .stage_manager import PipelineStageManager
@ -61,17 +62,6 @@ def _broadcast_object_list(object_list: List[Any],
c10d._warn_not_in_group("broadcast_object_list") c10d._warn_not_in_group("broadcast_object_list")
return return
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
if torch.__version__ >= "1.13.0":
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=device) for obj in object_list])
else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
is_nccl_backend = c10d._check_for_nccl_backend(group) is_nccl_backend = c10d._check_for_nccl_backend(group)
current_device = None current_device = None
@ -83,6 +73,18 @@ def _broadcast_object_list(object_list: List[Any],
current_device = torch.device("cpu") current_device = torch.device("cpu")
if is_nccl_backend: if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device()) current_device = torch.device("cuda", torch.cuda.current_device())
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
if Version(torch.__version__) >= Version("1.13.0"):
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
if is_nccl_backend: if is_nccl_backend:
object_sizes_tensor = object_sizes_tensor.to(current_device) object_sizes_tensor = object_sizes_tensor.to(current_device)

View File

@ -1,3 +1,4 @@
import copy
import random import random
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple from typing import Any, Callable, Iterator, List, Optional, Tuple
@ -6,7 +7,6 @@ import numpy as np
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@ -94,10 +94,10 @@ def execute_pipeline(
return outputs return outputs
class data_iter(): class data_loader():
def __getitem__(self, x): def __getitem__(self, x):
return torch.randint(0, 100, (4, 128)).cuda() return torch.ones((4, 128), dtype=torch.int).cuda() * 10
def loss(x, y): def loss(x, y):
@ -127,20 +127,30 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name != 'transformers_llama':
continue
num_microbatches = 2 num_microbatches = 2
org_model = model_fn().cuda() org_model = model_fn().cuda()
data_iter = iter(data_loader())
model_copy = copy.deepcopy(org_model)
batch = next(data_iter)
with torch.no_grad():
y = model_copy(batch)
org_loss = loss(batch, y)
optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3)
#dataloader=prepare_dataloader(dataset=dataset['train'],batch_size=4)
schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager)
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism, enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager) pipeline_stage_manager=stage_manager)
pipelined_model = PipelinedModel(org_model, shard_config, stage_manager) pipelined_model = PipelinedModel(org_model, shard_config, stage_manager)
pp_optimizer = PipelineOptimizer(optimizer, pipelined_model) pp_optimizer = PipelineOptimizer(optimizer, pipelined_model)
data_it = iter(data_iter()) results = execute_pipeline(data_iter, pipelined_model, loss, pp_optimizer, schedule=schedule)
results = execute_pipeline(data_it, pipelined_model, loss, pp_optimizer, schedule=schedule)
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
assert results['loss'] is not None assert results['loss'] == org_loss
else:
assert results['loss'] is None
assert results['outputs'] is None assert results['outputs'] is None
torch.cuda.empty_cache() torch.cuda.empty_cache()