mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
[pipeline] Llama causal lm and llama for sequence classification pipeline (#4208)
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2
.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
This commit is contained in:
@@ -162,6 +162,24 @@ class Policy(ABC):
|
||||
|
||||
return policy
|
||||
|
||||
def append_or_create_method_replacement(
|
||||
self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription],
|
||||
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
r"""
|
||||
Append or create a new method replacement description to the policy for the given key.
|
||||
|
||||
Args:
|
||||
description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
|
||||
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
|
||||
target_key (Union[str, nn.Module]): the key of the policy to be updated
|
||||
"""
|
||||
if target_key in policy:
|
||||
policy[target_key].method_replacement.update(description)
|
||||
else:
|
||||
policy[target_key] = ModulePolicyDescription(method_replacement=description)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get layers that should be held in current stage. This method should be implemented by subclass.
|
||||
|
||||
|
Reference in New Issue
Block a user