[shardformer, pipeline] add gradient_checkpointing_ratio and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests
This commit is contained in:
Wenhao Chen
2024-04-01 11:34:58 +08:00
committed by GitHub
parent df5e9c53cf
commit e614aa34f3
28 changed files with 396 additions and 213 deletions

View File

@@ -1,4 +1,3 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Union
@@ -21,7 +20,6 @@ __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
class OpenMoePolicy(Policy):
def config_sanity_check(self):
pass
@@ -43,7 +41,8 @@ class OpenMoePolicy(Policy):
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
@@ -97,8 +96,8 @@ class OpenMoePolicy(Policy):
else:
module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
@@ -117,10 +116,10 @@ class OpenMoePolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
@@ -143,7 +142,6 @@ class OpenMoePolicy(Policy):
class OpenMoeModelPolicy(OpenMoePolicy):
def __init__(self) -> None:
super().__init__()
@@ -169,21 +167,21 @@ class OpenMoeModelPolicy(OpenMoePolicy):
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
OpenMoeForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
OpenMoeForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
])
]
)
}
policy.update(new_item)
@@ -208,13 +206,17 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1):
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}]
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
@@ -247,12 +249,13 @@ class OpenMoePipelineForwards:
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
@@ -320,7 +323,8 @@ class OpenMoePipelineForwards:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
@@ -333,12 +337,11 @@ class OpenMoePipelineForwards:
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
@@ -384,14 +387,16 @@ class OpenMoePipelineForwards:
router_z_loss = past_router_z_loss + router_z_loss
if stage_manager.is_last_stage():
return tuple([
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
router_aux_loss,
router_z_loss,
])
return tuple(
[
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
router_aux_loss,
router_z_loss,
]
)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
@@ -445,10 +450,11 @@ class OpenMoePipelineForwards:
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
@@ -504,7 +510,6 @@ class OpenMoePipelineForwards:
if chunk_head == True:
def create_custom_forward(module):
def custom_forward(*inputs):
logits = module(inputs[0])
logits = logits.float()
@@ -522,8 +527,8 @@ class OpenMoePipelineForwards:
for batch_idx in range(hidden_states.shape[0]):
loss = loss + torch.utils.checkpoint.checkpoint(
create_custom_forward(self.lm_head),
hidden_states[batch_idx:batch_idx + 1, :],
labels[batch_idx:batch_idx + 1, :],
hidden_states[batch_idx : batch_idx + 1, :],
labels[batch_idx : batch_idx + 1, :],
)
logits = None
else: