mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 04:03:58 +00:00
[shardformer] import huggingface implicitly (#4101)
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel
|
||||
|
||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['ViTPolicy']
|
||||
|
||||
|
||||
class ViTPolicy(Policy):
|
||||
|
||||
@@ -25,7 +26,9 @@ class ViTPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
base_policy = {
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
||||
|
||||
return {
|
||||
ViTEmbeddings:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
|
Reference in New Issue
Block a user