mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 20:23:26 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -18,6 +18,7 @@ class PolicyLocation:
|
||||
file_name (str): The file name of the policy under colossalai.shardformer.policies
|
||||
class_name (str): The class name of the policy class
|
||||
"""
|
||||
|
||||
file_name: str
|
||||
class_name: str
|
||||
|
||||
@@ -27,121 +28,142 @@ class PolicyLocation:
|
||||
# we will allow the user to only import the policy file needed
|
||||
_POLICY_LIST = {
|
||||
# BERT
|
||||
"transformers.models.bert.modeling_bert.BertModel":
|
||||
PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForPreTraining":
|
||||
PolicyLocation(file_name="bert", class_name="BertForPreTrainingPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertLMHeadModel":
|
||||
PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForMaskedLM":
|
||||
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForSequenceClassification":
|
||||
PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForTokenClassification":
|
||||
PolicyLocation(file_name="bert", class_name="BertForTokenClassificationPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
|
||||
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
|
||||
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForQuestionAnswering":
|
||||
PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"),
|
||||
|
||||
"transformers.models.bert.modeling_bert.BertModel": PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForPreTraining": PolicyLocation(
|
||||
file_name="bert", class_name="BertForPreTrainingPolicy"
|
||||
),
|
||||
"transformers.models.bert.modeling_bert.BertLMHeadModel": PolicyLocation(
|
||||
file_name="bert", class_name="BertLMHeadModelPolicy"
|
||||
),
|
||||
"transformers.models.bert.modeling_bert.BertForMaskedLM": PolicyLocation(
|
||||
file_name="bert", class_name="BertForMaskedLMPolicy"
|
||||
),
|
||||
"transformers.models.bert.modeling_bert.BertForSequenceClassification": PolicyLocation(
|
||||
file_name="bert", class_name="BertForSequenceClassificationPolicy"
|
||||
),
|
||||
"transformers.models.bert.modeling_bert.BertForTokenClassification": PolicyLocation(
|
||||
file_name="bert", class_name="BertForTokenClassificationPolicy"
|
||||
),
|
||||
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction": PolicyLocation(
|
||||
file_name="bert", class_name="BertForNextSentencePredictionPolicy"
|
||||
),
|
||||
"transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation(
|
||||
file_name="bert", class_name="BertForMultipleChoicePolicy"
|
||||
),
|
||||
"transformers.models.bert.modeling_bert.BertForQuestionAnswering": PolicyLocation(
|
||||
file_name="bert", class_name="BertForQuestionAnsweringPolicy"
|
||||
),
|
||||
# LLaMA
|
||||
"transformers.models.llama.modeling_llama.LlamaModel":
|
||||
PolicyLocation(file_name="llama", class_name="LlamaModelPolicy"),
|
||||
"transformers.models.llama.modeling_llama.LlamaForCausalLM":
|
||||
PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"),
|
||||
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification":
|
||||
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),
|
||||
|
||||
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
|
||||
file_name="llama", class_name="LlamaModelPolicy"
|
||||
),
|
||||
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
|
||||
file_name="llama", class_name="LlamaForCausalLMPolicy"
|
||||
),
|
||||
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification": PolicyLocation(
|
||||
file_name="llama", class_name="LlamaForSequenceClassificationPolicy"
|
||||
),
|
||||
# T5
|
||||
"transformers.models.t5.modeling_t5.T5Model":
|
||||
PolicyLocation(file_name="t5", class_name="T5ModelPolicy"),
|
||||
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration":
|
||||
PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"),
|
||||
"transformers.models.t5.modeling_t5.T5EncoderModel":
|
||||
PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
|
||||
|
||||
"transformers.models.t5.modeling_t5.T5Model": PolicyLocation(file_name="t5", class_name="T5ModelPolicy"),
|
||||
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration": PolicyLocation(
|
||||
file_name="t5", class_name="T5ForConditionalGenerationPolicy"
|
||||
),
|
||||
"transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
|
||||
# GPT2
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2Model":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
|
||||
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation(
|
||||
file_name="gpt2", class_name="GPT2LMHeadModelPolicy"
|
||||
),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": PolicyLocation(
|
||||
file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"
|
||||
),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": PolicyLocation(
|
||||
file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"
|
||||
),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": PolicyLocation(
|
||||
file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"
|
||||
),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(
|
||||
file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"
|
||||
),
|
||||
# ViT
|
||||
"transformers.models.vit.modeling_vit.ViTModel":
|
||||
PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
|
||||
"transformers.models.vit.modeling_vit.ViTForImageClassification":
|
||||
PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"),
|
||||
"transformers.models.vit.modeling_vit.ViTForMaskedImageModeling":
|
||||
PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"),
|
||||
|
||||
"transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
|
||||
"transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation(
|
||||
file_name="vit", class_name="ViTForImageClassificationPolicy"
|
||||
),
|
||||
"transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": PolicyLocation(
|
||||
file_name="vit", class_name="ViTForMaskedImageModelingPolicy"
|
||||
),
|
||||
# OPT
|
||||
"transformers.models.opt.modeling_opt.OPTModel":
|
||||
PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
|
||||
"transformers.models.opt.modeling_opt.OPTForCausalLM":
|
||||
PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"),
|
||||
"transformers.models.opt.modeling_opt.OPTForSequenceClassification":
|
||||
PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"),
|
||||
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering":
|
||||
PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"),
|
||||
|
||||
"transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
|
||||
"transformers.models.opt.modeling_opt.OPTForCausalLM": PolicyLocation(
|
||||
file_name="opt", class_name="OPTForCausalLMPolicy"
|
||||
),
|
||||
"transformers.models.opt.modeling_opt.OPTForSequenceClassification": PolicyLocation(
|
||||
file_name="opt", class_name="OPTForSequenceClassificationPolicy"
|
||||
),
|
||||
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering": PolicyLocation(
|
||||
file_name="opt", class_name="OPTForQuestionAnsweringPolicy"
|
||||
),
|
||||
# Bloom
|
||||
"transformers.models.bloom.modeling_bloom.BloomModel":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForCausalLM":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForCausalLMPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForSequenceClassification":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForSequenceClassificationPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForTokenClassification":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"),
|
||||
|
||||
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomModelPolicy"
|
||||
),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomForCausalLMPolicy"
|
||||
),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomForSequenceClassificationPolicy"
|
||||
),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForTokenClassification": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomForTokenClassificationPolicy"
|
||||
),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"
|
||||
),
|
||||
# Whisper
|
||||
"transformers.models.whisper.modeling_whisper.WhisperModel":
|
||||
PolicyLocation(file_name="whisper", class_name="WhisperModelPolicy"),
|
||||
"transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration":
|
||||
PolicyLocation(file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"),
|
||||
"transformers.models.whisper.modeling_whisper.WhisperForAudioClassification":
|
||||
PolicyLocation(file_name="whisper", class_name="WhisperForAudioClassificationPolicy"),
|
||||
|
||||
"transformers.models.whisper.modeling_whisper.WhisperModel": PolicyLocation(
|
||||
file_name="whisper", class_name="WhisperModelPolicy"
|
||||
),
|
||||
"transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": PolicyLocation(
|
||||
file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"
|
||||
),
|
||||
"transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": PolicyLocation(
|
||||
file_name="whisper", class_name="WhisperForAudioClassificationPolicy"
|
||||
),
|
||||
# Sam
|
||||
"transformers.models.sam.modeling_sam.SamModel":
|
||||
PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
|
||||
|
||||
"transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
|
||||
# Blip2
|
||||
"transformers.models.blip_2.modeling_blip_2.Blip2Model":
|
||||
PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"),
|
||||
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration":
|
||||
PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"),
|
||||
|
||||
"transformers.models.blip_2.modeling_blip_2.Blip2Model": PolicyLocation(
|
||||
file_name="blip2", class_name="Blip2ModelPolicy"
|
||||
),
|
||||
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": PolicyLocation(
|
||||
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
|
||||
),
|
||||
# ChatGLM
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel":
|
||||
PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"),
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration":
|
||||
PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"),
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLMModelPolicy"
|
||||
),
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
|
||||
),
|
||||
}
|
||||
|
||||
_INFER_POLICY_LIST = {
|
||||
# LlaMa
|
||||
"transformers.models.llama.modeling_llama.LlamaModel":
|
||||
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
|
||||
"transformers.models.llama.modeling_llama.LlamaForCausalLM":
|
||||
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
|
||||
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
|
||||
file_name="llama", class_name="LlamaModelInferPolicy"
|
||||
),
|
||||
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
|
||||
file_name="llama", class_name="LlamaModelInferPolicy"
|
||||
),
|
||||
# Bloom
|
||||
"transformers.models.bloom.modeling_bloom.BloomModel":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForCausalLM":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomModelInferPolicy"
|
||||
),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
|
||||
file_name="bloom", class_name="BloomModelInferPolicy"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -163,9 +185,9 @@ def _fullname(obj):
|
||||
"""
|
||||
klass = obj.__class__
|
||||
module = klass.__module__
|
||||
if module == 'builtins':
|
||||
return klass.__qualname__ # avoid outputs like 'builtins.str'
|
||||
return module + '.' + klass.__qualname__
|
||||
if module == "builtins":
|
||||
return klass.__qualname__ # avoid outputs like 'builtins.str'
|
||||
return module + "." + klass.__qualname__
|
||||
|
||||
|
||||
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
|
||||
|
Reference in New Issue
Block a user