[Feat]Inference RPC Server Support (#5705)

* rpc support source
* kv cache logical/physical disaggregation
* sampler refactor
* colossalai launch built in
* Unitest
* Rpyc support

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Runyu Lu
2024-05-14 10:00:55 +08:00
committed by GitHub
parent de4bf3dedf
commit 18d67d0e8e
15 changed files with 1032 additions and 63 deletions

View File

@@ -1,5 +1,6 @@
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
from colossalai.inference.config import RPC_PARAM
from colossalai.inference.modeling.models.nopadding_llama import (
NopadLlamaAttention,
NopadLlamaMLP,
@@ -14,7 +15,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
def __init__(self) -> None:
super().__init__()
@@ -102,3 +103,10 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
def postprocess(self):
init_to_get_rotary(self.model.model, self.model.config.rope_theta)
return self.model
def to_rpc_param(self) -> str:
return __class__.__name__
@staticmethod
def from_rpc_param() -> "NoPaddingLlamaModelInferPolicy":
return NoPaddingLlamaModelInferPolicy()