mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code
This commit is contained in:
@@ -1,25 +1,57 @@
|
||||
import re
|
||||
|
||||
|
||||
def get_obj_list_element(obj, a):
|
||||
def get_obj_list_element(obj, attr: str):
|
||||
r"""
|
||||
Get the element of the list in the object
|
||||
|
||||
If the attr is a normal attribute, return the attribute of the object.
|
||||
If the attr is a index type, return the element of the index in the list, like `layers[0]`.
|
||||
|
||||
Args:
|
||||
obj (Object): The object to get
|
||||
attr (str): The suffix of the attribute to get
|
||||
|
||||
"""
|
||||
re_pattern = r'\[\d+\]'
|
||||
prog = re.compile(re_pattern)
|
||||
result = prog.search(a)
|
||||
result = prog.search(attr)
|
||||
if result:
|
||||
matched_brackets = result.group()
|
||||
matched_index = matched_brackets.replace('[', '')
|
||||
matched_index = matched_index.replace(']', '')
|
||||
a_ = a.replace(matched_brackets, '')
|
||||
container_obj = getattr(obj, a_)
|
||||
attr_ = attr.replace(matched_brackets, '')
|
||||
container_obj = getattr(obj, attr_)
|
||||
obj = container_obj[int(matched_index)]
|
||||
else:
|
||||
obj = getattr(obj, a)
|
||||
obj = getattr(obj, attr)
|
||||
return obj
|
||||
|
||||
|
||||
def set_obj_list_element(obj, attr: str, value):
|
||||
r"""
|
||||
Set the element to value of a list object
|
||||
|
||||
It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value
|
||||
|
||||
Args:
|
||||
obj (object): The object to set
|
||||
attr (str): the string including a list index like `layers[0]`
|
||||
"""
|
||||
re_pattern = r'\[\d+\]'
|
||||
prog = re.compile(re_pattern)
|
||||
result = prog.search(attr)
|
||||
if result:
|
||||
matched_brackets = result.group()
|
||||
matched_index = matched_brackets.replace('[', '')
|
||||
matched_index = matched_index.replace(']', '')
|
||||
attr_ = attr.replace(matched_brackets, '')
|
||||
container_obj = getattr(obj, attr_)
|
||||
container_obj[int(matched_index)] = value
|
||||
else:
|
||||
setattr(obj, attr, value)
|
||||
|
||||
|
||||
def hasattr_(obj, attr: str):
|
||||
r"""
|
||||
Check whether the object has the multi sublevel attr
|
||||
@@ -56,7 +88,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
|
||||
if ignore:
|
||||
return
|
||||
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
|
||||
setattr(obj, attrs[-1], value)
|
||||
set_obj_list_element(obj, attrs[-1], value)
|
||||
|
||||
|
||||
def getattr_(obj, attr: str, ignore: bool = False):
|
||||
|
@@ -3,11 +3,10 @@ from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row
|
||||
from .loss import cross_entropy_1d
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
||||
__all__ = [
|
||||
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
|
||||
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
|
||||
'FusedLayerNorm', 'FusedRMSNorm', 'ParallelModule'
|
||||
'FusedLayerNorm', 'FusedRMSNorm', 'FusedLinear1D_Col'
|
||||
]
|
||||
|
@@ -25,6 +25,7 @@ from colossalai.tensor.d_tensor.api import (
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_split_backward,
|
||||
linear_with_async_comm,
|
||||
matmul_with_async_comm,
|
||||
reduce_backward,
|
||||
reduce_forward,
|
||||
@@ -33,7 +34,7 @@ from ._operation import (
|
||||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row']
|
||||
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row', 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row']
|
||||
|
||||
# ====================================
|
||||
# For GPT Only
|
||||
@@ -490,3 +491,175 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
||||
|
||||
# ====================================
|
||||
# For Fused torch.nn.Linear
|
||||
# ====================================
|
||||
|
||||
|
||||
class FusedLinear1D_Col(ParallelModule):
|
||||
r"""Fused Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (`typing.Callable`):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
async_communication: bool = False,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.n_fused = n_fused
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
|
||||
def shard_fn(tensor):
|
||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
||||
|
||||
def gather_fn(tensor):
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False)
|
||||
|
||||
with torch.no_grad():
|
||||
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
|
||||
self.weight = customized_distributed_tensor_to_param(sharded_weight)
|
||||
|
||||
if bias:
|
||||
bias = torch.empty(self.out_features, **factory_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
|
||||
self.bias = customized_distributed_tensor_to_param(sharded_bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
|
||||
*args, **kwargs) -> ParallelModule:
|
||||
r"""
|
||||
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
|
||||
|
||||
Args:
|
||||
module (`nn.Linear`): The module to be converted.
|
||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||
n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
|
||||
"""
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
process_group = process_group[0]
|
||||
|
||||
linear_1d = FusedLinear1D_Col(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
||||
n_fused=n_fused,
|
||||
process_group=process_group,
|
||||
is_transposed=False)
|
||||
linear_1d.weight.data.copy_(sharded_weight.data)
|
||||
|
||||
if bias:
|
||||
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
||||
n_fused=n_fused,
|
||||
process_group=process_group,
|
||||
is_transposed=False)
|
||||
linear_1d.bias.data.copy_(sharded_bias.data)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
# Set up backprop all-reduce.
|
||||
# input_parallel = reduce_backward(input_, self.process_group)
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
41
colossalai/shardformer/modeling/sam.py
Normal file
41
colossalai/shardformer/modeling/sam.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
def forward_fn():
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
||||
batch_size, height, width, _ = hidden_states.shape
|
||||
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
||||
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
|
||||
-1).permute(2, 0, 3, 1, 4))
|
||||
# q, k, v with shape (batch_size * nHead, height * width, channel)
|
||||
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
|
||||
|
||||
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w,
|
||||
(height, width), (height, width))
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
||||
|
||||
# replace dropout process with added DropoutForParallelInput layer
|
||||
# origin code:
|
||||
# attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_probs = self.dropout_layer(attn_weights)
|
||||
|
||||
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
||||
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
|
||||
if output_attentions:
|
||||
outputs = (attn_output, attn_weights)
|
||||
else:
|
||||
outputs = (attn_output, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
@@ -104,6 +104,10 @@ _POLICY_LIST = {
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"),
|
||||
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering":
|
||||
PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"),
|
||||
|
||||
# Sam
|
||||
"transformers.models.sam.modeling_sam.SamModel":
|
||||
PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
|
||||
}
|
||||
|
||||
|
||||
|
209
colossalai/shardformer/policies/sam.py
Normal file
209
colossalai/shardformer/policies/sam.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.sam import forward_fn
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['SamPolicy', 'SamModelPolicy']
|
||||
|
||||
|
||||
class SamPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.sam.modeling_sam import (
|
||||
SamFeedForward,
|
||||
SamTwoWayAttentionBlock,
|
||||
SamTwoWayTransformer,
|
||||
SamVisionAttention,
|
||||
SamVisionLayer,
|
||||
)
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[SamVisionLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"attn.num_attention_heads":
|
||||
self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.qkv",
|
||||
target_module=col_nn.FusedLinear1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
)
|
||||
])
|
||||
policy[SamTwoWayAttentionBlock] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attn.num_attention_heads":
|
||||
self.model.config.mask_decoder_config.num_attention_heads //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
])
|
||||
policy[SamTwoWayTransformer] = ModulePolicyDescription(attribute_replacement={
|
||||
"final_attn_token_to_image.num_attention_heads":
|
||||
self.model.config.mask_decoder_config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
)
|
||||
])
|
||||
|
||||
# add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout`
|
||||
policy[SamVisionAttention] = ModulePolicyDescription(attribute_replacement={
|
||||
"dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout)
|
||||
},
|
||||
method_replacement={"forward": forward_fn()},
|
||||
sub_module_replacement=[])
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle SamVisionLayer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm1",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm2",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=SamVisionLayer)
|
||||
|
||||
# Handle SamTwoWayAttentionBlock
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm1",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm2",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm3",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm4",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=SamTwoWayAttentionBlock)
|
||||
|
||||
# Handle SamTwoWayTransformer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm_final_attn",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=SamTwoWayTransformer)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
# SamModel
|
||||
class SamModelPolicy(SamPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
Reference in New Issue
Block a user