[pipeline] Llama causal lm and llama for sequence classification pipeline (#4208)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision
This commit is contained in:
Jianghai
2023-07-11 15:23:33 +08:00
committed by Hongxin Liu
parent 1622031058
commit 31bcf867ae
4 changed files with 109 additions and 21 deletions

View File

@@ -162,6 +162,24 @@ class Policy(ABC):
return policy
def append_or_create_method_replacement(
self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription],
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Append or create a new method replacement description to the policy for the given key.
Args:
description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
target_key (Union[str, nn.Module]): the key of the policy to be updated
"""
if target_key in policy:
policy[target_key].method_replacement.update(description)
else:
policy[target_key] = ModulePolicyDescription(method_replacement=description)
return policy
def get_held_layers(self) -> List[Module]:
"""Get layers that should be held in current stage. This method should be implemented by subclass.

View File

@@ -131,17 +131,20 @@ class LlamaModelPolicy(LlamaPolicy):
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel
if self.pipeline_stage_manager:
# set None as default
stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
module_policy[LlamaModel] = ModulePolicyDescription(method_replacement={
method_replacement = {
'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index)
})
return module_policy
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaModel)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
@@ -158,7 +161,7 @@ class LlamaModelPolicy(LlamaPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in bert model"""
"""No shared params in llama model"""
return []
@@ -179,8 +182,43 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
])
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward': partial(llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaForCausalLM)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.model.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.model.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.model.norm)
held_layers.append(module.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
llama_model = self.model.model
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
# tie weights
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}]
return []
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
@@ -199,8 +237,42 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
])
}
policy.update(new_item)
# to be confirmed
if self.pipeline_stage_manager:
# set None as default
stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(llama_for_sequence_classification_forward,
stage_manager=stage_manager,
stage_index=stage_index)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaForSequenceClassification)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.model.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.model.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.model.norm)
held_layers.append(module.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []
def llama_model_forward(
self: LlamaModel,