[pipeline] Llama causal lm and llama for sequence classification pipeline (#4208)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision
This commit is contained in:
Jianghai
2023-07-11 15:23:33 +08:00
committed by Hongxin Liu
parent 1622031058
commit 31bcf867ae
4 changed files with 109 additions and 21 deletions

View File

@@ -162,6 +162,24 @@ class Policy(ABC):
return policy
def append_or_create_method_replacement(
self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription],
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Append or create a new method replacement description to the policy for the given key.
Args:
description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
target_key (Union[str, nn.Module]): the key of the policy to be updated
"""
if target_key in policy:
policy[target_key].method_replacement.update(description)
else:
policy[target_key] = ModulePolicyDescription(method_replacement=description)
return policy
def get_held_layers(self) -> List[Module]:
"""Get layers that should be held in current stage. This method should be implemented by subclass.