mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] import huggingface implicitly (#4101)
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
||||
|
||||
|
||||
class LlamaPolicy(Policy):
|
||||
|
||||
@@ -26,7 +26,9 @@ class LlamaPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
base_policy = {
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
return {
|
||||
LlamaDecoderLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
@@ -109,6 +111,8 @@ class LlamaPolicy(Policy):
|
||||
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
@@ -128,6 +132,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
# add a new item for sequence classification
|
||||
|
Reference in New Issue
Block a user