mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Feat]Tensor Model Parallel Support For Inference (#5563)
* tensor parallel support naive source * [fix]precision, model load and refactor the framework * add tp unit test * docstring * fix do_sample
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
||||
from typing import List, Optional, Tuple
|
||||
import itertools
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaConfig,
|
||||
@@ -26,6 +29,8 @@ from colossalai.kernel.triton import (
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
@@ -68,7 +73,8 @@ def llama_causal_lm_forward(
|
||||
use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could
|
||||
high_precision=inputmetadata.high_precision,
|
||||
)
|
||||
logits = torch.mm(hidden_states, self.lm_head.weight)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
||||
|
||||
@@ -109,6 +115,7 @@ def llama_model_forward(
|
||||
logger.warning("CUDA kernel is disabled for speculative-decoding.")
|
||||
|
||||
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||
|
||||
cu_seqlens = None
|
||||
|
||||
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
|
||||
@@ -126,7 +133,7 @@ def llama_model_forward(
|
||||
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||
|
||||
elif use_cuda_kernel:
|
||||
if inputmetadata != torch.float32 and use_flash_attn2:
|
||||
if inputmetadata.dtype != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
hidden_dim = self._cos_cached.size(-1)
|
||||
@@ -270,7 +277,129 @@ def llama_rmsnorm_forward(
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||
|
||||
|
||||
class NopadLlamaAttention(LlamaAttention):
|
||||
class NopadLlamaMLP(ParallelModule, LlamaMLP):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
mlp_gproj_w: torch.Tensor = None,
|
||||
mlp_uproj_w: torch.Tensor = None,
|
||||
mlp_dproj: ParallelModule = None,
|
||||
process_group: ProcessGroup = None,
|
||||
):
|
||||
"""A Unified Layer for
|
||||
|
||||
Args:
|
||||
config (LlamaConfig): Holding the Llama model config.
|
||||
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
|
||||
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
|
||||
mlp_dproj (Linear1D_Row, optional): The Linear1D_Row mlp_dproj weight. Defaults to None.
|
||||
"""
|
||||
ParallelModule.__init__(self)
|
||||
self.config = config
|
||||
assert is_distributed_tensor(
|
||||
mlp_gproj_w
|
||||
), "mlp_gproj_w must be dtensor so we could get the layout of the weight"
|
||||
self.helper_layout = (
|
||||
mlp_gproj_w.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of gate_up_weight(used in _load_from_state_dict)
|
||||
self.gate_up_weight = nn.Parameter(
|
||||
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
|
||||
)
|
||||
self.down_proj = mlp_dproj
|
||||
self.process_group = process_group
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: LlamaMLP, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
|
||||
|
||||
Args:
|
||||
module (LlamaMLP): The origin LlamaMLP layer.
|
||||
"""
|
||||
|
||||
config = module.config
|
||||
|
||||
mlp_gproj_w = module.gate_proj.weight
|
||||
assert is_distributed_tensor(
|
||||
module.gate_proj.weight
|
||||
), "gate_proj.weight must be dtensor so we could get the layout of the weight"
|
||||
mlp_uproj_w = module.up_proj.weight
|
||||
mlp_dproj = module.down_proj
|
||||
|
||||
mlp_layer = NopadLlamaMLP(
|
||||
config=config,
|
||||
mlp_gproj_w=mlp_gproj_w,
|
||||
mlp_uproj_w=mlp_uproj_w,
|
||||
mlp_dproj=mlp_dproj,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
return mlp_layer
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
|
||||
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "gate_up_weight"
|
||||
k1 = "gate_proj.weight"
|
||||
k2 = "up_proj.weight"
|
||||
|
||||
gate_w = state_dict[prefix + k1]
|
||||
up_w = state_dict[prefix + k2]
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
|
||||
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
|
||||
|
||||
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
|
||||
|
||||
input_param = nn.Parameter(
|
||||
gate_up_w
|
||||
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
"""
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
||||
|
||||
return self.down_proj(act_out)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False"
|
||||
|
||||
|
||||
class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
@@ -278,7 +407,11 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
attn_qproj_w: torch.Tensor = None,
|
||||
attn_kproj_w: torch.Tensor = None,
|
||||
attn_vproj_w: torch.Tensor = None,
|
||||
attn_oproj_w: torch.Tensor = None,
|
||||
attn_oproj: ParallelModule = None,
|
||||
process_group: ProcessGroup = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
num_key_value_heads: int = None,
|
||||
):
|
||||
"""This layer will replace the LlamaAttention.
|
||||
|
||||
@@ -288,36 +421,54 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
||||
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
||||
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
||||
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
|
||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__(config, layer_idx)
|
||||
self.q_proj_weight = attn_qproj_w
|
||||
self.k_proj_weight = attn_kproj_w
|
||||
self.v_proj_weight = attn_vproj_w
|
||||
self.o_proj_weight = attn_oproj_w
|
||||
ParallelModule.__init__(self)
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.o_proj = attn_oproj
|
||||
self.process_group = process_group
|
||||
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
if self.num_heads == self.num_key_value_heads:
|
||||
qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight]
|
||||
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
||||
|
||||
self.q_proj = None
|
||||
self.k_proj = None
|
||||
self.v_proj = None
|
||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
self.helper_layout = (
|
||||
attn_qproj_w.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
else:
|
||||
self.q_proj_weight = attn_qproj_w
|
||||
self.k_proj_weight = attn_kproj_w
|
||||
self.v_proj_weight = attn_vproj_w
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
|
||||
def from_native_module(
|
||||
module: LlamaAttention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention.
|
||||
|
||||
Args:
|
||||
module (LlamaAttention): The origin LlamaAttention layer.
|
||||
"""
|
||||
|
||||
config = module.config
|
||||
layer_idx = module.layer_idx
|
||||
|
||||
attn_qproj_w = module.q_proj.weight.transpose(0, 1)
|
||||
attn_kproj_w = module.k_proj.weight.transpose(0, 1)
|
||||
attn_vproj_w = module.v_proj.weight.transpose(0, 1)
|
||||
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
|
||||
attn_qproj_w = module.q_proj.weight
|
||||
attn_kproj_w = module.k_proj.weight
|
||||
attn_vproj_w = module.v_proj.weight
|
||||
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
|
||||
attn_oproj = module.o_proj
|
||||
|
||||
attn_layer = NopadLlamaAttention(
|
||||
config=config,
|
||||
@@ -325,7 +476,11 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
attn_qproj_w=attn_qproj_w,
|
||||
attn_kproj_w=attn_kproj_w,
|
||||
attn_vproj_w=attn_vproj_w,
|
||||
attn_oproj_w=attn_oproj_w,
|
||||
attn_oproj=attn_oproj,
|
||||
process_group=process_group,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
num_key_value_heads=module.num_key_value_heads,
|
||||
)
|
||||
|
||||
return attn_layer
|
||||
@@ -487,63 +642,57 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE This will cause difference as out length increases.
|
||||
class NopadLlamaMLP(LlamaMLP):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
mlp_gproj_w: torch.Tensor = None,
|
||||
mlp_uproj_w: torch.Tensor = None,
|
||||
mlp_dproj_w: torch.Tensor = None,
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
"""This layer will replace the LlamaAttention.
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
Args:
|
||||
config (LlamaConfig): Holding the Llama model config.
|
||||
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
|
||||
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
|
||||
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
|
||||
self.down_proj_weight = mlp_dproj_w
|
||||
self.gate_proj = None
|
||||
self.up_proj = None
|
||||
self.down_proj = None
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
|
||||
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
|
||||
key = "qkv_weight"
|
||||
k1 = "q_proj.weight"
|
||||
k2 = "k_proj.weight"
|
||||
k3 = "v_proj.weight"
|
||||
q_w = state_dict[prefix + k1]
|
||||
k_w = state_dict[prefix + k2]
|
||||
v_w = state_dict[prefix + k3]
|
||||
|
||||
Args:
|
||||
module (LlamaMLP): The origin LlamaMLP layer.
|
||||
"""
|
||||
config = module.config
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
||||
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
||||
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
||||
|
||||
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
|
||||
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
|
||||
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
|
||||
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
||||
|
||||
mlp_layer = NopadLlamaMLP(
|
||||
config=config,
|
||||
mlp_gproj_w=mlp_gproj_w,
|
||||
mlp_uproj_w=mlp_uproj_w,
|
||||
mlp_dproj_w=mlp_dproj_w,
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
return mlp_layer
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
"""
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
||||
return torch.mm(act_out, self.down_proj_weight)
|
||||
def extra_repr(self) -> str:
|
||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from torch.nn import Parameter
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
@@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
llama_rmsnorm_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
@@ -21,26 +21,69 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False),
|
||||
}
|
||||
policy[LlamaForCausalLM] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
else:
|
||||
decoder_attribute_replacement = None
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=NopadLlamaMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadLlamaAttention,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
policy[LlamaForCausalLM] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# self.shard_config._infer()
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM
|
||||
)
|
||||
|
Reference in New Issue
Block a user