mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 22:19:47 +00:00
[shardformer, pipeline] add gradient_checkpointing_ratio
and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests
This commit is contained in:
parent
df5e9c53cf
commit
e614aa34f3
@ -109,8 +109,8 @@ class MixtralPolicy(Policy):
|
|||||||
else:
|
else:
|
||||||
module = self.model.model
|
module = self.model.model
|
||||||
|
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description=method_replacement, policy=policy, target_key=model_cls
|
description=method_replacement, policy=policy, target_key=model_cls
|
||||||
@ -129,10 +129,10 @@ class MixtralPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
|
@ -26,7 +26,7 @@ from colossalai.cluster import ProcessGroupMesh
|
|||||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||||
@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
|
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
|
||||||
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
|
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
|
||||||
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
|
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
|
||||||
|
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -969,6 +970,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
custom_policy: Policy = None,
|
custom_policy: Policy = None,
|
||||||
pp_style: str = "1f1b",
|
pp_style: str = "1f1b",
|
||||||
num_model_chunks: int = 1,
|
num_model_chunks: int = 1,
|
||||||
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||||
enable_metadata_cache: bool = True,
|
enable_metadata_cache: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1043,6 +1045,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||||
enable_sequence_overlap=enable_sequence_overlap,
|
enable_sequence_overlap=enable_sequence_overlap,
|
||||||
parallel_output=parallel_output,
|
parallel_output=parallel_output,
|
||||||
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
)
|
)
|
||||||
self.amp_config = dict(
|
self.amp_config = dict(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
|
@ -114,12 +114,12 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.word_embeddings)
|
held_layers.append(module.word_embeddings)
|
||||||
held_layers.append(module.word_embeddings_layernorm)
|
held_layers.append(module.word_embeddings_layernorm)
|
||||||
held_layers.append(self.model.lm_head)
|
held_layers.append(self.model.lm_head)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.h[start_idx:end_idx])
|
held_layers.extend(module.h[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.ln_f)
|
held_layers.append(module.ln_f)
|
||||||
|
@ -69,11 +69,11 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embedding)
|
held_layers.append(module.embedding)
|
||||||
held_layers.append(module.output_layer)
|
held_layers.append(module.output_layer)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.encoder.layers[start_idx:end_idx])
|
held_layers.extend(module.encoder.layers[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
if module.encoder.post_layer_norm:
|
if module.encoder.post_layer_norm:
|
||||||
|
@ -194,11 +194,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
held_layers.append(self.model.lm_head)
|
held_layers.append(self.model.lm_head)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
@ -29,6 +30,8 @@ class PipelineStageManager:
|
|||||||
) -> None:
|
) -> None:
|
||||||
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
|
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
|
||||||
|
|
||||||
|
self.num_layers_per_stage = None
|
||||||
|
|
||||||
self.pg_mesh = pg_mesh
|
self.pg_mesh = pg_mesh
|
||||||
self.pipeline_axis = pipeline_axis
|
self.pipeline_axis = pipeline_axis
|
||||||
self.prev_rank: Optional[Tuple[int, ...]] = None
|
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||||
@ -69,6 +72,88 @@ class PipelineStageManager:
|
|||||||
# for shardformer, hold model chunk id
|
# for shardformer, hold model chunk id
|
||||||
self.model_chunk_id: Optional[int] = None
|
self.model_chunk_id: Optional[int] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def control_distribute_layers(self) -> bool:
|
||||||
|
return self.num_layers_per_stage is not None
|
||||||
|
|
||||||
|
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
|
||||||
|
"""Set the distribution configuration.
|
||||||
|
This allows user to customize the number of layers for each stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_model_layers (int): Number of layers in the model.
|
||||||
|
num_layers_per_stage (List[int]): Number of layers for each stage.
|
||||||
|
"""
|
||||||
|
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
|
||||||
|
assert sum(num_layers_per_stage) == num_model_layers
|
||||||
|
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
|
||||||
|
self.num_model_layers = num_model_layers
|
||||||
|
self.num_layers_per_stage = num_layers_per_stage
|
||||||
|
|
||||||
|
def distribute_layers(
|
||||||
|
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""Divide layers into stages"""
|
||||||
|
num_stages = self.num_stages if num_stages is None else num_stages
|
||||||
|
num_model_chunks = (
|
||||||
|
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.control_distribute_layers:
|
||||||
|
assert num_layers == self.num_model_layers
|
||||||
|
return self.num_layers_per_stage
|
||||||
|
|
||||||
|
else:
|
||||||
|
quotient = num_layers // (num_stages * num_model_chunks)
|
||||||
|
remainder = num_layers % (num_stages * num_model_chunks)
|
||||||
|
|
||||||
|
# calculate the num_layers per stage
|
||||||
|
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||||
|
|
||||||
|
# deal with the rest layers
|
||||||
|
if remainder > 0:
|
||||||
|
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||||
|
for i in range(start_position, start_position + remainder):
|
||||||
|
layers_per_stage[i] += 1
|
||||||
|
return layers_per_stage
|
||||||
|
|
||||||
|
def get_stage_index(
|
||||||
|
self,
|
||||||
|
layers_per_stage: List[int],
|
||||||
|
stage: Optional[int] = None,
|
||||||
|
num_model_chunks: Optional[int] = None,
|
||||||
|
num_stages: Optional[int] = None,
|
||||||
|
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
|
||||||
|
"""
|
||||||
|
Get the start index and end index of layers for each stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layers_per_stage (List[int]): number of layers for each stage
|
||||||
|
stage (int): the stage index
|
||||||
|
num_stages (int): number of stages
|
||||||
|
num_model_chunks (int): number of model chunks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Tuple[int, int]: the start index and end index of this stage
|
||||||
|
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
|
||||||
|
|
||||||
|
"""
|
||||||
|
stage = self.stage if stage is None else stage
|
||||||
|
num_model_chunks = (
|
||||||
|
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||||
|
)
|
||||||
|
num_stages = self.num_stages if num_stages is None else num_stages
|
||||||
|
|
||||||
|
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
||||||
|
|
||||||
|
stage_indices = []
|
||||||
|
for model_chunk in range(num_model_chunks):
|
||||||
|
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
|
||||||
|
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
|
||||||
|
stage_indices.append([start_idx, end_idx])
|
||||||
|
|
||||||
|
return stage_indices[0] if num_model_chunks == 1 else stage_indices
|
||||||
|
|
||||||
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
|
def is_first_stage(self, ignore_chunk: bool = False) -> bool:
|
||||||
"""Is the current stage the first stage.
|
"""Is the current stage the first stage.
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
from .shard import ShardConfig, ShardFormer
|
from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
|
@ -138,13 +138,25 @@ class LlamaPipelineForwards:
|
|||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
num_ckpt_layers = 0
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
num_ckpt_layers = end_idx - start_idx
|
||||||
|
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||||
|
if shard_config.gradient_checkpoint_config is not None:
|
||||||
|
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||||
|
stage=stage_manager.stage,
|
||||||
|
num_layers=end_idx - start_idx,
|
||||||
|
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
|
||||||
|
)
|
||||||
|
assert num_ckpt_layers <= end_idx - start_idx
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if idx - start_idx < num_ckpt_layers:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
@ -2,9 +2,8 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
@ -196,49 +195,3 @@ class Policy(ABC):
|
|||||||
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
|
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
|
|
||||||
"""Divide layers into stages"""
|
|
||||||
quotient = num_layers // num_stages
|
|
||||||
remainder = num_layers % num_stages
|
|
||||||
|
|
||||||
# calculate the num_layers per stage
|
|
||||||
layers_per_stage = [quotient] * num_stages
|
|
||||||
|
|
||||||
# deal with the rest layers
|
|
||||||
if remainder > 0:
|
|
||||||
start_position = num_stages // 2 - remainder // 2
|
|
||||||
for i in range(start_position, start_position + remainder):
|
|
||||||
layers_per_stage[i] += 1
|
|
||||||
return layers_per_stage
|
|
||||||
|
|
||||||
def get_stage_index(
|
|
||||||
self,
|
|
||||||
layers_per_stage: List[int],
|
|
||||||
stage: int,
|
|
||||||
num_model_chunks: int = 1,
|
|
||||||
num_stages: int = 0,
|
|
||||||
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
|
|
||||||
"""
|
|
||||||
Get the start index and end index of layers for each stage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layers_per_stage (List[int]): number of layers for each stage
|
|
||||||
stage (int): the stage index
|
|
||||||
num_stages (int): number of stages
|
|
||||||
num_model_chunks (int): number of model chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Tuple[int, int]: the start index and end index of this stage
|
|
||||||
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
|
|
||||||
|
|
||||||
"""
|
|
||||||
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
|
||||||
|
|
||||||
stage_indices = []
|
|
||||||
for model_chunk in range(num_model_chunks):
|
|
||||||
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
|
|
||||||
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
|
|
||||||
stage_indices.append([start_idx, end_idx])
|
|
||||||
|
|
||||||
return stage_indices[0] if num_model_chunks == 1 else stage_indices
|
|
||||||
|
@ -279,16 +279,8 @@ class BertPolicy(Policy):
|
|||||||
module = self.model.bert
|
module = self.model.bert
|
||||||
|
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
layers_per_stage = self.distribute_layers(
|
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||||
len(module.encoder.layer),
|
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
stage_manager.num_stages * stage_manager.num_model_chunks,
|
|
||||||
)
|
|
||||||
stage_manager.stage_indices = self.get_stage_index(
|
|
||||||
layers_per_stage,
|
|
||||||
stage_manager.stage,
|
|
||||||
num_model_chunks=stage_manager.num_model_chunks,
|
|
||||||
num_stages=stage_manager.num_stages,
|
|
||||||
)
|
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward,
|
new_forward,
|
||||||
@ -298,8 +290,8 @@ class BertPolicy(Policy):
|
|||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward,
|
new_forward,
|
||||||
@ -324,16 +316,8 @@ class BertPolicy(Policy):
|
|||||||
held_layers = []
|
held_layers = []
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = self.distribute_layers(
|
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||||
len(module.encoder.layer),
|
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
stage_manager.num_stages * stage_manager.num_model_chunks,
|
|
||||||
)
|
|
||||||
stage_indices = self.get_stage_index(
|
|
||||||
layers_per_stage,
|
|
||||||
stage_manager.stage,
|
|
||||||
num_model_chunks=stage_manager.num_model_chunks,
|
|
||||||
num_stages=stage_manager.num_stages,
|
|
||||||
)
|
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
held_layers.append(module.embeddings)
|
held_layers.append(module.embeddings)
|
||||||
for start_idx, end_idx in stage_indices:
|
for start_idx, end_idx in stage_indices:
|
||||||
@ -342,10 +326,10 @@ class BertPolicy(Policy):
|
|||||||
held_layers.append(module.pooler)
|
held_layers.append(module.pooler)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embeddings)
|
held_layers.append(module.embeddings)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.pooler)
|
held_layers.append(module.pooler)
|
||||||
|
@ -203,8 +203,8 @@ class BloomPolicy(Policy):
|
|||||||
else:
|
else:
|
||||||
module = self.model.transformer
|
module = self.model.transformer
|
||||||
|
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
@ -226,11 +226,11 @@ class BloomPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.word_embeddings)
|
held_layers.append(module.word_embeddings)
|
||||||
held_layers.append(module.word_embeddings_layernorm)
|
held_layers.append(module.word_embeddings_layernorm)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.h[start_idx:end_idx])
|
held_layers.extend(module.h[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.ln_f)
|
held_layers.append(module.ln_f)
|
||||||
|
@ -179,10 +179,10 @@ class ChatGLMPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embedding)
|
held_layers.append(module.embedding)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.encoder.layers[start_idx:end_idx])
|
held_layers.extend(module.encoder.layers[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
if module.encoder.post_layer_norm:
|
if module.encoder.post_layer_norm:
|
||||||
@ -204,8 +204,8 @@ class ChatGLMPolicy(Policy):
|
|||||||
else:
|
else:
|
||||||
module = self.model.transformer
|
module = self.model.transformer
|
||||||
|
|
||||||
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
|
@ -161,8 +161,8 @@ class FalconPolicy(Policy):
|
|||||||
else:
|
else:
|
||||||
module = self.model.transformer
|
module = self.model.transformer
|
||||||
|
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
@ -181,10 +181,10 @@ class FalconPolicy(Policy):
|
|||||||
module = self.model.transformer
|
module = self.model.transformer
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.word_embeddings)
|
held_layers.append(module.word_embeddings)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.h[start_idx:end_idx])
|
held_layers.extend(module.h[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.ln_f)
|
held_layers.append(module.ln_f)
|
||||||
|
@ -185,15 +185,8 @@ class GPT2Policy(Policy):
|
|||||||
held_layers = []
|
held_layers = []
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = self.distribute_layers(
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
|
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
)
|
|
||||||
stage_indices = self.get_stage_index(
|
|
||||||
layers_per_stage,
|
|
||||||
stage_manager.stage,
|
|
||||||
num_model_chunks=stage_manager.num_model_chunks,
|
|
||||||
num_stages=stage_manager.num_stages,
|
|
||||||
)
|
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
held_layers.append(module.wte)
|
held_layers.append(module.wte)
|
||||||
held_layers.append(module.wpe)
|
held_layers.append(module.wpe)
|
||||||
@ -203,12 +196,12 @@ class GPT2Policy(Policy):
|
|||||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(module.ln_f)
|
held_layers.append(module.ln_f)
|
||||||
else:
|
else:
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.wte)
|
held_layers.append(module.wte)
|
||||||
held_layers.append(module.wpe)
|
held_layers.append(module.wpe)
|
||||||
held_layers.append(module.drop)
|
held_layers.append(module.drop)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.h[start_idx:end_idx])
|
held_layers.extend(module.h[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.ln_f)
|
held_layers.append(module.ln_f)
|
||||||
@ -226,15 +219,8 @@ class GPT2Policy(Policy):
|
|||||||
module = self.model.transformer
|
module = self.model.transformer
|
||||||
|
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
layers_per_stage = self.distribute_layers(
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
|
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
)
|
|
||||||
stage_manager.stage_indices = self.get_stage_index(
|
|
||||||
layers_per_stage,
|
|
||||||
stage_manager.stage,
|
|
||||||
num_model_chunks=stage_manager.num_model_chunks,
|
|
||||||
num_stages=stage_manager.num_stages,
|
|
||||||
)
|
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward,
|
new_forward,
|
||||||
@ -243,8 +229,8 @@ class GPT2Policy(Policy):
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward,
|
new_forward,
|
||||||
|
@ -179,11 +179,11 @@ class GPTJPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.wte)
|
held_layers.append(module.wte)
|
||||||
held_layers.append(module.drop)
|
held_layers.append(module.drop)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.h[start_idx:end_idx])
|
held_layers.extend(module.h[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.ln_f)
|
held_layers.append(module.ln_f)
|
||||||
@ -200,8 +200,8 @@ class GPTJPolicy(Policy):
|
|||||||
else:
|
else:
|
||||||
module = self.model.transformer
|
module = self.model.transformer
|
||||||
|
|
||||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward,
|
new_forward,
|
||||||
|
@ -164,30 +164,20 @@ class LlamaPolicy(Policy):
|
|||||||
module = self.model.model
|
module = self.model.model
|
||||||
|
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
layers_per_stage = self.distribute_layers(
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
)
|
|
||||||
stage_manager.stage_indices = self.get_stage_index(
|
|
||||||
layers_per_stage,
|
|
||||||
stage_manager.stage,
|
|
||||||
num_model_chunks=stage_manager.num_model_chunks,
|
|
||||||
num_stages=stage_manager.num_stages,
|
|
||||||
)
|
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
|
||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description=method_replacement, policy=policy, target_key=model_cls
|
|
||||||
)
|
|
||||||
|
|
||||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
|
||||||
@ -204,15 +194,8 @@ class LlamaPolicy(Policy):
|
|||||||
held_layers = []
|
held_layers = []
|
||||||
if stage_manager.is_interleave:
|
if stage_manager.is_interleave:
|
||||||
assert stage_manager.num_model_chunks is not None
|
assert stage_manager.num_model_chunks is not None
|
||||||
layers_per_stage = self.distribute_layers(
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||||
)
|
|
||||||
stage_indices = self.get_stage_index(
|
|
||||||
layers_per_stage,
|
|
||||||
stage_manager.stage,
|
|
||||||
num_model_chunks=stage_manager.num_model_chunks,
|
|
||||||
num_stages=stage_manager.num_stages,
|
|
||||||
)
|
|
||||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
for start_idx, end_idx in stage_indices:
|
for start_idx, end_idx in stage_indices:
|
||||||
@ -221,10 +204,10 @@ class LlamaPolicy(Policy):
|
|||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
|
@ -186,12 +186,12 @@ class OPTPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
held_layers.append(module.embed_positions)
|
held_layers.append(module.embed_positions)
|
||||||
held_layers.append(module.project_in)
|
held_layers.append(module.project_in)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.final_layer_norm)
|
held_layers.append(module.final_layer_norm)
|
||||||
@ -208,8 +208,8 @@ class OPTPolicy(Policy):
|
|||||||
else:
|
else:
|
||||||
module = self.model.model.decoder
|
module = self.model.model.decoder
|
||||||
|
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward,
|
new_forward,
|
||||||
|
@ -251,6 +251,8 @@ class T5BasePolicy(Policy):
|
|||||||
Return the layer distribution as a list and the starting stage of decoder.
|
Return the layer distribution as a list and the starting stage of decoder.
|
||||||
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
|
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
|
||||||
"""
|
"""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
assert stage_manager is not None, "Pipeline stage manager is not set."
|
||||||
|
|
||||||
# number of encoder layers must be a positive integer
|
# number of encoder layers must be a positive integer
|
||||||
if num_encoder_layers <= 0:
|
if num_encoder_layers <= 0:
|
||||||
@ -262,7 +264,7 @@ class T5BasePolicy(Policy):
|
|||||||
|
|
||||||
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
|
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
|
||||||
if num_decoder_layers == 0:
|
if num_decoder_layers == 0:
|
||||||
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
|
return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||||
|
|
||||||
# the number of stages distributed between encoder and decoder is optimized in this way:
|
# the number of stages distributed between encoder and decoder is optimized in this way:
|
||||||
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
||||||
@ -273,21 +275,26 @@ class T5BasePolicy(Policy):
|
|||||||
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
||||||
num_decoder_stages = num_stages - num_encoder_stages
|
num_decoder_stages = num_stages - num_encoder_stages
|
||||||
|
|
||||||
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
|
encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||||
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
|
decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||||
return encoder_distribution + decoder_distribution, num_encoder_stages
|
return encoder_distribution + decoder_distribution, num_encoder_stages
|
||||||
|
|
||||||
def get_t5_stage_index(
|
def get_t5_stage_index(
|
||||||
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
||||||
) -> Tuple[bool, int, int]:
|
) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
||||||
Return the starting/ending idx of layers in encoder/decoder
|
Return the starting/ending idx of layers in encoder/decoder
|
||||||
"""
|
"""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
assert stage_manager is not None, "Pipeline stage manager is not set."
|
||||||
|
|
||||||
if stage < decoder_starting_stage:
|
if stage < decoder_starting_stage:
|
||||||
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||||
else:
|
else:
|
||||||
return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
|
return stage_manager.get_stage_index(
|
||||||
|
layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage
|
||||||
|
)
|
||||||
|
|
||||||
def get_held_layers(self) -> List[nn.Module]:
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
|
@ -134,10 +134,10 @@ class ViTPolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embeddings)
|
held_layers.append(module.embeddings)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
@ -149,8 +149,8 @@ class ViTPolicy(Policy):
|
|||||||
else:
|
else:
|
||||||
module = self.model.vit
|
module = self.model.vit
|
||||||
|
|
||||||
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
|
method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description=method_replacement, policy=policy, target_key=model_cls
|
description=method_replacement, policy=policy, target_key=model_cls
|
||||||
|
@ -300,6 +300,8 @@ class WhisperPolicy(Policy):
|
|||||||
Return the layer distribution as a list and the starting stage of decoder.
|
Return the layer distribution as a list and the starting stage of decoder.
|
||||||
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
|
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
|
||||||
"""
|
"""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
assert stage_manager is not None, "pipeline_stage_manager is None"
|
||||||
|
|
||||||
# number of encoder layers must be a positive integer
|
# number of encoder layers must be a positive integer
|
||||||
if num_encoder_layers <= 0:
|
if num_encoder_layers <= 0:
|
||||||
@ -311,7 +313,7 @@ class WhisperPolicy(Policy):
|
|||||||
|
|
||||||
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
|
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
|
||||||
if num_decoder_layers == 0:
|
if num_decoder_layers == 0:
|
||||||
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
|
return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||||
|
|
||||||
# the number of stages distributed between encoder and decoder is optimized in this way:
|
# the number of stages distributed between encoder and decoder is optimized in this way:
|
||||||
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
||||||
@ -322,21 +324,24 @@ class WhisperPolicy(Policy):
|
|||||||
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
||||||
num_decoder_stages = num_stages - num_encoder_stages
|
num_decoder_stages = num_stages - num_encoder_stages
|
||||||
|
|
||||||
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
|
encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||||
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
|
decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||||
return encoder_distribution + decoder_distribution, num_encoder_stages
|
return encoder_distribution + decoder_distribution, num_encoder_stages
|
||||||
|
|
||||||
def get_whisper_stage_index(
|
def get_whisper_stage_index(
|
||||||
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
||||||
) -> Tuple[bool, int, int]:
|
) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
||||||
Return the starting/ending idx of layers in encoder/decoder
|
Return the starting/ending idx of layers in encoder/decoder
|
||||||
"""
|
"""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
assert stage_manager is not None, "pipeline_stage_manager is None"
|
||||||
|
|
||||||
if stage < decoder_starting_stage:
|
if stage < decoder_starting_stage:
|
||||||
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||||
else:
|
else:
|
||||||
return self.get_stage_index(
|
return stage_manager.get_stage_index(
|
||||||
layers_per_stage[decoder_starting_stage:],
|
layers_per_stage[decoder_starting_stage:],
|
||||||
stage - decoder_starting_stage,
|
stage - decoder_starting_stage,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
from .sharder import ModelSharder
|
from .sharder import ModelSharder
|
||||||
from .shardformer import ShardFormer
|
from .shardformer import ShardFormer
|
||||||
|
|
||||||
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"]
|
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"]
|
||||||
|
87
colossalai/shardformer/shard/grad_ckpt_config.py
Normal file
87
colossalai/shardformer/shard/grad_ckpt_config.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GradientCheckpointConfig:
|
||||||
|
gradient_checkpointing_ratio: float = 0.0
|
||||||
|
|
||||||
|
def get_num_ckpt_layers(self, num_layers: int) -> int:
|
||||||
|
return int(self.gradient_checkpointing_ratio * num_layers)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
|
||||||
|
r"""
|
||||||
|
The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism.
|
||||||
|
Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism.
|
||||||
|
Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details.
|
||||||
|
|
||||||
|
It provides the following features:
|
||||||
|
1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing.
|
||||||
|
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
|
||||||
|
num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check.
|
||||||
|
num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check.
|
||||||
|
num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check.
|
||||||
|
num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None.
|
||||||
|
|
||||||
|
Example 1:
|
||||||
|
num_stages = 8
|
||||||
|
num_layers = 80
|
||||||
|
num_model_chunks = 1
|
||||||
|
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
|
||||||
|
num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0]
|
||||||
|
|
||||||
|
Example 2:
|
||||||
|
num_stages = 4
|
||||||
|
num_layers = 80
|
||||||
|
num_model_chunks = 2
|
||||||
|
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
|
||||||
|
# device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers
|
||||||
|
...
|
||||||
|
|
||||||
|
"""
|
||||||
|
num_stages: Optional[int] = None
|
||||||
|
num_model_chunks: Optional[int] = None
|
||||||
|
num_model_layers: Optional[int] = None
|
||||||
|
num_ckpt_layers_per_stage: Optional[List[int]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self._enable_gradient_checkpointing_ratio:
|
||||||
|
if not (0 <= self.gradient_checkpointing_ratio <= 1):
|
||||||
|
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
|
||||||
|
|
||||||
|
if self._enable_customized_ckpt_layers_per_stage:
|
||||||
|
assert (
|
||||||
|
self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
|
||||||
|
)
|
||||||
|
assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
|
||||||
|
assert all(
|
||||||
|
[0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
|
||||||
|
)
|
||||||
|
self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _enable_gradient_checkpointing_ratio(self) -> bool:
|
||||||
|
return self.gradient_checkpointing_ratio is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
|
||||||
|
return self.num_ckpt_layers_per_stage is not None
|
||||||
|
|
||||||
|
def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
|
||||||
|
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
|
||||||
|
raise RuntimeError("No checkpointed layers information is provided")
|
||||||
|
|
||||||
|
if self._enable_customized_ckpt_layers_per_stage:
|
||||||
|
assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
|
||||||
|
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
|
||||||
|
assert num_ckpt_layers <= num_layers
|
||||||
|
return num_ckpt_layers
|
||||||
|
else:
|
||||||
|
return int(self.gradient_checkpointing_ratio * num_layers)
|
@ -6,6 +6,8 @@ from torch.distributed import ProcessGroup
|
|||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
|
from .grad_ckpt_config import GradientCheckpointConfig
|
||||||
|
|
||||||
__all__ = ["ShardConfig"]
|
__all__ = ["ShardConfig"]
|
||||||
|
|
||||||
|
|
||||||
@ -23,6 +25,7 @@ class ShardConfig:
|
|||||||
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
||||||
|
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||||
"""
|
"""
|
||||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
@ -35,6 +38,7 @@ class ShardConfig:
|
|||||||
enable_sequence_parallelism: bool = False
|
enable_sequence_parallelism: bool = False
|
||||||
enable_sequence_overlap: bool = False
|
enable_sequence_overlap: bool = False
|
||||||
parallel_output: bool = True
|
parallel_output: bool = True
|
||||||
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
# TODO padding vocab
|
# TODO padding vocab
|
||||||
# make_vocab_size_divisible_by: int = 128
|
# make_vocab_size_divisible_by: int = 128
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import warnings
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
@ -21,7 +20,6 @@ __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
|
|||||||
|
|
||||||
|
|
||||||
class OpenMoePolicy(Policy):
|
class OpenMoePolicy(Policy):
|
||||||
|
|
||||||
def config_sanity_check(self):
|
def config_sanity_check(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -43,7 +41,8 @@ class OpenMoePolicy(Policy):
|
|||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||||
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
|
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
|
||||||
@ -97,8 +96,8 @@ class OpenMoePolicy(Policy):
|
|||||||
else:
|
else:
|
||||||
module = self.model.model
|
module = self.model.model
|
||||||
|
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description=method_replacement, policy=policy, target_key=model_cls
|
description=method_replacement, policy=policy, target_key=model_cls
|
||||||
@ -117,10 +116,10 @@ class OpenMoePolicy(Policy):
|
|||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
held_layers = []
|
held_layers = []
|
||||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
held_layers.append(module.embed_tokens)
|
held_layers.append(module.embed_tokens)
|
||||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||||
held_layers.extend(module.layers[start_idx:end_idx])
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
@ -143,7 +142,6 @@ class OpenMoePolicy(Policy):
|
|||||||
|
|
||||||
|
|
||||||
class OpenMoeModelPolicy(OpenMoePolicy):
|
class OpenMoeModelPolicy(OpenMoePolicy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -169,21 +167,21 @@ class OpenMoeModelPolicy(OpenMoePolicy):
|
|||||||
|
|
||||||
|
|
||||||
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for casual lm
|
||||||
new_item = {
|
new_item = {
|
||||||
OpenMoeForCausalLM:
|
OpenMoeForCausalLM: ModulePolicyDescription(
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(gather_output=True),
|
kwargs=dict(gather_output=True),
|
||||||
)
|
)
|
||||||
])
|
]
|
||||||
|
)
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
|
|
||||||
@ -208,13 +206,17 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
|||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
llama_model = self.model.model
|
llama_model = self.model.model
|
||||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||||
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
if (
|
||||||
and self.pipeline_stage_manager.num_stages > 1):
|
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||||
|
and self.pipeline_stage_manager.num_stages > 1
|
||||||
|
):
|
||||||
# tie weights
|
# tie weights
|
||||||
return [{
|
return [
|
||||||
|
{
|
||||||
0: llama_model.embed_tokens.weight,
|
0: llama_model.embed_tokens.weight,
|
||||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@ -247,12 +249,13 @@ class OpenMoePipelineForwards:
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (output_hidden_states
|
output_hidden_states = (
|
||||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
@ -320,7 +323,8 @@ class OpenMoePipelineForwards:
|
|||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
@ -333,12 +337,11 @@ class OpenMoePipelineForwards:
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs, output_attentions, None)
|
return module(*inputs, output_attentions, None)
|
||||||
@ -384,14 +387,16 @@ class OpenMoePipelineForwards:
|
|||||||
router_z_loss = past_router_z_loss + router_z_loss
|
router_z_loss = past_router_z_loss + router_z_loss
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
return tuple([
|
return tuple(
|
||||||
|
[
|
||||||
hidden_states,
|
hidden_states,
|
||||||
next_cache,
|
next_cache,
|
||||||
all_hidden_states,
|
all_hidden_states,
|
||||||
all_self_attns,
|
all_self_attns,
|
||||||
router_aux_loss,
|
router_aux_loss,
|
||||||
router_z_loss,
|
router_z_loss,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
# always return dict for imediate stage
|
# always return dict for imediate stage
|
||||||
return {
|
return {
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -445,10 +450,11 @@ class OpenMoePipelineForwards:
|
|||||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||||
```"""
|
```"""
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (output_hidden_states
|
output_hidden_states = (
|
||||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -504,7 +510,6 @@ class OpenMoePipelineForwards:
|
|||||||
if chunk_head == True:
|
if chunk_head == True:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
logits = module(inputs[0])
|
logits = module(inputs[0])
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
@ -49,9 +49,9 @@ if HAS_LLAMA:
|
|||||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||||
|
|
||||||
config = LlamaConfig(
|
config = LlamaConfig(
|
||||||
num_hidden_layers=4,
|
num_hidden_layers=8,
|
||||||
hidden_size=128,
|
hidden_size=32,
|
||||||
intermediate_size=256,
|
intermediate_size=64,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
num_labels=16,
|
num_labels=16,
|
||||||
|
@ -1,4 +1,23 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
||||||
|
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _ShardConfig(ShardConfig):
|
||||||
|
def __post_init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _PipelineStageManager(PipelineStageManager):
|
||||||
|
def __init__(self):
|
||||||
|
self.is_interleave = False
|
||||||
|
self.num_layers_per_stage = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_stages(self):
|
||||||
|
return random.randint(5, 10)
|
||||||
|
|
||||||
|
|
||||||
def test_t5_pipeline_distribution():
|
def test_t5_pipeline_distribution():
|
||||||
@ -10,7 +29,10 @@ def test_t5_pipeline_distribution():
|
|||||||
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stage_manager = _PipelineStageManager()
|
||||||
|
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||||
policy = T5BasePolicy()
|
policy = T5BasePolicy()
|
||||||
|
policy.set_shard_config(shard_config)
|
||||||
for i in range(num_test_cases):
|
for i in range(num_test_cases):
|
||||||
_, decoder_starting_stage = policy.distribute_t5_layers(
|
_, decoder_starting_stage = policy.distribute_t5_layers(
|
||||||
test_dict["num_encoder_layers"][i],
|
test_dict["num_encoder_layers"][i],
|
||||||
@ -35,7 +57,10 @@ def test_t5_pipeline_layers():
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i in range(num_test_cases):
|
for i in range(num_test_cases):
|
||||||
|
stage_manager = _PipelineStageManager()
|
||||||
|
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||||
policy = T5BasePolicy()
|
policy = T5BasePolicy()
|
||||||
|
policy.set_shard_config(shard_config)
|
||||||
layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
|
layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
|
||||||
test_dict["num_encoder_layers"][i],
|
test_dict["num_encoder_layers"][i],
|
||||||
test_dict["num_decoder_layers"][i],
|
test_dict["num_decoder_layers"][i],
|
||||||
|
@ -1,4 +1,23 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.policies.whisper import WhisperPolicy
|
from colossalai.shardformer.policies.whisper import WhisperPolicy
|
||||||
|
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _ShardConfig(ShardConfig):
|
||||||
|
def __post_init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _PipelineStageManager(PipelineStageManager):
|
||||||
|
def __init__(self):
|
||||||
|
self.is_interleave = False
|
||||||
|
self.num_layers_per_stage = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_stages(self):
|
||||||
|
return random.randint(5, 10)
|
||||||
|
|
||||||
|
|
||||||
def test_whisper_pipeline_distribution():
|
def test_whisper_pipeline_distribution():
|
||||||
@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution():
|
|||||||
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stage_manager = _PipelineStageManager()
|
||||||
|
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||||
policy = WhisperPolicy()
|
policy = WhisperPolicy()
|
||||||
|
policy.set_shard_config(shard_config)
|
||||||
for i in range(num_test_cases):
|
for i in range(num_test_cases):
|
||||||
_, decoder_starting_stage = policy.distribute_whisper_layers(
|
_, decoder_starting_stage = policy.distribute_whisper_layers(
|
||||||
test_dict["num_encoder_layers"][i],
|
test_dict["num_encoder_layers"][i],
|
||||||
@ -34,7 +56,10 @@ def test_whisper_pipeline_layers():
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stage_manager = _PipelineStageManager()
|
||||||
|
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||||
policy = WhisperPolicy()
|
policy = WhisperPolicy()
|
||||||
|
policy.set_shard_config(shard_config)
|
||||||
for i in range(num_test_cases):
|
for i in range(num_test_cases):
|
||||||
layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
|
layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
|
||||||
test_dict["num_encoder_layers"][i],
|
test_dict["num_encoder_layers"][i],
|
||||||
|
@ -5,6 +5,7 @@ import torch
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
from colossalai.shardformer.layer.utils import Randomizer
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
@ -24,9 +25,13 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
|||||||
|
|
||||||
|
|
||||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
|
||||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
model_fn, loss_fn, test_config
|
model_fn, loss_fn, test_config
|
||||||
)
|
)
|
||||||
|
if enable_gradient_checkpointing:
|
||||||
|
org_model.gradient_checkpointing_enable()
|
||||||
|
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||||
|
|
||||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
@ -101,6 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
@ -108,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||||
|
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
|
||||||
|
),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 4,
|
"tp_size": 4,
|
||||||
@ -189,6 +200,13 @@ def run_llama_test(test_config):
|
|||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
|
"enable_gradient_checkpointing": True,
|
||||||
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||||
|
num_stages=2,
|
||||||
|
num_model_chunks=2,
|
||||||
|
num_model_layers=8,
|
||||||
|
num_ckpt_layers_per_stage=[0, 1, 2, 2],
|
||||||
|
),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user