mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
[shardformer, pipeline] add gradient_checkpointing_ratio
and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests
This commit is contained in:
parent
df5e9c53cf
commit
e614aa34f3
@ -109,8 +109,8 @@ class MixtralPolicy(Policy):
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
@ -129,10 +129,10 @@ class MixtralPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
@ -26,7 +26,7 @@ from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
|
||||
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
|
||||
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
|
||||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
"""
|
||||
|
||||
@ -969,6 +970,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -1043,6 +1045,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
@ -114,12 +114,12 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.word_embeddings)
|
||||
held_layers.append(module.word_embeddings_layernorm)
|
||||
held_layers.append(self.model.lm_head)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
|
@ -69,11 +69,11 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embedding)
|
||||
held_layers.append(module.output_layer)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.encoder.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
if module.encoder.post_layer_norm:
|
||||
|
@ -194,11 +194,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
held_layers.append(self.model.lm_head)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
@ -29,6 +30,8 @@ class PipelineStageManager:
|
||||
) -> None:
|
||||
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
|
||||
|
||||
self.num_layers_per_stage = None
|
||||
|
||||
self.pg_mesh = pg_mesh
|
||||
self.pipeline_axis = pipeline_axis
|
||||
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||
@ -69,6 +72,88 @@ class PipelineStageManager:
|
||||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
|
||||
@property
|
||||
def control_distribute_layers(self) -> bool:
|
||||
return self.num_layers_per_stage is not None
|
||||
|
||||
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
|
||||
"""Set the distribution configuration.
|
||||
This allows user to customize the number of layers for each stage.
|
||||
|
||||
Args:
|
||||
num_model_layers (int): Number of layers in the model.
|
||||
num_layers_per_stage (List[int]): Number of layers for each stage.
|
||||
"""
|
||||
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
|
||||
assert sum(num_layers_per_stage) == num_model_layers
|
||||
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
|
||||
self.num_model_layers = num_model_layers
|
||||
self.num_layers_per_stage = num_layers_per_stage
|
||||
|
||||
def distribute_layers(
|
||||
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
|
||||
) -> List[int]:
|
||||
"""Divide layers into stages"""
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
num_model_chunks = (
|
||||
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||
)
|
||||
|
||||
if self.control_distribute_layers:
|
||||
assert num_layers == self.num_model_layers
|
||||
return self.num_layers_per_stage
|
||||
|
||||
else:
|
||||
quotient = num_layers // (num_stages * num_model_chunks)
|
||||
remainder = num_layers % (num_stages * num_model_chunks)
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
||||
def get_stage_index(
|
||||
self,
|
||||
layers_per_stage: List[int],
|
||||
stage: Optional[int] = None,
|
||||
num_model_chunks: Optional[int] = None,
|
||||
num_stages: Optional[int] = None,
|
||||
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
|
||||
"""
|
||||
Get the start index and end index of layers for each stage.
|
||||
|
||||
Args:
|
||||
layers_per_stage (List[int]): number of layers for each stage
|
||||
stage (int): the stage index
|
||||
num_stages (int): number of stages
|
||||
num_model_chunks (int): number of model chunks
|
||||
|
||||
Returns:
|
||||
- Tuple[int, int]: the start index and end index of this stage
|
||||
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
|
||||
|
||||
"""
|
||||
stage = self.stage if stage is None else stage
|
||||
num_model_chunks = (
|
||||
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||
)
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
|
||||
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
||||
|
||||
stage_indices = []
|
||||
for model_chunk in range(num_model_chunks):
|
||||
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
|
||||
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
|
||||
stage_indices.append([start_idx, end_idx])
|
||||
|
||||
return stage_indices[0] if num_model_chunks == 1 else stage_indices
|
||||
|
||||
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
|
||||
"""Is the current stage the first stage.
|
||||
|
||||
|
@ -1 +1 @@
|
||||
from .shard import ShardConfig, ShardFormer
|
||||
from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
|
@ -138,13 +138,25 @@ class LlamaPipelineForwards:
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
@ -2,9 +2,8 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
@ -196,49 +195,3 @@ class Policy(ABC):
|
||||
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
|
||||
"""
|
||||
return []
|
||||
|
||||
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
|
||||
"""Divide layers into stages"""
|
||||
quotient = num_layers // num_stages
|
||||
remainder = num_layers % num_stages
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = num_stages // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
||||
def get_stage_index(
|
||||
self,
|
||||
layers_per_stage: List[int],
|
||||
stage: int,
|
||||
num_model_chunks: int = 1,
|
||||
num_stages: int = 0,
|
||||
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
|
||||
"""
|
||||
Get the start index and end index of layers for each stage.
|
||||
|
||||
Args:
|
||||
layers_per_stage (List[int]): number of layers for each stage
|
||||
stage (int): the stage index
|
||||
num_stages (int): number of stages
|
||||
num_model_chunks (int): number of model chunks
|
||||
|
||||
Returns:
|
||||
- Tuple[int, int]: the start index and end index of this stage
|
||||
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
|
||||
|
||||
"""
|
||||
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
||||
|
||||
stage_indices = []
|
||||
for model_chunk in range(num_model_chunks):
|
||||
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
|
||||
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
|
||||
stage_indices.append([start_idx, end_idx])
|
||||
|
||||
return stage_indices[0] if num_model_chunks == 1 else stage_indices
|
||||
|
@ -279,16 +279,8 @@ class BertPolicy(Policy):
|
||||
module = self.model.bert
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.encoder.layer),
|
||||
stage_manager.num_stages * stage_manager.num_model_chunks,
|
||||
)
|
||||
stage_manager.stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward,
|
||||
@ -298,8 +290,8 @@ class BertPolicy(Policy):
|
||||
}
|
||||
|
||||
else:
|
||||
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward,
|
||||
@ -324,16 +316,8 @@ class BertPolicy(Policy):
|
||||
held_layers = []
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.encoder.layer),
|
||||
stage_manager.num_stages * stage_manager.num_model_chunks,
|
||||
)
|
||||
stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embeddings)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
@ -342,10 +326,10 @@ class BertPolicy(Policy):
|
||||
held_layers.append(module.pooler)
|
||||
|
||||
else:
|
||||
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.pooler)
|
||||
|
@ -203,8 +203,8 @@ class BloomPolicy(Policy):
|
||||
else:
|
||||
module = self.model.transformer
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
@ -226,11 +226,11 @@ class BloomPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.word_embeddings)
|
||||
held_layers.append(module.word_embeddings_layernorm)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
|
@ -179,10 +179,10 @@ class ChatGLMPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embedding)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.encoder.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
if module.encoder.post_layer_norm:
|
||||
@ -204,8 +204,8 @@ class ChatGLMPolicy(Policy):
|
||||
else:
|
||||
module = self.model.transformer
|
||||
|
||||
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
|
@ -161,8 +161,8 @@ class FalconPolicy(Policy):
|
||||
else:
|
||||
module = self.model.transformer
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
@ -181,10 +181,10 @@ class FalconPolicy(Policy):
|
||||
module = self.model.transformer
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.word_embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
|
@ -185,15 +185,8 @@ class GPT2Policy(Policy):
|
||||
held_layers = []
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.wte)
|
||||
held_layers.append(module.wpe)
|
||||
@ -203,12 +196,12 @@ class GPT2Policy(Policy):
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(module.ln_f)
|
||||
else:
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.wte)
|
||||
held_layers.append(module.wpe)
|
||||
held_layers.append(module.drop)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
@ -226,15 +219,8 @@ class GPT2Policy(Policy):
|
||||
module = self.model.transformer
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_manager.stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward,
|
||||
@ -243,8 +229,8 @@ class GPT2Policy(Policy):
|
||||
)
|
||||
}
|
||||
else:
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward,
|
||||
|
@ -179,11 +179,11 @@ class GPTJPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.wte)
|
||||
held_layers.append(module.drop)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
@ -200,8 +200,8 @@ class GPTJPolicy(Policy):
|
||||
else:
|
||||
module = self.model.transformer
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward,
|
||||
|
@ -164,30 +164,20 @@ class LlamaPolicy(Policy):
|
||||
module = self.model.model
|
||||
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_manager.stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||
}
|
||||
|
||||
else:
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
@ -204,15 +194,8 @@ class LlamaPolicy(Policy):
|
||||
held_layers = []
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
num_stages=stage_manager.num_stages,
|
||||
)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
@ -221,10 +204,10 @@ class LlamaPolicy(Policy):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
@ -186,12 +186,12 @@ class OPTPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
held_layers.append(module.embed_positions)
|
||||
held_layers.append(module.project_in)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.final_layer_norm)
|
||||
@ -208,8 +208,8 @@ class OPTPolicy(Policy):
|
||||
else:
|
||||
module = self.model.model.decoder
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward,
|
||||
|
@ -251,6 +251,8 @@ class T5BasePolicy(Policy):
|
||||
Return the layer distribution as a list and the starting stage of decoder.
|
||||
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
|
||||
"""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
assert stage_manager is not None, "Pipeline stage manager is not set."
|
||||
|
||||
# number of encoder layers must be a positive integer
|
||||
if num_encoder_layers <= 0:
|
||||
@ -262,7 +264,7 @@ class T5BasePolicy(Policy):
|
||||
|
||||
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
|
||||
if num_decoder_layers == 0:
|
||||
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||
return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||
|
||||
# the number of stages distributed between encoder and decoder is optimized in this way:
|
||||
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
||||
@ -273,21 +275,26 @@ class T5BasePolicy(Policy):
|
||||
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
||||
num_decoder_stages = num_stages - num_encoder_stages
|
||||
|
||||
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||
encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||
decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||
return encoder_distribution + decoder_distribution, num_encoder_stages
|
||||
|
||||
def get_t5_stage_index(
|
||||
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
||||
) -> Tuple[bool, int, int]:
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
||||
Return the starting/ending idx of layers in encoder/decoder
|
||||
"""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
assert stage_manager is not None, "Pipeline stage manager is not set."
|
||||
|
||||
if stage < decoder_starting_stage:
|
||||
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||
return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||
else:
|
||||
return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
|
||||
return stage_manager.get_stage_index(
|
||||
layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage
|
||||
)
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
|
@ -134,10 +134,10 @@ class ViTPolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embeddings)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||
return held_layers
|
||||
|
||||
@ -149,8 +149,8 @@ class ViTPolicy(Policy):
|
||||
else:
|
||||
module = self.model.vit
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
|
@ -300,6 +300,8 @@ class WhisperPolicy(Policy):
|
||||
Return the layer distribution as a list and the starting stage of decoder.
|
||||
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
|
||||
"""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
assert stage_manager is not None, "pipeline_stage_manager is None"
|
||||
|
||||
# number of encoder layers must be a positive integer
|
||||
if num_encoder_layers <= 0:
|
||||
@ -311,7 +313,7 @@ class WhisperPolicy(Policy):
|
||||
|
||||
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
|
||||
if num_decoder_layers == 0:
|
||||
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||
return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||
|
||||
# the number of stages distributed between encoder and decoder is optimized in this way:
|
||||
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
||||
@ -322,21 +324,24 @@ class WhisperPolicy(Policy):
|
||||
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
||||
num_decoder_stages = num_stages - num_encoder_stages
|
||||
|
||||
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||
encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||
decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||
return encoder_distribution + decoder_distribution, num_encoder_stages
|
||||
|
||||
def get_whisper_stage_index(
|
||||
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
||||
) -> Tuple[bool, int, int]:
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
||||
Return the starting/ending idx of layers in encoder/decoder
|
||||
"""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
assert stage_manager is not None, "pipeline_stage_manager is None"
|
||||
|
||||
if stage < decoder_starting_stage:
|
||||
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||
return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||
else:
|
||||
return self.get_stage_index(
|
||||
return stage_manager.get_stage_index(
|
||||
layers_per_stage[decoder_starting_stage:],
|
||||
stage - decoder_starting_stage,
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
|
||||
from .shard_config import ShardConfig
|
||||
from .sharder import ModelSharder
|
||||
from .shardformer import ShardFormer
|
||||
|
||||
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"]
|
||||
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"]
|
||||
|
87
colossalai/shardformer/shard/grad_ckpt_config.py
Normal file
87
colossalai/shardformer/shard/grad_ckpt_config.py
Normal file
@ -0,0 +1,87 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradientCheckpointConfig:
|
||||
gradient_checkpointing_ratio: float = 0.0
|
||||
|
||||
def get_num_ckpt_layers(self, num_layers: int) -> int:
|
||||
return int(self.gradient_checkpointing_ratio * num_layers)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
|
||||
r"""
|
||||
The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism.
|
||||
Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism.
|
||||
Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details.
|
||||
|
||||
It provides the following features:
|
||||
1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing.
|
||||
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
|
||||
|
||||
"""
|
||||
"""
|
||||
Args:
|
||||
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
|
||||
num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check.
|
||||
num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check.
|
||||
num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check.
|
||||
num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None.
|
||||
|
||||
Example 1:
|
||||
num_stages = 8
|
||||
num_layers = 80
|
||||
num_model_chunks = 1
|
||||
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
|
||||
num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0]
|
||||
|
||||
Example 2:
|
||||
num_stages = 4
|
||||
num_layers = 80
|
||||
num_model_chunks = 2
|
||||
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
|
||||
# device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers
|
||||
...
|
||||
|
||||
"""
|
||||
num_stages: Optional[int] = None
|
||||
num_model_chunks: Optional[int] = None
|
||||
num_model_layers: Optional[int] = None
|
||||
num_ckpt_layers_per_stage: Optional[List[int]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self._enable_gradient_checkpointing_ratio:
|
||||
if not (0 <= self.gradient_checkpointing_ratio <= 1):
|
||||
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
|
||||
|
||||
if self._enable_customized_ckpt_layers_per_stage:
|
||||
assert (
|
||||
self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
|
||||
)
|
||||
assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
|
||||
assert all(
|
||||
[0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
|
||||
)
|
||||
self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
|
||||
|
||||
@property
|
||||
def _enable_gradient_checkpointing_ratio(self) -> bool:
|
||||
return self.gradient_checkpointing_ratio is not None
|
||||
|
||||
@property
|
||||
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
|
||||
return self.num_ckpt_layers_per_stage is not None
|
||||
|
||||
def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
|
||||
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
|
||||
raise RuntimeError("No checkpointed layers information is provided")
|
||||
|
||||
if self._enable_customized_ckpt_layers_per_stage:
|
||||
assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
|
||||
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
|
||||
assert num_ckpt_layers <= num_layers
|
||||
return num_ckpt_layers
|
||||
else:
|
||||
return int(self.gradient_checkpointing_ratio * num_layers)
|
@ -6,6 +6,8 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from .grad_ckpt_config import GradientCheckpointConfig
|
||||
|
||||
__all__ = ["ShardConfig"]
|
||||
|
||||
|
||||
@ -23,6 +25,7 @@ class ShardConfig:
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
||||
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||
"""
|
||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||
@ -35,6 +38,7 @@ class ShardConfig:
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
parallel_output: bool = True
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
# TODO padding vocab
|
||||
# make_vocab_size_divisible_by: int = 128
|
||||
|
@ -1,4 +1,3 @@
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
@ -21,7 +20,6 @@ __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
|
||||
|
||||
|
||||
class OpenMoePolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
@ -43,7 +41,8 @@ class OpenMoePolicy(Policy):
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
|
||||
@ -97,8 +96,8 @@ class OpenMoePolicy(Policy):
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
@ -117,10 +116,10 @@ class OpenMoePolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
@ -143,7 +142,6 @@ class OpenMoePolicy(Policy):
|
||||
|
||||
|
||||
class OpenMoeModelPolicy(OpenMoePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -169,21 +167,21 @@ class OpenMoeModelPolicy(OpenMoePolicy):
|
||||
|
||||
|
||||
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
OpenMoeForCausalLM:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
OpenMoeForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
])
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
@ -208,13 +206,17 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1):
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}]
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
@ -247,12 +249,13 @@ class OpenMoePipelineForwards:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
@ -320,7 +323,8 @@ class OpenMoePipelineForwards:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
@ -333,12 +337,11 @@ class OpenMoePipelineForwards:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
@ -384,14 +387,16 @@ class OpenMoePipelineForwards:
|
||||
router_z_loss = past_router_z_loss + router_z_loss
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
return tuple([
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
router_aux_loss,
|
||||
router_z_loss,
|
||||
])
|
||||
return tuple(
|
||||
[
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
router_aux_loss,
|
||||
router_z_loss,
|
||||
]
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
@ -445,10 +450,11 @@ class OpenMoePipelineForwards:
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
@ -504,7 +510,6 @@ class OpenMoePipelineForwards:
|
||||
if chunk_head == True:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
logits = module(inputs[0])
|
||||
logits = logits.float()
|
||||
@ -522,8 +527,8 @@ class OpenMoePipelineForwards:
|
||||
for batch_idx in range(hidden_states.shape[0]):
|
||||
loss = loss + torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.lm_head),
|
||||
hidden_states[batch_idx:batch_idx + 1, :],
|
||||
labels[batch_idx:batch_idx + 1, :],
|
||||
hidden_states[batch_idx : batch_idx + 1, :],
|
||||
labels[batch_idx : batch_idx + 1, :],
|
||||
)
|
||||
logits = None
|
||||
else:
|
||||
|
@ -49,9 +49,9 @@ if HAS_LLAMA:
|
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||
|
||||
config = LlamaConfig(
|
||||
num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_hidden_layers=8,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16,
|
||||
|
@ -1,4 +1,23 @@
|
||||
import random
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
||||
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||
|
||||
|
||||
class _ShardConfig(ShardConfig):
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class _PipelineStageManager(PipelineStageManager):
|
||||
def __init__(self):
|
||||
self.is_interleave = False
|
||||
self.num_layers_per_stage = None
|
||||
|
||||
@property
|
||||
def num_stages(self):
|
||||
return random.randint(5, 10)
|
||||
|
||||
|
||||
def test_t5_pipeline_distribution():
|
||||
@ -10,7 +29,10 @@ def test_t5_pipeline_distribution():
|
||||
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
||||
}
|
||||
|
||||
stage_manager = _PipelineStageManager()
|
||||
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||
policy = T5BasePolicy()
|
||||
policy.set_shard_config(shard_config)
|
||||
for i in range(num_test_cases):
|
||||
_, decoder_starting_stage = policy.distribute_t5_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
@ -35,7 +57,10 @@ def test_t5_pipeline_layers():
|
||||
}
|
||||
|
||||
for i in range(num_test_cases):
|
||||
stage_manager = _PipelineStageManager()
|
||||
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||
policy = T5BasePolicy()
|
||||
policy.set_shard_config(shard_config)
|
||||
layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
test_dict["num_decoder_layers"][i],
|
||||
|
@ -1,4 +1,23 @@
|
||||
import random
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.whisper import WhisperPolicy
|
||||
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||
|
||||
|
||||
class _ShardConfig(ShardConfig):
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class _PipelineStageManager(PipelineStageManager):
|
||||
def __init__(self):
|
||||
self.is_interleave = False
|
||||
self.num_layers_per_stage = None
|
||||
|
||||
@property
|
||||
def num_stages(self):
|
||||
return random.randint(5, 10)
|
||||
|
||||
|
||||
def test_whisper_pipeline_distribution():
|
||||
@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution():
|
||||
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
||||
}
|
||||
|
||||
stage_manager = _PipelineStageManager()
|
||||
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||
policy = WhisperPolicy()
|
||||
policy.set_shard_config(shard_config)
|
||||
for i in range(num_test_cases):
|
||||
_, decoder_starting_stage = policy.distribute_whisper_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
@ -34,7 +56,10 @@ def test_whisper_pipeline_layers():
|
||||
],
|
||||
}
|
||||
|
||||
stage_manager = _PipelineStageManager()
|
||||
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||
policy = WhisperPolicy()
|
||||
policy.set_shard_config(shard_config)
|
||||
for i in range(num_test_cases):
|
||||
layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
@ -24,9 +25,13 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config
|
||||
)
|
||||
if enable_gradient_checkpointing:
|
||||
org_model.gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
@ -101,6 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
@ -108,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"num_microbatches": 4,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
|
||||
),
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
@ -189,6 +200,13 @@ def run_llama_test(test_config):
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_stages=2,
|
||||
num_model_chunks=2,
|
||||
num_model_layers=8,
|
||||
num_ckpt_layers_per_stage=[0, 1, 2, 2],
|
||||
),
|
||||
},
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user