From 4bb5d8923a6e85a0f89a483f15933698635a9f9c Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Tue, 2 Apr 2024 14:16:59 +0800 Subject: [PATCH] [Fix/Inference] Remove unused and non-functional functions (#5543) * [fix] remove unused func * rm non-functional partial --- .../modeling/policy/nopadding_llama.py | 29 +++++-------------- colossalai/shardformer/shard/shard_config.py | 8 ----- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index bb9a22b41..292a6e5ff 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -1,5 +1,3 @@ -from functools import partial - from torch.nn import Parameter from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm @@ -13,8 +11,6 @@ from colossalai.inference.modeling.models.nopadding_llama import ( ) from colossalai.inference.utils import init_to_get_rotary from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription - -# import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy @@ -45,27 +41,18 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): ] ) - self.shard_config._infer() - - infer_forward = llama_causal_lm_forward - method_replacement = {"forward": partial(infer_forward)} self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaForCausalLM + description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM ) - - infer_forward = llama_model_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) - - infer_forward = llama_decoder_layer_forward - method_replacement = {"forward": partial(infer_forward)} self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + description={"forward": llama_model_forward}, policy=policy, target_key=LlamaModel + ) + self.append_or_create_method_replacement( + description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=LlamaDecoderLayer + ) + self.append_or_create_method_replacement( + description={"forward": llama_rmsnorm_forward}, policy=policy, target_key=LlamaRMSNorm ) - - infer_forward = llama_rmsnorm_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm) return policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 415fc6dd5..ad79394a9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -36,8 +36,6 @@ class ShardConfig: enable_sequence_overlap: bool = False parallel_output = True extra_kwargs: Dict[str, Any] = field(default_factory=dict) - # pipeline_parallel_size: int - # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] @property @@ -70,9 +68,3 @@ class ShardConfig: self.enable_jit_fused = True self.enable_sequence_parallelism = True self.enable_sequence_overlap = True - - def _infer(self): - """ - Set default params for inference. - """ - # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"