mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer, pipeline] add gradient_checkpointing_ratio
and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
@@ -21,7 +20,6 @@ __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
|
||||
|
||||
|
||||
class OpenMoePolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
@@ -43,7 +41,8 @@ class OpenMoePolicy(Policy):
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
|
||||
@@ -97,8 +96,8 @@ class OpenMoePolicy(Policy):
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
@@ -117,10 +116,10 @@ class OpenMoePolicy(Policy):
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
@@ -143,7 +142,6 @@ class OpenMoePolicy(Policy):
|
||||
|
||||
|
||||
class OpenMoeModelPolicy(OpenMoePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -169,21 +167,21 @@ class OpenMoeModelPolicy(OpenMoePolicy):
|
||||
|
||||
|
||||
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
OpenMoeForCausalLM:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
OpenMoeForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
])
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
@@ -208,13 +206,17 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1):
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}]
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
@@ -247,12 +249,13 @@ class OpenMoePipelineForwards:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
@@ -320,7 +323,8 @@ class OpenMoePipelineForwards:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
@@ -333,12 +337,11 @@ class OpenMoePipelineForwards:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
@@ -384,14 +387,16 @@ class OpenMoePipelineForwards:
|
||||
router_z_loss = past_router_z_loss + router_z_loss
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
return tuple([
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
router_aux_loss,
|
||||
router_z_loss,
|
||||
])
|
||||
return tuple(
|
||||
[
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
router_aux_loss,
|
||||
router_z_loss,
|
||||
]
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
@@ -445,10 +450,11 @@ class OpenMoePipelineForwards:
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
|
||||
output_hidden_states = (output_hidden_states
|
||||
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
@@ -504,7 +510,6 @@ class OpenMoePipelineForwards:
|
||||
if chunk_head == True:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
logits = module(inputs[0])
|
||||
logits = logits.float()
|
||||
@@ -522,8 +527,8 @@ class OpenMoePipelineForwards:
|
||||
for batch_idx in range(hidden_states.shape[0]):
|
||||
loss = loss + torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.lm_head),
|
||||
hidden_states[batch_idx:batch_idx + 1, :],
|
||||
labels[batch_idx:batch_idx + 1, :],
|
||||
hidden_states[batch_idx : batch_idx + 1, :],
|
||||
labels[batch_idx : batch_idx + 1, :],
|
||||
)
|
||||
logits = None
|
||||
else:
|
||||
|
Reference in New Issue
Block a user