mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||
|
||||
from colossalai.inference.config import RPC_PARAM
|
||||
from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
NopadLlamaAttention,
|
||||
NopadLlamaMLP,
|
||||
@@ -14,7 +15,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
|
||||
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -102,3 +103,10 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def postprocess(self):
|
||||
init_to_get_rotary(self.model.model, self.model.config.rope_theta)
|
||||
return self.model
|
||||
|
||||
def to_rpc_param(self) -> str:
|
||||
return __class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def from_rpc_param() -> "NoPaddingLlamaModelInferPolicy":
|
||||
return NoPaddingLlamaModelInferPolicy()
|
||||
|
Reference in New Issue
Block a user