[shardformer] import huggingface implicitly (#4101)

This commit is contained in:
Frank Lee
2023-06-30 10:56:29 +08:00
parent 6a88bae4ec
commit 44a190e6ac
9 changed files with 91 additions and 38 deletions

View File

@@ -1,13 +1,13 @@
from typing import Dict, Union
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
class LlamaPolicy(Policy):
@@ -26,7 +26,9 @@ class LlamaPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
base_policy = {
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
return {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
@@ -109,6 +111,8 @@ class LlamaPolicy(Policy):
class LlamaForCausalLMPolicy(LlamaPolicy):
def module_policy(self):
from transformers import LlamaForCausalLM
policy = super().module_policy()
# add a new item for casual lm
new_item = {
@@ -128,6 +132,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
def module_policy(self):
from transformers import LlamaForSequenceClassification
policy = super().module_policy()
# add a new item for sequence classification