From 7172459e74a46c7a1d9bdcdc5022d3757e4b88d6 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 28 Nov 2023 16:54:42 +0800 Subject: [PATCH] [shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088) * [shardformer] implement policy for all GPT-J models and test * [shardformer] support interleaved pipeline parallel for bert finetune * [shardformer] shardformer support falcon (#4883) * [shardformer]: fix interleaved pipeline for bert model (#5048) * [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093) * Add Mistral support for Shardformer (#5103) * [shardformer] add tests to mistral (#5105) --------- Co-authored-by: Pengtai Xu Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: eric8607242 --- .github/workflows/example_check_on_pr.yml | 2 +- .../booster/plugin/hybrid_parallel_plugin.py | 33 +- colossalai/legacy/zero/gemini/__init__.py | 3 + .../zero/gemini/colo_init_context.py | 0 .../pipeline/schedule/interleaved_pp.py | 180 ++-- colossalai/pipeline/stage_manager.py | 55 +- colossalai/shardformer/README.md | 2 + colossalai/shardformer/layer/normalization.py | 4 +- colossalai/shardformer/modeling/falcon.py | 772 ++++++++++++++++ colossalai/shardformer/modeling/gptj.py | 824 ++++++++++++++++++ colossalai/shardformer/modeling/mistral.py | 73 ++ .../shardformer/policies/auto_policy.py | 36 + .../shardformer/policies/base_policy.py | 31 +- colossalai/shardformer/policies/bert.py | 77 +- colossalai/shardformer/policies/bloom.py | 9 + colossalai/shardformer/policies/falcon.py | 392 +++++++++ colossalai/shardformer/policies/gptj.py | 318 +++++++ colossalai/shardformer/policies/mistral.py | 192 ++++ colossalai/shardformer/policies/opt.py | 9 + colossalai/shardformer/policies/whisper.py | 9 + colossalai/utils/memory.py | 77 ++ colossalai/zero/__init__.py | 11 +- colossalai/zero/gemini/__init__.py | 3 - colossalai/zero/gemini/placement_policy.py | 2 +- docs/source/en/features/shardformer.md | 12 + docs/source/zh-Hans/features/shardformer.md | 12 + examples/language/bert/data.py | 12 +- examples/language/bert/finetune.py | 19 +- tests/kit/model_zoo/registry.py | 8 +- tests/kit/model_zoo/transformers/__init__.py | 7 + tests/kit/model_zoo/transformers/falcon.py | 124 +++ tests/kit/model_zoo/transformers/gpt.py | 2 +- tests/kit/model_zoo/transformers/gptj.py | 109 +++ tests/kit/model_zoo/transformers/mistral.py | 78 ++ .../test_plugin/test_gemini_plugin.py | 5 + .../test_schedule/test_interleaved.py | 142 ++- .../test_model/test_shard_falcon.py | 202 +++++ .../test_model/test_shard_gptj.py | 227 +++++ .../test_model/test_shard_mistral.py | 168 ++++ 39 files changed, 4007 insertions(+), 234 deletions(-) rename colossalai/{ => legacy}/zero/gemini/colo_init_context.py (100%) create mode 100644 colossalai/shardformer/modeling/falcon.py create mode 100644 colossalai/shardformer/modeling/gptj.py create mode 100644 colossalai/shardformer/modeling/mistral.py create mode 100644 colossalai/shardformer/policies/falcon.py create mode 100644 colossalai/shardformer/policies/gptj.py create mode 100644 colossalai/shardformer/policies/mistral.py create mode 100644 colossalai/utils/memory.py create mode 100644 tests/kit/model_zoo/transformers/falcon.py create mode 100644 tests/kit/model_zoo/transformers/gptj.py create mode 100644 tests/kit/model_zoo/transformers/mistral.py create mode 100644 tests/test_shardformer/test_model/test_shard_falcon.py create mode 100644 tests/test_shardformer/test_model/test_shard_gptj.py create mode 100644 tests/test_shardformer/test_model/test_shard_mistral.py diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 5934704f4..859b6e4fb 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -79,7 +79,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 20 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} cancel-in-progress: true diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bbc36ceab..59a0deaeb 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -22,7 +22,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils @@ -911,6 +911,8 @@ class HybridParallelPlugin(PipelinePluginBase): communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. + pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. + num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. """ def __init__( @@ -946,6 +948,8 @@ class HybridParallelPlugin(PipelinePluginBase): communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, custom_policy: Policy = None, + pp_style: str = "1f1b", + num_model_chunks: int = 1, ) -> None: super().__init__() assert ( @@ -972,17 +976,38 @@ class HybridParallelPlugin(PipelinePluginBase): self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) - self.schedule = OneForwardOneBackwardSchedule( - self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + self.stage_manager = PipelineStageManager( + self.pg_mesh, + pipeline_axis=PP_AXIS, + enable_interleave=pp_style == "interleaved", + num_model_chunks=num_model_chunks, ) + + if pp_style == "interleaved": + assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + self.schedule = InterleavedSchedule( + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + ) + elif pp_style == "1f1b": + self.schedule = OneForwardOneBackwardSchedule( + self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + ) + else: + raise NotImplementedError() + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, diff --git a/colossalai/legacy/zero/gemini/__init__.py b/colossalai/legacy/zero/gemini/__init__.py index b272980d3..f30bccea4 100644 --- a/colossalai/legacy/zero/gemini/__init__.py +++ b/colossalai/legacy/zero/gemini/__init__.py @@ -1,3 +1,4 @@ +from .colo_init_context import ColoInitContext, post_process_colo_init_ctx from .ophooks import BaseOpHook, register_ophooks_recursively from .stateful_tensor import StatefulTensor from .stateful_tensor_mgr import StatefulTensorMgr @@ -11,4 +12,6 @@ __all__ = [ "AutoTensorPlacementPolicy", "register_ophooks_recursively", "BaseOpHook", + "ColoInitContext", + "post_process_colo_init_ctx", ] diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/legacy/zero/gemini/colo_init_context.py similarity index 100% rename from colossalai/zero/gemini/colo_init_context.py rename to colossalai/legacy/zero/gemini/colo_init_context.py diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index cbf6dd80f..7c3f15e80 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Iterable, List, Optional, Union import torch import torch.cuda -from torch.nn import Module +from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map from colossalai.interface import OptimizerWrapper @@ -16,18 +16,25 @@ from .base import PipelineSchedule class InterleavedSchedule(PipelineSchedule): - def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: - self.num_model_chunks = num_model_chunks - assert ( - num_microbatches % self.num_model_chunks == 0 - ), "Number of microbatches should be an integer multiple of number of model chunks" + def __init__( + self, + stage_manager: PipelineStageManager, + num_model_chunks: int, + num_microbatch: Optional[int] = None, + microbatch_size: Optional[int] = None, + ) -> None: super().__init__(stage_manager) + assert ( + num_microbatch is not None or microbatch_size is not None + ), "Either num_microbatch or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) - self.num_microbatches = num_microbatches - self.batch: Optional[Any] = None - self.batch_size: Optional[int] = None - self.microbatch_offset: Optional[int] = None - self.microbatch_size: Optional[int] = None + self.num_microbatch = num_microbatch + self.microbatch_size = microbatch_size + self.num_model_chunks = num_model_chunks + + self.batch: Any + self.batch_size: int + self.microbatch_offset: List[int] def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -42,8 +49,22 @@ class InterleavedSchedule(PipelineSchedule): self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" - self.microbatch_size = self.batch_size // self.num_microbatches + if self.num_microbatch is not None: + assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" + self.microbatch_size = self.batch_size // self.num_microbatch + elif self.microbatch_size is not None: + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" + self.num_microbatch = self.batch_size // self.microbatch_size + else: + raise ValueError("Either num_microbatch or microbatch_size should be provided") + + assert ( + self.num_microbatch % self.num_model_chunks == 0 + ), "Number of microbatch should be an integer multiple of number of model chunks" + + assert ( + self.num_microbatch % self.stage_manager.num_stages == 0 + ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" def load_micro_batch(self, model_chunk_id: int) -> Any: """Load a micro batch from the current batch. @@ -58,7 +79,7 @@ class InterleavedSchedule(PipelineSchedule): self.microbatch_offset[model_chunk_id] += self.microbatch_size return tree_map(partial(to_device, device=get_current_device()), micro_batch) - def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: + def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: """Helper method to get the model chunk ID given the iteration number. Args: @@ -70,36 +91,10 @@ class InterleavedSchedule(PipelineSchedule): """ microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages - if not forward: + if not is_forward: model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id - def is_first_stage(self, model_chunk_id: int) -> bool: - """Is the current virtual stage the first stage - - Args: - model_chunk_id (int): The current model chunk idx. - - Returns: - bool: Whether the current virtual stage is the first stage. - """ - if self.stage_manager.is_first_stage() and model_chunk_id == 0: - return True - return False - - def is_last_stage(self, model_chunk_id: int) -> bool: - """Is the current virtual stage the last stage - - Args: - model_chunk_id (int): The current model chunk idx. - - Returns: - bool: Whether the current virtual stage is the last stage. - """ - if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1: - return True - return False - def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For interleaved 1F1B. @@ -111,7 +106,7 @@ class InterleavedSchedule(PipelineSchedule): Returns: Any: The input tensor or input tensor list. """ - if self.is_first_stage(model_chunk_id): + if self.stage_manager.is_first_stage(model_chunk_id): input_tensor = None else: input_tensor = self.comm.recv_forward(prev_rank) @@ -129,7 +124,7 @@ class InterleavedSchedule(PipelineSchedule): Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.is_last_stage(model_chunk_id): + if self.stage_manager.is_last_stage(model_chunk_id): output_tensor_grad = None else: output_tensor_grad = self.comm.recv_backward(next_rank) @@ -145,7 +140,7 @@ class InterleavedSchedule(PipelineSchedule): output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.is_last_stage(model_chunk_id): + if not self.stage_manager.is_last_stage(model_chunk_id): self.comm.send_forward(output_object, next_rank) def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: @@ -157,12 +152,12 @@ class InterleavedSchedule(PipelineSchedule): input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - if not self.is_first_stage(model_chunk_id): + if not self.stage_manager.is_first_stage(model_chunk_id): self.comm.send_backward(input_object, prev_rank) def forward_step( self, - model_chunk: Module, + model_chunk: Union[ModuleList, Module], model_chunk_id: int, input_obj: Optional[dict], criterion: Callable, @@ -171,7 +166,7 @@ class InterleavedSchedule(PipelineSchedule): ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: - model (Module): Model Chunk to be run + model (ModuleList or Module): Model Chunk to be run input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. criterion (Callable): Criterion to calculate loss. accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. @@ -184,10 +179,19 @@ class InterleavedSchedule(PipelineSchedule): # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict - output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) - if self.is_last_stage(model_chunk_id): - loss = criterion(output_obj, micro_batch) / self.num_microbatches + self.stage_manager.model_chunk_id = model_chunk_id + if isinstance(model_chunk, ModuleList): + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + else: + # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + self.stage_manager.model_chunk_id = None + + if self.stage_manager.is_last_stage(model_chunk_id): + loss = criterion(output_obj, micro_batch) / self.num_microbatch if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: @@ -243,17 +247,17 @@ class InterleavedSchedule(PipelineSchedule): def forward_backward_step( self, - model_chunk: Module, + model_chunk: Union[ModuleList, Module], data_iter: Iterable, criterion: Callable[..., Any], optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False, ) -> dict: - """Runs interleaved 1F1B schedule, with communication between pipeline stages. + """Runs interleaved schedule, with communication between pipeline stages. Args: - model_chunk (List[Module]): Model Chunk to be trained. + model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification data_iter (Iterable): Data iterator. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. @@ -263,49 +267,46 @@ class InterleavedSchedule(PipelineSchedule): Returns: dict: A dict with keys: 'loss' and 'outputs'. """ + # TODO: handle arbitrary batch size when forward_only == True forward_only = not torch.is_grad_enabled() if optimizer is None: assert forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) - num_model_chunks = len(model_chunk) - # num_warmup_microbatches is the step when not all the processes are working - num_microbatches = self.num_microbatches * num_model_chunks + num_microbatch = self.num_microbatch * self.num_model_chunks if forward_only: - num_warmup_microbatches = num_microbatches + num_warmup_microbatch = num_microbatch else: - num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 - num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 + num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages + num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch) - num_microbatches_remaining = num_microbatches - num_warmup_microbatches + num_microbatch_remaining = num_microbatch - num_warmup_microbatch # Input, output tensors only need to be saved when doing backward passes input_objs = None output_objs = None if not forward_only: - input_objs = [[] for _ in range(num_model_chunks)] - output_objs = [[] for _ in range(num_model_chunks)] + input_objs = [[] for _ in range(self.num_model_chunks)] + output_objs = [[] for _ in range(self.num_model_chunks)] - outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None - if return_loss and self.stage_manager.is_last_stage(): + if return_loss and self.stage_manager.is_last_stage(-1): accum_loss = torch.zeros(1, device=get_current_device()) else: accum_loss = None # for ranks except the first one, get into recv state - # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining) input_obj = self.recv_forward(0) - input_objs[0].append(input_obj) - # Run warmup forward passes. - for i in range(num_warmup_microbatches): - model_chunk_id = self.get_model_chunk_id(i, forward=True) - # recv first on first rank to avoid sending or recving at the same time - if self.stage_manager.is_first_stage(): + # Run warmup forward passes. + for i in range(num_warmup_microbatch): + model_chunk_id = self.get_model_chunk_id(i, is_forward=True) + # recv first on first rank to avoid sending or receiving at the same time + if self.stage_manager.is_first_stage(-1): input_obj = self.recv_forward(model_chunk_id) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) self.send_forward(model_chunk_id, output_obj) @@ -315,21 +316,20 @@ class InterleavedSchedule(PipelineSchedule): else: output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if not forward_only: + input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) self.send_forward(model_chunk_id, output_obj) - if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches: - break - else: - model_chunk_id = self.get_model_chunk_id(i + 1, forward=True) - input_obj = self.recv_forward(model_chunk_id) - if not forward_only: - input_objs[model_chunk_id].append(input_obj) + if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch: + break + + model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) # Run 1F1B in steady state. - for i in range(num_microbatches_remaining): - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) - last_iteration = i == (num_microbatches_remaining - 1) + for i in range(num_microbatch_remaining): + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch, is_forward=True) + last_iteration = i == num_microbatch_remaining - 1 output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if forward_only: @@ -344,7 +344,7 @@ class InterleavedSchedule(PipelineSchedule): input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) - model_chunk_id = self.get_model_chunk_id(i, forward=False) + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) output_obj_grad = self.recv_backward(model_chunk_id) # Pop output_obj and output_obj from the start of the list for @@ -358,23 +358,25 @@ class InterleavedSchedule(PipelineSchedule): if last_iteration: input_obj = None else: - model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True) + model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) input_obj = self.recv_forward(model_chunk_id) - model_chunk_id = self.get_model_chunk_id(i, forward=False) + + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) self.send_backward(model_chunk_id, input_obj_grad) # Run cooldown backward passes. if not forward_only: - for i in range(num_microbatches_remaining, num_microbatches): - model_chunk_id = self.get_model_chunk_id(i, forward=False) - # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}") + for i in range(num_microbatch_remaining, num_microbatch): + model_chunk_id = self.get_model_chunk_id(i, is_forward=False) input_obj = input_objs[model_chunk_id].pop(0) output_obj = output_objs[model_chunk_id].pop(0) - output_obj_grad = self.recv_backward(model_chunk_id) input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) self.send_backward(model_chunk_id, input_obj_grad) + if not forward_only: + assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) + if outputs is not None: outputs = merge_batch(outputs) return {"loss": accum_loss, "outputs": outputs} diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index d988015ce..d7853938a 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -19,7 +19,15 @@ class PipelineStageManager: stage (int): The current stage. """ - def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None: + def __init__( + self, + pg_mesh: ProcessGroupMesh, + pipeline_axis: int, + enable_interleave: bool = False, + num_model_chunks: int = 1, + ) -> None: + assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" + self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None @@ -43,29 +51,62 @@ class PipelineStageManager: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - if is_virtual: + self.is_interleave = enable_interleave + if enable_interleave: + # use circle p2p communication # add the process group of the first rank and the last rank - # only used in interleaved pipeline for now group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]]) if self.stage in [stages[0], stages[-1]]: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - def is_first_stage(self) -> bool: + # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers + self.num_model_chunks: int = num_model_chunks + + # for shardformer, hold stage indices of model + self.stage_indices: List[Tuple[int, int]] + # for shardformer, hold model chunk id + self.model_chunk_id: Optional[int] = None + + def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: """Is the current stage the first stage. + NOTE: + 1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device. + 2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device() + Returns: bool: Whether the current stage is the first stage. """ - return self.stage == 0 + if self.is_interleave and model_chunk_id is None: + model_chunk_id = self.model_chunk_id + assert self.is_interleave ^ ( + model_chunk_id is None + ), "model_chunk_id must be specified when using interleaved pipeline" + if not self.is_interleave or model_chunk_id < 0: + return self.stage == 0 + else: + return self.stage == 0 and model_chunk_id == 0 - def is_last_stage(self) -> bool: + def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: """Is the current stage the last stage. + NOTE: + 1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device. + 2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device() + Returns: bool: Whether the current stage is the last stage. """ - return self.stage == self.num_stages - 1 + if self.is_interleave and model_chunk_id is None: + model_chunk_id = self.model_chunk_id + assert self.is_interleave ^ ( + model_chunk_id is None + ), "model_chunk_id must be specified when using interleaved pipeline" + if not self.is_interleave or model_chunk_id < 0: + return self.stage == self.num_stages - 1 + else: + return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index cf06eecd3..e475a607f 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -127,6 +127,7 @@ We will follow this roadmap to develop Shardformer: | whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | | sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| falcon | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | @@ -136,6 +137,7 @@ We will follow this roadmap to develop Shardformer: | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| mistral | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | ## 💡 API Design diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 8387bb5e3..4aa281290 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -275,8 +275,8 @@ class FusedRMSNorm(BaseLayerNorm): ) LazyInitContext.materialize(module) - # to check if it is huggingface LlamaRMSNorm - if module.__class__.__name__ == "LlamaRMSNorm": + # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm + if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]: normalized_shape = module.weight.shape[0] eps = module.variance_epsilon elementwise_affine = True diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py new file mode 100644 index 000000000..4e271dfe0 --- /dev/null +++ b/colossalai/shardformer/modeling/falcon.py @@ -0,0 +1,772 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.models.falcon.modeling_falcon import ( + FalconForCausalLM, + FalconForQuestionAnswering, + FalconForSequenceClassification, + FalconForTokenClassification, + FalconModel, + build_alibi_tensor, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.shard import ShardConfig + + +def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: + def build_falcon_alibi_tensor( + self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype + ) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + import math + + if dist.is_initialized(): + world_size = dist.get_world_size(process_group) + num_heads = num_heads * world_size + + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size(process_group)) + offset = dist.get_rank(process_group) * num_heads_per_rank + alibi = alibi.view(batch_size, num_heads, 1, seq_length) + alibi = alibi[:, offset : num_heads_per_rank + offset, :, :] + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) + else: + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + return build_falcon_alibi_tensor + + +def get_tp_falcon_decoder_layer_forward(): + from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, dropout_add + + def forward( + self: FalconDecoderLayer, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + + if self.config.new_decoder_architecture: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attn_outputs = self.self_attention( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = dropout_add( + attention_output, residual, self.config.attention_dropout, training=self.training + ) + mlp_layernorm_out = self.post_attention_layernorm(residual) + + outputs = attn_outputs[1:] + + # MLP. + mlp_output = self.mlp(mlp_layernorm_out) + + if self.config.new_decoder_architecture or self.config.parallel_attn: + mlp_output = mlp_output + attention_output + + output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + return forward + + +def get_falcon_flash_attention_forward(): + try: + from xformers.ops import memory_efficient_attention as me_attention + except: + raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") + from transformers.models.falcon.modeling_falcon import FalconAttention + + def forward( + self: FalconAttention, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, kv_length, _ = key_layer.shape + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + + query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous() + key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() + value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous() + + if alibi is not None: + attention_mask_float = ( + attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta + ) + + batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1] + tgt_len = key_layer_.size()[1] + attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous() + context_layer = me_attention( + query_layer_, + key_layer_, + value_layer_, + attn_bias=attention_mask_float, + scale=self.inv_norm_factor, + p=self.attention_dropout.p, + ) + batch_size, seq_length, _, _ = context_layer.shape + context_layer = context_layer.reshape(batch_size, seq_length, -1) + + output_tensor = self.dense(context_layer) + + return output_tensor, present + + return forward + + +class FalconPipelineForwards: + """ + This class serves as a micro library for falcon pipeline forwards. + """ + + @staticmethod + def falcon_model_forward( + self: FalconModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + logger.warning_once("past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + else: + past_key_values = self._convert_to_rw_cache(past_key_values) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # case: First stage of training + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = inputs_embeds + + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # Compute alibi tensor: check build_alibi_tensor documentation + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + if self.use_alibi: + alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + else: + alibi = None + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + start_idx, end_idx = stage_index[0], stage_index[1] + for i, (block, layer_past) in enumerate( + zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx + ): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + if stage_manager.is_last_stage(): + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if presents is not None: + presents = self._convert_cache_to_standard_format(presents, batch_size) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + # always return dict for imediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def falcon_for_causal_lm_forward( + self: FalconForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = FalconPipelineForwards.falcon_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def falcon_for_sequence_classification_forward( + self: FalconForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = FalconPipelineForwards.falcon_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + if stage_manager.is_last_stage(): + batch_size = hidden_states.shape[0] + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def falcon_for_token_classification_forward( + self: FalconForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + transformer_outputs = FalconPipelineForwards.falcon_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + else: + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} + + @staticmethod + def falcon_for_question_answering_forward( + self: FalconForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + outputs = FalconPipelineForwards.falcon_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + if stage_manager.is_last_stage(): + sequence_output = outputs[0] + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py new file mode 100644 index 000000000..ad51bf2c7 --- /dev/null +++ b/colossalai/shardformer/modeling/gptj.py @@ -0,0 +1,824 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) +from transformers.models.gptj.modeling_gptj import ( + GPTJForCausalLM, + GPTJForQuestionAnswering, + GPTJForSequenceClassification, + GPTJModel, + apply_rotary_pos_emb, + get_embed_positions, +) +from transformers.utils import is_torch_fx_proxy, logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + + +class GPTJPipelineForwards: + """ + This class serves as a micro library for forward function substitution of GPTJ models + under pipeline setting. + """ + + @staticmethod + def gptj_model_forward( + self: GPTJModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPast]: + # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJModel.forward. + # Please refer to original code of transformers for more details. + # GPTJ has no cross attention in comparison to GPT2 + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") + past_key_values = None + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + input_shape = input_ids.size() + input_ids = input_ids.view(-1, seq_length) + + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, seq_length) + else: + if hidden_states is None: + raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape[0], input_shape[1] + device = hidden_states.device + + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + # position id to be asssigned not just for the first stage for attn input + if position_ids is not None: + position_ids = position_ids.view(-1, seq_length) + else: + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + if stage_manager.is_first_stage(): + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + if shard_config.enable_sequence_parallelism: + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + # Going through held blocks. + start_idx, end_idx = stage_index[0], stage_index[1] + for i in range(start_idx, end_idx): + block = self.h[i] + torch.cuda.set_device(hidden_states.device) + + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + position_ids, + head_mask[i], + ) + else: + outputs = block( + hidden_states=hidden_states, + layer_past=None, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + if shard_config.enable_sequence_parallelism: + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + if stage_manager.is_last_stage(): + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if stage_manager.is_last_stage(): + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + else: + # always return dict for intermediate stage + return {"hidden_states": hidden_states} + + @staticmethod + def gptj_causallm_model_forward( + self: GPTJForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForCausalLM.forward. + # Please refer to original code of transformers for more details. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = GPTJPipelineForwards.gptj_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPTJModel + if not stage_manager.is_last_stage(): + return {"hidden_states": transformer_outputs["hidden_states"]} + + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def gptj_for_sequence_classification_forward( + self: GPTJForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification.forward. + # Please refer to original code of transformers for more details. + """ + logger = logging.get_logger(__name__) + + if input_ids is not None: + batch_size, _ = input_ids.shape[:2] + else: + batch_size, _ = hidden_states.shape[:2] + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = GPTJPipelineForwards.gptj_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPTJModel + if not stage_manager.is_last_stage(): + return {"hidden_states": transformer_outputs["hidden_states"]} + + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(pooled_logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def gptj_for_question_answering_forward( + self: GPTJForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + # This function is modified on the basis of transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering.forward. + # Please refer to original code of transformers for more details. + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = GPTJPipelineForwards.gptj_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) + + # If not at the last stage, return hidden_states as in GPTJModel + if not stage_manager.is_last_stage(): + return {"hidden_states": outputs["hidden_states"]} + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def get_gptj_flash_attention_forward(): + from transformers.models.gptj.modeling_gptj import GPTJAttention + + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + def split_heads(tensor, num_attention_heads, attn_head_size, rotary): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + tensor = tensor.view(new_shape) + if rotary or len(tensor.shape) in [4, 5]: + return tensor + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + + def forward( + self: GPTJAttention, + hidden_states: torch.FloatTensor, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = split_heads(query, self.num_attention_heads, self.head_dim, True) + key = split_heads(key, self.num_attention_heads, self.head_dim, True) + value = split_heads(value, self.num_attention_heads, self.head_dim, False) + + if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): + # The logic to conditionally copy to GPU could not be traced, so we do this + # every time in the torch.fx case + embed_positions = get_embed_positions(self.embed_positions, position_ids) + else: + embed_positions = self._get_embed_positions(position_ids) + + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + k_rot = apply_rotary_pos_emb(k_rot, sin, cos) + q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) + + # key = key.permute(0, 2, 1, 3) + # query = query.permute(0, 2, 1, 3) + key = key.to(dtype=value.dtype) # fp16 compatability + query = query.to(dtype=value.dtype) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache is True: + present = (key, value) + else: + present = None + + # use AttnMaskType and ColoAttention + attn_mask_type = AttnMaskType.causal + flash_attention_mask = None + if attention_mask != None: + if attn_mask_type == AttnMaskType.causal: + attn_mask_type == AttnMaskType.paddedcausal + else: + attn_mask_type = AttnMaskType.padding + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + + # use coloattention + scale = value.size(-1) ** -0.5 + + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_attention_heads, dropout=self.attn_dropout.p, scale=scale + ) + + attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + outputs = (attn_output, present, None) + + return outputs # a, present, (attentions) + + return forward + + +def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]).long() + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + position_ids, + head_mask[i], + ) + else: + outputs = block( + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py new file mode 100644 index 000000000..1ddb26c25 --- /dev/null +++ b/colossalai/shardformer/modeling/mistral.py @@ -0,0 +1,73 @@ +from typing import Optional, Tuple + +import torch + + +def get_mistral_flash_attention_forward(): + from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv + + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + def forward( + self: MistralAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = ( + self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) + key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) + value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) + + flash_attention_mask = None + attn_mask_type = AttnMaskType.causal + if attention_mask != None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attn_mask_type = AttnMaskType.paddedcausal + + attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + return forward diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index b01896e48..0991ace2c 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -85,6 +85,17 @@ _POLICY_LIST = { "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation( file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy" ), + # GPTJ + "transformers.models.gptj.modeling_gptj.GPTJModel": PolicyLocation(file_name="gptj", class_name="GPTJModelPolicy"), + "transformers.models.gptj.modeling_gptj.GPTJForCausalLM": PolicyLocation( + file_name="gptj", class_name="GPTJForCausalLMPolicy" + ), + "transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering": PolicyLocation( + file_name="gptj", class_name="GPTJForQuestionAnsweringPolicy" + ), + "transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification": PolicyLocation( + file_name="gptj", class_name="GPTJForSequenceClassificationPolicy" + ), # ViT "transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), "transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation( @@ -146,6 +157,31 @@ _POLICY_LIST = { "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" ), + # Falcon + "transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation( + file_name="falcon", class_name="FalconModelPolicy" + ), + "transformers.models.falcon.modeling_falcon.FalconForCausalLM": PolicyLocation( + file_name="falcon", class_name="FalconForCausalLMPolicy" + ), + "transformers.models.falcon.modeling_falcon.FalconForSequenceClassification": PolicyLocation( + file_name="falcon", class_name="FalconForSequenceClassificationPolicy" + ), + "transformers.models.falcon.modeling_falcon.FalconForTokenClassification": PolicyLocation( + file_name="falcon", class_name="FalconForTokenClassificationPolicy" + ), + "transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation( + file_name="falcon", class_name="FalconForQuestionAnsweringPolicy" + ), + "transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation( + file_name="mistral", class_name="MistralModelPolicy" + ), + "transformers.models.mistral.modeling_mistral.MistralForCausalLM": PolicyLocation( + file_name="mistral", class_name="MistralForCausalLMPolicy" + ), + "transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation( + file_name="mistral", class_name="MistralForSequenceClassificationPolicy" + ), } diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 003c9322a..1d2b7a570 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch.nn as nn @@ -214,13 +214,32 @@ class Policy(ABC): return layers_per_stage @staticmethod - def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: + def get_stage_index( + layers_per_stage: List[int], + stage: int, + num_model_chunks: int = 1, + num_stages: int = 0, + ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: """ - get the start index and end index of layers for each stage. + Get the start index and end index of layers for each stage. + + Args: + layers_per_stage (List[int]): number of layers for each stage + stage (int): the stage index + num_stages (int): number of stages + num_model_chunks (int): number of model chunks + + Returns: + - Tuple[int, int]: the start index and end index of this stage + - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk + """ num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - start_idx = num_layers_per_stage_accumulated[stage] - end_idx = num_layers_per_stage_accumulated[stage + 1] + stage_indices = [] + for model_chunk in range(num_model_chunks): + start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] + end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] + stage_indices.append([start_idx, end_idx]) - return [start_idx, end_idx] + return stage_indices[0] if num_model_chunks == 1 else stage_indices diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index c31327a6c..78363bf5e 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -21,7 +21,7 @@ __all__ = [ "BertPolicy", "BertModelPolicy", "BertForPreTrainingPolicy", - "BertLMdHeadModelPolicy", + "BertLMHeadModelPolicy", "BertForMaskedLMPolicy", "BertForNextSentencePredictionPolicy", "BertForSequenceClassificationPolicy", @@ -249,15 +249,34 @@ class BertPolicy(Policy): return self.model def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "BertModel": - module = self.model - else: - module = self.model.bert + """ + If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy. + """ + if self.pipeline_stage_manager is None: + return + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "BertModel": + module = self.model + else: + module = self.model.bert + + if stage_manager.is_interleave: + layers_per_stage = self.distribute_layers( + len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_manager.stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + + else: layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { @@ -265,11 +284,8 @@ class BertPolicy(Policy): new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config ) } - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=model_cls - ) - return + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -282,13 +298,32 @@ class BertPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.encoder.layer[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.pooler) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = self.distribute_layers( + len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + if stage_manager.is_first_stage(-1): + held_layers.append(module.embeddings) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(-1): + held_layers.append(module.pooler) + + else: + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.pooler) return held_layers @@ -464,7 +499,7 @@ class BertForSequenceClassificationPolicy(BertPolicy): """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index c8687a1ac..eddfafdcb 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -21,6 +21,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe class BloomPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + from packaging.version import Version + + assert Version(transformers.__version__) <= Version( + "4.33.0" + ), "The Bloom model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py new file mode 100644 index 000000000..f2eeb9d69 --- /dev/null +++ b/colossalai/shardformer/policies/falcon.py @@ -0,0 +1,392 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List + +from torch import Tensor, nn +from torch.nn import Module + +import colossalai.shardformer.layer as col_nn + +from ..modeling.falcon import ( + FalconPipelineForwards, + build_falcon_alibi_tensor_fn, + get_falcon_flash_attention_forward, + get_tp_falcon_decoder_layer_forward, +) +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["FalconPolicy"] + + +class FalconPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + from packaging.version import Version + + assert Version(transformers.__version__) <= Version( + "4.33.0" + ), "The Falcon model should run on a transformers version not greater than 4.33.0." + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel + + if not self.model.config.new_decoder_architecture and self.model.config.multi_query: + warnings.warn( + "Falcon dosen't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag." + ) + self.shard_config.enable_tensor_parallelism = False + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + + policy = {} + if self.shard_config.enable_tensor_parallelism: + attn_attribute_replacement = { + "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attention.num_kv_heads": self.model.config.num_kv_heads // self.shard_config.tensor_parallel_size, + } + + policy[FalconDecoderLayer] = ModulePolicyDescription( + attribute_replacement=attn_attribute_replacement, + method_replacement={"forward": get_tp_falcon_decoder_layer_forward()}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription(suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row), + ], + ) + + policy[FalconModel] = ModulePolicyDescription( + attribute_replacement={ + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + method_replacement={ + "build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ], + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # handle falcon model + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=FalconModel, + ) + + # handle falcon decoder layer + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="ln_mlp", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="input_layernorm", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True + ), + ], + policy=policy, + target_key=FalconDecoderLayer, + ) + + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={"forward": get_falcon_flash_attention_forward()}, + policy=policy, + target_key=FalconAttention, + ) + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "FalconModel": + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + if self.model.__class__.__name__ == "FalconModel": + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.word_embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + + return held_layers + + +class FalconModelPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + + from transformers.models.falcon.modeling_falcon import FalconModel + + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconModel, new_forward=FalconPipelineForwards.falcon_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[Module]: + """ + get pipeline layers for current stage + """ + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """no shared params in falcon model""" + return [] + + +class FalconForCausalLMPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconForCausalLM + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=FalconForCausalLM, + ) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconForCausalLM, + new_forward=FalconPipelineForwards.falcon_for_causal_lm_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + falcon_model = self.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if id(falcon_model.transformer.word_embeddings.weight) == id(falcon_model.lm_head.weight): + # tie weights + return [ + { + 0: falcon_model.transformer.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: falcon_model.lm_head.weight, + } + ] + return [] + + +class FalconForSequenceClassificationPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconForSequenceClassification + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=FalconForSequenceClassification, + ) + + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconForSequenceClassification, + new_forward=FalconPipelineForwards.falcon_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in falcon for sequence classification model""" + return [] + + +class FalconForTokenClassificationPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconForTokenClassification + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=FalconForTokenClassification, + ) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconForTokenClassification, + new_forward=FalconPipelineForwards.falcon_for_token_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.dropout) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in falcon for token classification model""" + return [] + + +class FalconForQuestionAnsweringPolicy(FalconPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.falcon.modeling_falcon import FalconForQuestionAnswering + + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="qa_outputs", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=FalconForQuestionAnswering, + ) + if self.pipeline_stage_manager: + self.set_pipeline_forward( + model_cls=FalconForQuestionAnswering, + new_forward=FalconPipelineForwards.falcon_for_question_answering_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in falcon for question answering model""" + return [] diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py new file mode 100644 index 000000000..9feb826c4 --- /dev/null +++ b/colossalai/shardformer/policies/gptj.py @@ -0,0 +1,318 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List + +from torch import Tensor, nn + +import colossalai.shardformer.layer as col_nn + +from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + "GPTJPolicy", + "GPTJModelPolicy", + "GPTJForCausalLMPolicy", + "GPTJForSequenceClassificationPolicy", + "GPTJForQuestionAnsweringPolicy", + "FlaxGPTJPolicy", + "FlaxGPTJForCausalLMPolicy", +] + + +class GPTJPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel + + policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + overlap = self.shard_config.enable_sequence_overlap + if self.shard_config.enable_tensor_parallelism: + policy[GPTJModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) + + policy[GPTJBlock] = ModulePolicyDescription( + attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.k_proj", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.q_proj", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.v_proj", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.out_proj", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_in", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="mlp.fc_out", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + policy=policy, + target_key=GPTJModel, + ) + + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=GPTJBlock, + ) + + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_gptj_flash_attention_forward(), + }, + policy=policy, + target_key=GPTJAttention, + ) + + return policy + + def postprocess(self): + return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "GPTJModel": + module = self.model + else: + module = self.model.transformer + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.drop) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "GPTJModel": + module = self.model + else: + module = self.model.transformer + + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + + +# GPTJModel +class GPTJModelPolicy(GPTJPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJModel + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPTJModel, new_forward=GPTJPipelineForwards.gptj_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2Model.""" + return [] + + +# GPTJForCausalLM +class GPTJForCausalLMPolicy(GPTJPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPTJForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) + } + policy.update(addon_module) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPTJForCausalLM, new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, policy=policy + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """The weights of wte and lm_head are shared.""" + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager is not None: + if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): + first_stage, last_stage = 0, stage_manager.num_stages - 1 + return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [] + + +# GPTJForSequenceClassification +class GPTJForSequenceClassificationPolicy(GPTJPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJForSequenceClassification + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPTJForSequenceClassification, + new_forward=GPTJPipelineForwards.gptj_for_sequence_classification_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPTJForSequenceClassification.""" + return [] + + +# GPTJForQuestionAnswering +class GPTJForQuestionAnsweringPolicy(GPTJPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gptj.modeling_gptj import GPTJForQuestionAnswering + + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward( + model_cls=GPTJForQuestionAnswering, + new_forward=GPTJPipelineForwards.gptj_for_question_answering_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.qa_outputs) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in GPT2ForQuestionAnswering.""" + return [] diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py new file mode 100644 index 000000000..c16aa6dea --- /dev/null +++ b/colossalai/shardformer/policies/mistral.py @@ -0,0 +1,192 @@ +import warnings +from typing import Dict, Union + +import torch.nn as nn + +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D + +from ..modeling.mistral import get_mistral_flash_attention_forward +from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] + + +class MistralPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel + + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn( + "Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) + + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, + } + + policy[MistralDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ), + ], + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=MistralModel, + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key=MistralDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=MistralModel, + ) + + if self.shard_config.enable_flash_attention: + self.append_or_create_method_replacement( + description={ + "forward": get_mistral_flash_attention_forward(), + }, + policy=policy, + target_key=MistralAttention, + ) + + return policy + + def postprocess(self): + return self.model + + +class MistralModelPolicy(MistralPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + if self.pipeline_stage_manager: + warnings.warn("Mistral dosen't support pipeline parallelism now.") + + return super().module_policy() + + +class MistralForCausalLMPolicy(MistralPolicy): + def module_policy(self): + from transformers import MistralForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + MistralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + + if self.pipeline_stage_manager: + warnings.warn("Mistral dosen't support pipeline parallelism now.") + + policy.update(new_item) + + return policy + + +class MistralForSequenceClassificationPolicy(MistralPolicy): + def module_policy(self): + from transformers import MistralForSequenceClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + MistralForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) + } + + if self.pipeline_stage_manager: + warnings.warn("Mistral dosen't support pipeline parallelism now.") + + policy.update(new_item) + return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 0b5c767e1..e2f3a829c 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -22,6 +22,15 @@ __all__ = [ class OPTPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + from packaging.version import Version + + assert Version(transformers.__version__) <= Version( + "4.33.0" + ), "The OPT model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 3ce198e9e..6dae99e8c 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -26,6 +26,15 @@ __all__ = [ class WhisperPolicy(Policy): + def __init__(self) -> None: + super().__init__() + import transformers + from packaging.version import Version + + assert Version(transformers.__version__) <= Version( + "4.33.0" + ), "The Whisper model should run on a transformers version not greater than 4.33.0." + def config_sanity_check(self): pass diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py new file mode 100644 index 000000000..efe4b4f28 --- /dev/null +++ b/colossalai/utils/memory.py @@ -0,0 +1,77 @@ +from collections import namedtuple + +import psutil +import torch +import torch.distributed as dist + +from colossalai.utils import get_current_device + +_GLOBAL_CUDA_MEM_FRACTION = 1.0 +_GLOBAL_CPU_MEM_CAPACITY = -1 + + +# copy from PatrickStar +def _get_cpu_memory_info(): + ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"]) + try: + # psutil reads the memory info from /proc/memory_info, + # which results in returning the host memory instead of + # that of container. + # Here we try to read the container memory with method in: + # https://stackoverflow.com/a/46213331/5163915 + mems = {} + with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f: + for line in f: + fields = line.split() + mems[fields[0]] = int(fields[1]) * 1024 + total = mems[b"MemTotal:"] + free = mems[b"MemFree:"] + cached = mems[b"Cached:"] + buffers = mems[b"Buffers:"] + used = total - free - cached - buffers + if used < 0: + used = total - free + mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used) + except FileNotFoundError: + mems = psutil.virtual_memory() + mem_info = ps_mem_info( + total=mems.total, + free=mems.free, + cached=mems.cached, + buffers=mems.buffers, + used=mems.used, + ) + return mem_info + + +def colo_device_memory_capacity(device: torch.device) -> int: + """ + Get the capacity of the memory of the device + + Args: + device (torch.device): a device + + Returns: + int: size in byte + """ + # TODO: add NPU support + assert isinstance(device, torch.device) + if device.type == "cpu": + # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. + return colo_get_cpu_memory_capacity() // dist.get_world_size() + if device.type == "cuda": + return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION + + +def colo_get_cpu_memory_capacity() -> int: + """ + Get the cpu memory capacity. We may not use all of it. + Returns: + int: _description_ + """ + global _GLOBAL_CPU_MEM_CAPACITY + if _GLOBAL_CPU_MEM_CAPACITY == -1: + mem_info = _get_cpu_memory_info() + return mem_info.total + else: + return _GLOBAL_CPU_MEM_CAPACITY diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 90d0f8de1..5ad59e832 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -1,11 +1,4 @@ -from .gemini import ( - ColoInitContext, - GeminiAdamOptimizer, - GeminiDDP, - GeminiOptimizer, - get_static_torch_model, - post_process_colo_init_ctx, -) +from .gemini import GeminiAdamOptimizer, GeminiDDP, GeminiOptimizer, get_static_torch_model from .low_level import LowLevelZeroOptimizer from .wrapper import zero_model_wrapper, zero_optim_wrapper @@ -16,7 +9,5 @@ __all__ = [ "zero_model_wrapper", "zero_optim_wrapper", "LowLevelZeroOptimizer", - "ColoInitContext", - "post_process_colo_init_ctx", "get_static_torch_model", ] diff --git a/colossalai/zero/gemini/__init__.py b/colossalai/zero/gemini/__init__.py index 358d5c7fd..6d93ca8ed 100644 --- a/colossalai/zero/gemini/__init__.py +++ b/colossalai/zero/gemini/__init__.py @@ -1,5 +1,4 @@ from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration -from .colo_init_context import ColoInitContext, post_process_colo_init_ctx from .gemini_ddp import GeminiDDP from .gemini_mgr import GeminiManager from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer @@ -15,6 +14,4 @@ __all__ = [ "get_static_torch_model", "GeminiAdamOptimizer", "GeminiOptimizer", - "ColoInitContext", - "post_process_colo_init_ctx", ] diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 8a74eb587..c410ad379 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type import torch -from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index bf7b2b3e4..1e633ebc0 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -178,6 +178,18 @@ Model/Feature Compatibility Matrix: ❌ ❌ + + Falcon + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ❌ + ✔️ + ❌ + ❌ + diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 99752a1ce..972c48b0c 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -174,6 +174,18 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ❌ ❌ + + Falcon + ✔️ + ✔️ + ✔️ + ✔️ + ✔️ + ❌ + ✔️ + ❌ + ❌ + diff --git a/examples/language/bert/data.py b/examples/language/bert/data.py index ef51f938d..31c6937ee 100644 --- a/examples/language/bert/data.py +++ b/examples/language/bert/data.py @@ -88,20 +88,24 @@ class GLUEDataBuilder: ) def val_dataloader(self): + # TODO: drop_last is set to True for now to avoid error when using PP + # as the last batch may not be divisible by the number of microbatches if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + return self.plugin.prepare_dataloader( + self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True + ) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) for x in self.eval_splits ] def test_dataloader(self): if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True) elif len(self.eval_splits) > 1: return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True) for x in self.eval_splits ] diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 563cfa58d..b349d7edf 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -57,7 +57,9 @@ def evaluate_model( def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( + None if not booster.plugin.stage_manager.is_interleave else -1 + ) accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: @@ -69,9 +71,10 @@ def evaluate_model( current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) current_rank = dist.get_rank() batch = iter([batch]) + outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) - if is_pp_last_stage: + if is_pp_last_device: logits = outputs["outputs"]["logits"] val_loss = outputs["loss"] accum_loss.add_(val_loss) @@ -133,8 +136,10 @@ def train_epoch( coordinator: DistCoordinator, ): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( + None if not booster.plugin.stage_manager.is_interleave else -1 + ) + print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) total_step = len(train_dataloader) model.train() @@ -148,7 +153,7 @@ def train_epoch( train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True ) # Backward and optimize - if is_pp_last_stage: + if is_pp_last_device: loss = outputs["loss"] pbar.set_postfix({"loss": loss.item()}) else: @@ -222,7 +227,9 @@ def main(): tp_size=1, pp_size=2, num_microbatches=None, - microbatch_size=1, + pp_style="interleaved", + num_model_chunks=2, + microbatch_size=16, enable_all_optimization=True, zero_stage=1, precision="fp16", diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index b90972291..bb522778b 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -71,8 +71,12 @@ class ModelZooRegistry(dict): new_dict = dict() for k, v in self.items(): - if keyword in k: - new_dict[k] = v + if keyword == "transformers_gpt": + if keyword in k and not "gptj" in k: # ensure GPT2 does not retrieve GPTJ models + new_dict[k] = v + else: + if keyword in k: + new_dict[k] = v assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 2a492361b..be6d92f01 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -3,10 +3,17 @@ from .bert import * from .blip2 import * from .bloom import * from .chatglm2 import * +from .falcon import * from .gpt import * +from .gptj import * from .llama import * from .opt import * from .sam import * from .t5 import * from .vit import * from .whisper import * + +try: + from .mistral import * +except ImportError: + print("This version of transformers doesn't support mistral.") diff --git a/tests/kit/model_zoo/transformers/falcon.py b/tests/kit/model_zoo/transformers/falcon.py new file mode 100644 index 000000000..d28d44634 --- /dev/null +++ b/tests/kit/model_zoo/transformers/falcon.py @@ -0,0 +1,124 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register Falcon +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([0], dtype=torch.int64) + return data + + +def data_gen_for_question_answering(): + input_ids = torch.tensor( + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], + dtype=torch.int64, + ) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + start_positions = torch.tensor([1], dtype=torch.int64) + end_positions = torch.tensor([10], dtype=torch.int64) + return dict( + input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions + ) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_falcon_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn_for_causal_lm = lambda x: x.loss +loss_fn_for_classification = lambda x: x.loss +loss_fn_for_question_answering = lambda x: x.loss + +config = transformers.FalconConfig( + num_hidden_layers=2, + num_attention_heads=4, + vocab_size=250880, + hidden_dropout=0, + attention_dropout=0, + hidden_size=64, + multi_query=False, + new_decoder_architecture=True, + pad_token_id=-1, +) + +model_zoo.register( + name="transformers_falcon", + model_fn=lambda: transformers.FalconModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_falcon_model, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_falcon_for_causal_lm", + model_fn=lambda: transformers.FalconForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_falcon_for_sequence_classification", + model_fn=lambda: transformers.FalconForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_falcon_for_token_classification", + model_fn=lambda: transformers.FalconForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_falcon_for_question_answering", + model_fn=lambda: transformers.FalconForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_question_answering, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 5e98c02fd..24f9627c2 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -14,7 +14,7 @@ def data_gen(): # Generated from following code snippet # # from transformers import GPT2Tokenizer - # input = 'Hello, my dog is cute' + # input = 'Hello, my dog is cute is cute' (last two words repeated to satisfy length requirement) # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] diff --git a/tests/kit/model_zoo/transformers/gptj.py b/tests/kit/model_zoo/transformers/gptj.py new file mode 100644 index 000000000..263978512 --- /dev/null +++ b/tests/kit/model_zoo/transformers/gptj.py @@ -0,0 +1,109 @@ +import copy + +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence GPT +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoTokenizer + # input = 'Hello, my dog is cute is cute' (last two words repeated to satisfy length requirement) + # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_question_answering(): + # question answering data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + start_positions = torch.tensor([0], dtype=torch.int64) + data["start_positions"] = start_positions + end_positions = torch.tensor([1], dtype=torch.int64) + data["end_positions"] = end_positions + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_gptj_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn = lambda x: x.loss + +config = transformers.GPTJConfig( + n_layer=2, + n_head=16, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + hidden_dropout=0, + problem_type="single_label_classification", + pad_token_id=50256, +) + +config_for_token_classification = copy.deepcopy(config) +config_for_token_classification.num_labels = 2 + +# register the following models +model_zoo.register( + name="transformers_gptj", + model_fn=lambda: transformers.GPTJModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_gptj_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gptj_lm", + model_fn=lambda: transformers.GPTJForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gptj_for_question_answering", + model_fn=lambda: transformers.GPTJForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gptj_for_sequence_classification", + model_fn=lambda: transformers.GPTJForSequenceClassification(config_for_token_classification), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py new file mode 100644 index 000000000..37f875857 --- /dev/null +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -0,0 +1,78 @@ +import torch +import transformers +from transformers import MistralConfig + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence Mistral +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import AutoModelForCausalLM, AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + # input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement) + # tokenized_input = tokenizer([input], return_tensors="pt") + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data["labels"] = data["input_ids"].clone() + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data["labels"] = torch.tensor([1], dtype=torch.int64) + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_mistral_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) +loss_fn = lambda x: x.loss +loss_fn_for_seq_classification = lambda output: output.logits.mean() + +config = MistralConfig( + hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258 +) + +model_zoo.register( + name="transformers_mistral", + model_fn=lambda: transformers.MistralModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_mistral_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_mistral_for_casual_lm", + model_fn=lambda: transformers.MistralForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_mistral_for_sequence_classification", + model_fn=lambda: transformers.MistralForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 61debe47b..ddb4484ff 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -105,6 +105,11 @@ def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool "transformers_sam", "transformers_vit", "transformers_gpt_double_heads", # TODO check why does the model fail to run using Gemini + "transformers_falcon", # TODO check why falcon fails to run Gemini + "transformers_falcon_for_causal_lm", + "transformers_falcon_for_sequence_classification", + "transformers_falcon_for_token_classification", + "transformers_falcon_for_question_answering", ]: continue diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index f181453ea..4de50245f 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -4,6 +4,7 @@ from types import MethodType import pytest import torch +import torch.distributed as dist import torch.nn as nn import colossalai @@ -11,31 +12,21 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all +NUM_LAYER = 8 +DIM = 4 + class MlpModel(nn.Module): def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(4, 8) - self.linear2 = nn.Linear(8, 8) - self.linear3 = nn.Linear(8, 8) - self.linear4 = nn.Linear(8, 8) - self.linear5 = nn.Linear(8, 8) - self.linear6 = nn.Linear(8, 8) - self.linear7 = nn.Linear(8, 8) - self.linear8 = nn.Linear(8, 4) + super().__init__() + self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)]) def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - x = self.linear4(x) - x = self.linear5(x) - x = self.linear6(x) - x = self.linear7(x) - x = self.linear8(x) + for layer in self.layers: + x = layer(x) return x @@ -44,70 +35,71 @@ def pp_linear_fwd( data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None, - num_chunks: int = None, model_chunk_id: int = None, ): - if stage_mgr.is_first_stage() and model_chunk_id == 0: + if stage_mgr.is_first_stage(model_chunk_id): return {"input_obj": forward(data)} - elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: + elif stage_mgr.is_last_stage(model_chunk_id): return forward(input_obj) else: return {"input_obj": forward(input_obj)} -@parameterize("num_micro_batches", [4, 8, 12]) -def examine_pp(num_micro_batches): +def run_pp( + rank: int, + world_size: int, + port: int, + num_microbatch: int, + batch_size: int, + num_model_chunk: int, +): """ This test is to examine the correctness of interleaved 1F1B, compared with torch. Be aware it contains some hardcodes. """ - world_size = torch.distributed.get_world_size() - local_rank = torch.distributed.get_rank() - seed_all(1453) - - NUM_MICRO_BATCHS = num_micro_batches - BATCH_SIZE = num_micro_batches - NUM_CHUNKS = 2 + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") # create model + seed_all(1453) torch_model = MlpModel().cuda() - pp_model = copy.deepcopy(torch_model).cuda() - DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 - pg_mesh = ProcessGroupMesh(1, world_size, 1) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True) - schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager) + pg_mesh = ProcessGroupMesh(world_size) + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + ) + schedule = InterleavedSchedule( + stage_manager=stage_manager, + num_model_chunks=num_model_chunk, + num_microbatch=num_microbatch, + ) sharded_model = torch.nn.ModuleList() - for idx, (_, sub_model) in enumerate(pp_model.named_children()): - if idx % (world_size) == local_rank: + for idx, sub_model in enumerate(pp_model.layers): + if idx % world_size == rank: sub_model._forward = sub_model.forward sub_model.forward = MethodType( - partial( - pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model) - ), + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)), sub_model._forward, ) sharded_model.append(sub_model.cuda()) + assert len(sharded_model) == num_model_chunk, "num_model_chunk is not correct" # create optimizer - torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) - pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5)) - # create - seed_all(1453) - if local_rank == 0: - input_list = [torch.rand(BATCH_SIZE, 4).cuda()] - else: - input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] - torch.distributed.all_reduce(input_list[0]) + # create data + seed_all(115) + input_list = [torch.rand(batch_size, DIM).cuda()] + dist.all_reduce(input_list[0]) - criterion = lambda x, y: torch.mean(x) + def criterion(x, *args, **kwargs): + return (x * x).mean() # forward and backward torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output, _) + torch_loss = criterion(torch_output) torch_loss.backward() pp_ret = schedule.forward_backward_step( @@ -115,45 +107,41 @@ def examine_pp(num_micro_batches): ) # check loss - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(-1): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients - torch_grad = [] - for torch_p in torch_model.parameters(): - torch_grad.append(torch_p.grad.data) - - for idx, pp_p in enumerate(sharded_model.parameters()): - if idx < 2: - assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) - else: - assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data) + for i in range(num_model_chunk): + idx = world_size * i + rank + assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) + assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) # step torch_optimizer.step() pp_optimizer.step() # check updated param - torch_param = [] - for torch_p in torch_model.parameters(): - torch_param.append(torch_p.data) - for idx, pp_p in enumerate(sharded_model.parameters()): - if idx < 2: - assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) - else: - assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") - examine_pp() + for i in range(num_model_chunk): + idx = world_size * i + rank + assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) + assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) @pytest.mark.dist +@pytest.mark.parametrize("num_microbatch", [4, 12]) +@pytest.mark.parametrize("batch_size", [12]) +@pytest.mark.parametrize("num_model_chunk", [2, 4]) @rerun_if_address_is_in_use() -def test_pp(): - spawn(run_dist, 4) +def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): + assert NUM_LAYER % num_model_chunk == 0 + spawn( + run_pp, + nprocs=NUM_LAYER // num_model_chunk, + num_microbatch=num_microbatch, + batch_size=batch_size, + num_model_chunk=num_model_chunk, + ) if __name__ == "__main__": - test_pp() + test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4) diff --git a/tests/test_shardformer/test_model/test_shard_falcon.py b/tests/test_shardformer/test_model/test_shard_falcon.py new file mode 100644 index 000000000..963045179 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_falcon.py @@ -0,0 +1,202 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + falcon = unwrap_model(org_model, "FalconModel", "transformer") + sharded_falcon = unwrap_model(sharded_model, "FalconModel", "transformer") + + row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"] + col_layer_for_check = ["h[0].self_attention.dense"] + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-6, 1e-5 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + falcon, sharded_falcon, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if org_model.__class__.__name__ == "FalconModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 2e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight(falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_falcon_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_falcon") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +def run_falcon_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_falcon") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_falcon(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_falcon_test() + + +def check_falcon_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_falcon_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_falcon(): + spawn(check_falcon, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_falcon_3d(): + spawn(check_falcon_3d, 8) + + +if __name__ == "__main__": + test_falcon() + test_falcon_3d() diff --git a/tests/test_shardformer/test_model/test_shard_gptj.py b/tests/test_shardformer/test_model/test_shard_gptj.py new file mode 100644 index 000000000..a946aacfd --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_gptj.py @@ -0,0 +1,227 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + gptj = unwrap_model(org_model, "GPTJModel", "transformer") + sharded_gptj = unwrap_model(sharded_model, "GPTJModel", "transformer") + + col_layer_for_check = ["h[0].attn.k_proj"] + row_layer_for_check = ["h[0].mlp.fc_out"] # use dim=0 for wte get_grad_tensors_for_check + + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + col_layer_grads = get_grad_tensors_for_check( + gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + + row_layer_grads = get_grad_tensors_for_check( + gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "GPTJModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + + # check grads + check_all_grad_tensors(grads_to_check) + + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + #'use_lazy_init': True, GPTJ currently do not support lazy init; model training has issue even without sharding + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + #'use_lazy_init': True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + #'use_lazy_init': True, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + #'use_lazy_init': True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + #'use_lazy_init': True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +@clear_cache_before_run() +def run_gptj_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) +@clear_cache_before_run() +def run_gptj_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + torch.cuda.empty_cache() + + +def check_gptj(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_gptj_test() + + +def check_gptj_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_gptj_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gptj(): + spawn(check_gptj, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gptj_3d(): + spawn(check_gptj_3d, 8) + + +if __name__ == "__main__": + test_gptj() + test_gptj_3d() diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py new file mode 100644 index 000000000..07bc91b33 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -0,0 +1,168 @@ +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # unwrap model + mistral_model = unwrap_model(org_model, "MistralModel", "model") + shard_mistral_model = unwrap_model(sharded_model, "MistralModel", "model") + + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] + + # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if test_config["precision"] == "fp32": + atol, rtol = 5e-5, 1e-4 + else: + atol, rtol = 5e-3, 5e-3 + row_layer_grads = get_grad_tensors_for_check( + mistral_model, + shard_mistral_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) + col_layer_grads = get_grad_tensors_for_check( + mistral_model, + shard_mistral_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == "MistralModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if stage_manager is None or stage_manager.is_first_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + check_weight( + mistral_model, + shard_mistral_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) + + # check grads + check_all_grad_tensors(grads_to_check) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_mistral_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_mistral") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_mistral(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_mistral_test() + + +@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_mistral(): + spawn(check_mistral, 4) + + +if __name__ == "__main__": + test_mistral()