[shardformer, pipeline] add gradient_checkpointing_ratio and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests
This commit is contained in:
Wenhao Chen 2024-04-01 11:34:58 +08:00 committed by GitHub
parent df5e9c53cf
commit e614aa34f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 396 additions and 213 deletions

View File

@ -109,8 +109,8 @@ class MixtralPolicy(Policy):
else:
module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
@ -129,10 +129,10 @@ class MixtralPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)

View File

@ -26,7 +26,7 @@ from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
"""
@ -969,6 +970,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
) -> None:
super().__init__()
@ -1043,6 +1045,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(
initial_scale=initial_scale,

View File

@ -114,12 +114,12 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)

View File

@ -69,11 +69,11 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
held_layers.append(module.output_layer)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:

View File

@ -194,11 +194,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)

View File

@ -1,6 +1,7 @@
import contextlib
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup
@ -29,6 +30,8 @@ class PipelineStageManager:
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
self.num_layers_per_stage = None
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
@ -69,6 +72,88 @@ class PipelineStageManager:
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
@property
def control_distribute_layers(self) -> bool:
return self.num_layers_per_stage is not None
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
"""Set the distribution configuration.
This allows user to customize the number of layers for each stage.
Args:
num_model_layers (int): Number of layers in the model.
num_layers_per_stage (List[int]): Number of layers for each stage.
"""
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
assert sum(num_layers_per_stage) == num_model_layers
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
self.num_model_layers = num_model_layers
self.num_layers_per_stage = num_layers_per_stage
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
"""Divide layers into stages"""
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
if self.control_distribute_layers:
assert num_layers == self.num_model_layers
return self.num_layers_per_stage
else:
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def get_stage_index(
self,
layers_per_stage: List[int],
stage: Optional[int] = None,
num_model_chunks: Optional[int] = None,
num_stages: Optional[int] = None,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.
Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks
Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
"""
stage = self.stage if stage is None else stage
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
num_stages = self.num_stages if num_stages is None else num_stages
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
stage_indices = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])
return stage_indices[0] if num_model_chunks == 1 else stage_indices
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage.

View File

@ -1 +1 @@
from .shard import ShardConfig, ShardFormer
from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer

View File

@ -138,13 +138,25 @@ class LlamaPipelineForwards:
next_decoder_cache = () if use_cache else None
start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_layers=end_idx - start_idx,
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if idx - start_idx < num_ckpt_layers:
def create_custom_forward(module):
def custom_forward(*inputs):

View File

@ -2,9 +2,8 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
@ -196,49 +195,3 @@ class Policy(ABC):
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages
# deal with the rest layers
if remainder > 0:
start_position = num_stages // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def get_stage_index(
self,
layers_per_stage: List[int],
stage: int,
num_model_chunks: int = 1,
num_stages: int = 0,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.
Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks
Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
stage_indices = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])
return stage_indices[0] if num_model_chunks == 1 else stage_indices

View File

@ -279,16 +279,8 @@ class BertPolicy(Policy):
module = self.model.bert
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,
@ -298,8 +290,8 @@ class BertPolicy(Policy):
}
else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,
@ -324,16 +316,8 @@ class BertPolicy(Policy):
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
@ -342,10 +326,10 @@ class BertPolicy(Policy):
held_layers.append(module.pooler)
else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.pooler)

View File

@ -203,8 +203,8 @@ class BloomPolicy(Policy):
else:
module = self.model.transformer
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
@ -226,11 +226,11 @@ class BloomPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)

View File

@ -179,10 +179,10 @@ class ChatGLMPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:
@ -204,8 +204,8 @@ class ChatGLMPolicy(Policy):
else:
module = self.model.transformer
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config

View File

@ -161,8 +161,8 @@ class FalconPolicy(Policy):
else:
module = self.model.transformer
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
@ -181,10 +181,10 @@ class FalconPolicy(Policy):
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)

View File

@ -185,15 +185,8 @@ class GPT2Policy(Policy):
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.wte)
held_layers.append(module.wpe)
@ -203,12 +196,12 @@ class GPT2Policy(Policy):
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.ln_f)
else:
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.wpe)
held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
@ -226,15 +219,8 @@ class GPT2Policy(Policy):
module = self.model.transformer
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,
@ -243,8 +229,8 @@ class GPT2Policy(Policy):
)
}
else:
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,

View File

@ -179,11 +179,11 @@ class GPTJPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
@ -200,8 +200,8 @@ class GPTJPolicy(Policy):
else:
module = self.model.transformer
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,

View File

@ -164,30 +164,20 @@ class LlamaPolicy(Policy):
module = self.model.model
if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
}
else:
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
)
}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@ -204,15 +194,8 @@ class LlamaPolicy(Policy):
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
@ -221,10 +204,10 @@ class LlamaPolicy(Policy):
held_layers.append(module.norm)
else:
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)

View File

@ -186,12 +186,12 @@ class OPTPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(module.embed_positions)
held_layers.append(module.project_in)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.final_layer_norm)
@ -208,8 +208,8 @@ class OPTPolicy(Policy):
else:
module = self.model.model.decoder
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,

View File

@ -251,6 +251,8 @@ class T5BasePolicy(Policy):
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
stage_manager = self.pipeline_stage_manager
assert stage_manager is not None, "Pipeline stage manager is not set."
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
@ -262,7 +264,7 @@ class T5BasePolicy(Policy):
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optimized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
@ -273,21 +275,26 @@ class T5BasePolicy(Policy):
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
def get_t5_stage_index(
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
) -> Tuple[bool, int, int]:
) -> Tuple[int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
stage_manager = self.pipeline_stage_manager
assert stage_manager is not None, "Pipeline stage manager is not set."
if stage < decoder_starting_stage:
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
return stage_manager.get_stage_index(
layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage
)
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""

View File

@ -134,10 +134,10 @@ class ViTPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
return held_layers
@ -149,8 +149,8 @@ class ViTPolicy(Policy):
else:
module = self.model.vit
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls

View File

@ -300,6 +300,8 @@ class WhisperPolicy(Policy):
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
stage_manager = self.pipeline_stage_manager
assert stage_manager is not None, "pipeline_stage_manager is None"
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
@ -311,7 +313,7 @@ class WhisperPolicy(Policy):
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optimized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
@ -322,21 +324,24 @@ class WhisperPolicy(Policy):
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
def get_whisper_stage_index(
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
) -> Tuple[bool, int, int]:
) -> Tuple[int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
stage_manager = self.pipeline_stage_manager
assert stage_manager is not None, "pipeline_stage_manager is None"
if stage < decoder_starting_stage:
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
return self.get_stage_index(
return stage_manager.get_stage_index(
layers_per_stage[decoder_starting_stage:],
stage - decoder_starting_stage,
)

View File

@ -1,5 +1,6 @@
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
from .shard_config import ShardConfig
from .sharder import ModelSharder
from .shardformer import ShardFormer
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"]
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"]

View File

@ -0,0 +1,87 @@
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class GradientCheckpointConfig:
gradient_checkpointing_ratio: float = 0.0
def get_num_ckpt_layers(self, num_layers: int) -> int:
return int(self.gradient_checkpointing_ratio * num_layers)
@dataclass
class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
r"""
The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism.
Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism.
Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details.
It provides the following features:
1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing.
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
"""
"""
Args:
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check.
num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check.
num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check.
num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None.
Example 1:
num_stages = 8
num_layers = 80
num_model_chunks = 1
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0]
Example 2:
num_stages = 4
num_layers = 80
num_model_chunks = 2
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
# device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers
...
"""
num_stages: Optional[int] = None
num_model_chunks: Optional[int] = None
num_model_layers: Optional[int] = None
num_ckpt_layers_per_stage: Optional[List[int]] = None
def __post_init__(self):
if self._enable_gradient_checkpointing_ratio:
if not (0 <= self.gradient_checkpointing_ratio <= 1):
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
if self._enable_customized_ckpt_layers_per_stage:
assert (
self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
)
assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
assert all(
[0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
)
self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
@property
def _enable_gradient_checkpointing_ratio(self) -> bool:
return self.gradient_checkpointing_ratio is not None
@property
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
return self.num_ckpt_layers_per_stage is not None
def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
raise RuntimeError("No checkpointed layers information is provided")
if self._enable_customized_ckpt_layers_per_stage:
assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
assert num_ckpt_layers <= num_layers
return num_ckpt_layers
else:
return int(self.gradient_checkpointing_ratio * num_layers)

View File

@ -6,6 +6,8 @@ from torch.distributed import ProcessGroup
from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig
__all__ = ["ShardConfig"]
@ -23,6 +25,7 @@ class ShardConfig:
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
@ -35,6 +38,7 @@ class ShardConfig:
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
parallel_output: bool = True
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# TODO padding vocab
# make_vocab_size_divisible_by: int = 128

View File

@ -1,4 +1,3 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Union
@ -21,7 +20,6 @@ __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
class OpenMoePolicy(Policy):
def config_sanity_check(self):
pass
@ -43,7 +41,8 @@ class OpenMoePolicy(Policy):
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
@ -97,8 +96,8 @@ class OpenMoePolicy(Policy):
else:
module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
@ -117,10 +116,10 @@ class OpenMoePolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
@ -143,7 +142,6 @@ class OpenMoePolicy(Policy):
class OpenMoeModelPolicy(OpenMoePolicy):
def __init__(self) -> None:
super().__init__()
@ -169,21 +167,21 @@ class OpenMoeModelPolicy(OpenMoePolicy):
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
OpenMoeForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
OpenMoeForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
])
]
)
}
policy.update(new_item)
@ -208,13 +206,17 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1):
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}]
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
@ -247,12 +249,13 @@ class OpenMoePipelineForwards:
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
@ -320,7 +323,8 @@ class OpenMoePipelineForwards:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
@ -333,12 +337,11 @@ class OpenMoePipelineForwards:
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
@ -384,14 +387,16 @@ class OpenMoePipelineForwards:
router_z_loss = past_router_z_loss + router_z_loss
if stage_manager.is_last_stage():
return tuple([
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
router_aux_loss,
router_z_loss,
])
return tuple(
[
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
router_aux_loss,
router_z_loss,
]
)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
@ -445,10 +450,11 @@ class OpenMoePipelineForwards:
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
@ -504,7 +510,6 @@ class OpenMoePipelineForwards:
if chunk_head == True:
def create_custom_forward(module):
def custom_forward(*inputs):
logits = module(inputs[0])
logits = logits.float()
@ -522,8 +527,8 @@ class OpenMoePipelineForwards:
for batch_idx in range(hidden_states.shape[0]):
loss = loss + torch.utils.checkpoint.checkpoint(
create_custom_forward(self.lm_head),
hidden_states[batch_idx:batch_idx + 1, :],
labels[batch_idx:batch_idx + 1, :],
hidden_states[batch_idx : batch_idx + 1, :],
labels[batch_idx : batch_idx + 1, :],
)
logits = None
else:

View File

@ -49,9 +49,9 @@ if HAS_LLAMA:
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig(
num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=8,
hidden_size=32,
intermediate_size=64,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16,

View File

@ -1,4 +1,23 @@
import random
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.t5 import T5BasePolicy
from colossalai.shardformer.shard.shard_config import ShardConfig
class _ShardConfig(ShardConfig):
def __post_init__(self):
pass
class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
@property
def num_stages(self):
return random.randint(5, 10)
def test_t5_pipeline_distribution():
@ -10,7 +29,10 @@ def test_t5_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
}
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = T5BasePolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases):
_, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i],
@ -35,7 +57,10 @@ def test_t5_pipeline_layers():
}
for i in range(num_test_cases):
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = T5BasePolicy()
policy.set_shard_config(shard_config)
layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i],
test_dict["num_decoder_layers"][i],

View File

@ -1,4 +1,23 @@
import random
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.whisper import WhisperPolicy
from colossalai.shardformer.shard.shard_config import ShardConfig
class _ShardConfig(ShardConfig):
def __post_init__(self):
pass
class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
@property
def num_stages(self):
return random.randint(5, 10)
def test_whisper_pipeline_distribution():
@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
}
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = WhisperPolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases):
_, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i],
@ -34,7 +56,10 @@ def test_whisper_pipeline_layers():
],
}
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = WhisperPolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i],

View File

@ -5,6 +5,7 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import PipelineGradientCheckpointConfig
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@ -24,9 +25,13 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
if enable_gradient_checkpointing:
org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable()
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
@ -101,6 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
},
{
"tp_size": 1,
@ -108,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 4,
"use_lazy_init": False,
"precision": "fp32",
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
),
},
{
"tp_size": 4,
@ -189,6 +200,13 @@ def run_llama_test(test_config):
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_stages=2,
num_model_chunks=2,
num_model_layers=8,
num_ckpt_layers_per_stage=[0, 1, 2, 2],
),
},
],
)