mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-27 12:29:29 +00:00
feat(model): Support OpenAI-Compatible RESTful APIs
This commit is contained in:
parent
c7cad041d5
commit
2c9c539404
@ -7,6 +7,16 @@ services:
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- dbgptnet
|
||||
api-server:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start apiserver --controller_addr http://controller:8000
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- controller
|
||||
networks:
|
||||
- dbgptnet
|
||||
ports:
|
||||
- 8100:8100/tcp
|
||||
llm-worker:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000
|
||||
|
0
pilot/base_modules/agent/db/__init__.py
Normal file
0
pilot/base_modules/agent/db/__init__.py
Normal file
@ -46,6 +46,8 @@ class ComponentType(str, Enum):
|
||||
WORKER_MANAGER = "dbgpt_worker_manager"
|
||||
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
|
||||
MODEL_CONTROLLER = "dbgpt_model_controller"
|
||||
MODEL_REGISTRY = "dbgpt_model_registry"
|
||||
MODEL_API_SERVER = "dbgpt_model_api_server"
|
||||
AGENT_HUB = "dbgpt_agent_hub"
|
||||
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
|
||||
TRACER = "dbgpt_tracer"
|
||||
@ -68,7 +70,6 @@ class BaseComponent(LifeCycle, ABC):
|
||||
This method needs to be implemented by every component to define how it integrates
|
||||
with the main system app.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseComponent)
|
||||
@ -90,13 +91,28 @@ class SystemApp(LifeCycle):
|
||||
"""Returns the internal ASGI app."""
|
||||
return self._asgi_app
|
||||
|
||||
def register(self, component: Type[BaseComponent], *args, **kwargs):
|
||||
"""Register a new component by its type."""
|
||||
def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
|
||||
"""Register a new component by its type.
|
||||
|
||||
Args:
|
||||
component (Type[BaseComponent]): The component class to register
|
||||
|
||||
Returns:
|
||||
T: The instance of registered component
|
||||
"""
|
||||
instance = component(self, *args, **kwargs)
|
||||
self.register_instance(instance)
|
||||
return instance
|
||||
|
||||
def register_instance(self, instance: T):
|
||||
"""Register an already initialized component."""
|
||||
def register_instance(self, instance: T) -> T:
|
||||
"""Register an already initialized component.
|
||||
|
||||
Args:
|
||||
instance (T): The component instance to register
|
||||
|
||||
Returns:
|
||||
T: The instance of registered component
|
||||
"""
|
||||
name = instance.name
|
||||
if isinstance(name, ComponentType):
|
||||
name = name.value
|
||||
@ -107,18 +123,34 @@ class SystemApp(LifeCycle):
|
||||
logger.info(f"Register component with name {name} and instance: {instance}")
|
||||
self.components[name] = instance
|
||||
instance.init_app(self)
|
||||
return instance
|
||||
|
||||
def get_component(
|
||||
self,
|
||||
name: Union[str, ComponentType],
|
||||
component_type: Type[T],
|
||||
default_component=_EMPTY_DEFAULT_COMPONENT,
|
||||
or_register_component: Type[BaseComponent] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Retrieve a registered component by its name and type."""
|
||||
"""Retrieve a registered component by its name and type.
|
||||
|
||||
Args:
|
||||
name (Union[str, ComponentType]): Component name
|
||||
component_type (Type[T]): The type of current retrieve component
|
||||
default_component : The default component instance if not retrieve by name
|
||||
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
|
||||
|
||||
Returns:
|
||||
T: The instance retrieved by component name
|
||||
"""
|
||||
if isinstance(name, ComponentType):
|
||||
name = name.value
|
||||
component = self.components.get(name)
|
||||
if not component:
|
||||
if or_register_component:
|
||||
return self.register(or_register_component, *args, **kwargs)
|
||||
if default_component != _EMPTY_DEFAULT_COMPONENT:
|
||||
return default_component
|
||||
raise ValueError(f"No component found with name {name}")
|
||||
|
@ -2,7 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from enum import Enum
|
||||
from typing import TypedDict, Optional, Dict, List
|
||||
from typing import TypedDict, Optional, Dict, List, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from pilot.utils.parameter_utils import ParameterDescription
|
||||
@ -52,6 +52,8 @@ class ModelOutput:
|
||||
text: str
|
||||
error_code: int
|
||||
model_context: Dict = None
|
||||
finish_reason: str = None
|
||||
usage: Dict[str, Any] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self)
|
||||
|
@ -8,6 +8,7 @@ from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import WorkerApplyType
|
||||
from pilot.model.parameter import (
|
||||
ModelControllerParameters,
|
||||
ModelAPIServerParameters,
|
||||
ModelWorkerParameters,
|
||||
ModelParameters,
|
||||
BaseParameters,
|
||||
@ -441,15 +442,27 @@ def stop_model_worker(port: int):
|
||||
|
||||
|
||||
@click.command(name="apiserver")
|
||||
@EnvArgumentParser.create_click_option(ModelAPIServerParameters)
|
||||
def start_apiserver(**kwargs):
|
||||
"""Start apiserver(TODO)"""
|
||||
raise NotImplementedError
|
||||
"""Start apiserver"""
|
||||
|
||||
if kwargs["daemon"]:
|
||||
log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log")
|
||||
_run_current_with_daemon("ModelAPIServer", log_file)
|
||||
else:
|
||||
from pilot.model.cluster import run_apiserver
|
||||
|
||||
run_apiserver()
|
||||
|
||||
|
||||
@click.command(name="apiserver")
|
||||
def stop_apiserver(**kwargs):
|
||||
"""Start apiserver(TODO)"""
|
||||
raise NotImplementedError
|
||||
@add_stop_server_options
|
||||
def stop_apiserver(port: int):
|
||||
"""Stop apiserver"""
|
||||
name = "ModelAPIServer"
|
||||
if port:
|
||||
name = f"{name}-{port}"
|
||||
_stop_service("apiserver", name, port=port)
|
||||
|
||||
|
||||
def _stop_all_model_server(**kwargs):
|
||||
|
@ -21,6 +21,7 @@ from pilot.model.cluster.controller.controller import (
|
||||
run_model_controller,
|
||||
BaseModelController,
|
||||
)
|
||||
from pilot.model.cluster.apiserver.api import run_apiserver
|
||||
|
||||
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
||||
|
||||
@ -40,4 +41,5 @@ __all__ = [
|
||||
"ModelRegistryClient",
|
||||
"RemoteWorkerManager",
|
||||
"run_model_controller",
|
||||
"run_apiserver",
|
||||
]
|
||||
|
0
pilot/model/cluster/apiserver/__init__.py
Normal file
0
pilot/model/cluster/apiserver/__init__.py
Normal file
443
pilot/model/cluster/apiserver/api.py
Normal file
443
pilot/model/cluster/apiserver/api.py
Normal file
@ -0,0 +1,443 @@
|
||||
"""A server that provides OpenAI-compatible RESTful APIs. It supports:
|
||||
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
|
||||
|
||||
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any, Generator
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import shortuuid
|
||||
import json
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
from fastchat.protocol.openai_api_protocol import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
ErrorResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
from fastchat.protocol.api_protocol import (
|
||||
APIChatCompletionRequest,
|
||||
APITokenCheckRequest,
|
||||
APITokenCheckResponse,
|
||||
APITokenCheckResponseItem,
|
||||
)
|
||||
from fastchat.serve.openai_api_server import create_error_response, check_requests
|
||||
from fastchat.constants import ErrorCode
|
||||
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
from pilot.model.base import ModelInstance, ModelOutput
|
||||
from pilot.model.parameter import ModelAPIServerParameters, WorkerType
|
||||
from pilot.model.cluster import ModelRegistry, ModelRegistryClient
|
||||
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
|
||||
from pilot.utils.utils import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIServerException(Exception):
|
||||
def __init__(self, code: int, message: str):
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class APISettings(BaseSettings):
|
||||
api_keys: Optional[List[str]] = None
|
||||
|
||||
|
||||
api_settings = APISettings()
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
) -> str:
|
||||
if api_settings.api_keys:
|
||||
if auth is None or (token := auth.credentials) not in api_settings.api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
class APIServer(BaseComponent):
|
||||
name = ComponentType.MODEL_API_SERVER
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
def get_worker_manager(self) -> WorkerManager:
|
||||
"""Get the worker manager component instance
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get worker manager component instance
|
||||
"""
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
if not worker_manager:
|
||||
raise APIServerException(
|
||||
ErrorCode.INTERNAL_ERROR,
|
||||
f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app",
|
||||
)
|
||||
return worker_manager
|
||||
|
||||
def get_model_registry(self) -> ModelRegistry:
|
||||
"""Get the model registry component instance
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get model registry component instance
|
||||
"""
|
||||
|
||||
controller = self.system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry
|
||||
)
|
||||
if not controller:
|
||||
raise APIServerException(
|
||||
ErrorCode.INTERNAL_ERROR,
|
||||
f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app",
|
||||
)
|
||||
return controller
|
||||
|
||||
async def get_model_instances_or_raise(
|
||||
self, model_name: str
|
||||
) -> List[ModelInstance]:
|
||||
"""Get healthy model instances with request model name
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
|
||||
Raises:
|
||||
APIServerException: If can't get healthy model instances with request model name
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
registry_model_name = f"{model_name}@llm"
|
||||
model_instances = await registry.get_all_instances(
|
||||
registry_model_name, healthy_only=True
|
||||
)
|
||||
if not model_instances:
|
||||
all_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
models = [
|
||||
ins.model_name.split("@llm")[0]
|
||||
for ins in all_instances
|
||||
if ins.model_name.endswith("@llm")
|
||||
]
|
||||
if models:
|
||||
models = "&&".join(models)
|
||||
message = f"Only {models} allowed now, your model {model_name}"
|
||||
else:
|
||||
message = f"No models allowed now, your model {model_name}"
|
||||
raise APIServerException(ErrorCode.INVALID_MODEL, message)
|
||||
return model_instances
|
||||
|
||||
async def get_available_models(self) -> ModelList:
|
||||
"""Return available models
|
||||
|
||||
Just include LLM and embedding models.
|
||||
|
||||
Returns:
|
||||
List[ModelList]: The list of models.
|
||||
"""
|
||||
registry = self.get_model_registry()
|
||||
model_instances = await registry.get_all_model_instances(healthy_only=True)
|
||||
model_name_set = set()
|
||||
for inst in model_instances:
|
||||
name, worker_type = WorkerType.parse_worker_key(inst.model_name)
|
||||
if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC:
|
||||
model_name_set.add(name)
|
||||
models = list(model_name_set)
|
||||
models.sort()
|
||||
# TODO: return real model permission details
|
||||
model_cards = []
|
||||
for m in models:
|
||||
model_cards.append(
|
||||
ModelCard(
|
||||
id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()]
|
||||
)
|
||||
)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, model_name: str, params: Dict[str, Any], n: int
|
||||
) -> Generator[str, Any, None]:
|
||||
"""Chat stream completion generator
|
||||
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
params (Dict[str, Any]): The parameters pass to model worker
|
||||
n (int): How many completions to generate for each prompt.
|
||||
"""
|
||||
worker_manager = self.get_worker_manager()
|
||||
id = f"chatcmpl-{shortuuid.random()}"
|
||||
finish_stream_events = []
|
||||
for i in range(n):
|
||||
# First chunk with role
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
|
||||
previous_text = ""
|
||||
async for model_output in worker_manager.generate_stream(params):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
delta_text = decoded_unicode[len(previous_text) :]
|
||||
previous_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
|
||||
if len(delta_text) == 0:
|
||||
delta_text = None
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name
|
||||
)
|
||||
if delta_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
continue
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||
for finish_chunk in finish_stream_events:
|
||||
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_generate(
|
||||
self, model_name: str, params: Dict[str, Any], n: int
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate completion
|
||||
Args:
|
||||
model_name (str): Model name
|
||||
params (Dict[str, Any]): The parameters pass to model worker
|
||||
n (int): How many completions to generate for each prompt.
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
choices = []
|
||||
chat_completions = []
|
||||
for i in range(n):
|
||||
model_output = asyncio.create_task(worker_manager.generate(params))
|
||||
chat_completions.append(model_output)
|
||||
try:
|
||||
all_tasks = await asyncio.gather(*chat_completions)
|
||||
except Exception as e:
|
||||
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
||||
usage = UsageInfo()
|
||||
for i, model_output in enumerate(all_tasks):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
return create_error_response(model_output.error_code, model_output.text)
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role="assistant", content=model_output.text),
|
||||
finish_reason=model_output.finish_reason or "stop",
|
||||
)
|
||||
)
|
||||
if model_output.usage:
|
||||
task_usage = UsageInfo.parse_obj(model_output.usage)
|
||||
for usage_key, usage_value in task_usage.dict().items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
|
||||
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
|
||||
|
||||
|
||||
def get_api_server() -> APIServer:
|
||||
api_server = global_system_app.get_component(
|
||||
ComponentType.MODEL_API_SERVER, APIServer, default_component=None
|
||||
)
|
||||
if not api_server:
|
||||
global_system_app.register(APIServer)
|
||||
return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
async def get_available_models(api_server: APIServer = Depends(get_api_server)):
|
||||
return await api_server.get_available_models()
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
async def create_chat_completion(
|
||||
request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
await api_server.get_model_instances_or_raise(request.model)
|
||||
error_check_ret = check_requests(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
params = {
|
||||
"model": request.model,
|
||||
"messages": ModelMessage.to_dict_list(
|
||||
ModelMessage.from_openai_messages(request.messages)
|
||||
),
|
||||
"echo": False,
|
||||
}
|
||||
if request.temperature:
|
||||
params["temperature"] = request.temperature
|
||||
if request.top_p:
|
||||
params["top_p"] = request.top_p
|
||||
if request.max_tokens:
|
||||
params["max_new_tokens"] = request.max_tokens
|
||||
if request.stop:
|
||||
params["stop"] = request.stop
|
||||
if request.user:
|
||||
params["user"] = request.user
|
||||
|
||||
# TODO check token length
|
||||
if request.stream:
|
||||
generator = api_server.chat_completion_stream_generator(
|
||||
request.model, params, request.n
|
||||
)
|
||||
return StreamingResponse(generator, media_type="text/event-stream")
|
||||
return await api_server.chat_completion_generate(request.model, params, request.n)
|
||||
|
||||
|
||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient
|
||||
from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
||||
if not system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
||||
):
|
||||
# Register model registry if not exist
|
||||
registry = ModelRegistryClient(controller_addr)
|
||||
registry.name = ComponentType.MODEL_REGISTRY.value
|
||||
system_app.register_instance(registry)
|
||||
|
||||
registry = system_app.get_component(
|
||||
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
|
||||
)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
|
||||
# Register worker manager component if not exist
|
||||
system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
or_register_component=_DefaultWorkerManagerFactory,
|
||||
worker_manager=worker_manager,
|
||||
)
|
||||
# Register api server component if not exist
|
||||
system_app.get_component(
|
||||
ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer
|
||||
)
|
||||
|
||||
|
||||
def initialize_apiserver(
|
||||
controller_addr: str,
|
||||
app=None,
|
||||
system_app: SystemApp = None,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
api_keys: List[str] = None,
|
||||
):
|
||||
global global_system_app
|
||||
global api_settings
|
||||
embedded_mod = True
|
||||
if not app:
|
||||
embedded_mod = False
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if not system_app:
|
||||
system_app = SystemApp(app)
|
||||
global_system_app = system_app
|
||||
|
||||
if api_keys:
|
||||
api_settings.api_keys = api_keys
|
||||
|
||||
app.include_router(router, prefix="/api", tags=["APIServer"])
|
||||
|
||||
@app.exception_handler(APIServerException)
|
||||
async def validation_apiserver_exception_handler(request, exc: APIServerException):
|
||||
return create_error_response(exc.code, exc.message)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
|
||||
|
||||
_initialize_all(controller_addr, system_app)
|
||||
|
||||
if not embedded_mod:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
def run_apiserver():
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "apiserver_"
|
||||
apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass(
|
||||
ModelAPIServerParameters,
|
||||
env_prefixes=[env_prefix],
|
||||
)
|
||||
setup_logging(
|
||||
"pilot",
|
||||
logging_level=apiserver_params.log_level,
|
||||
logger_filename=apiserver_params.log_file,
|
||||
)
|
||||
api_keys = None
|
||||
if apiserver_params.api_keys:
|
||||
api_keys = apiserver_params.api_keys.strip().split(",")
|
||||
|
||||
initialize_apiserver(
|
||||
apiserver_params.controller_addr,
|
||||
host=apiserver_params.host,
|
||||
port=apiserver_params.port,
|
||||
api_keys=api_keys,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_apiserver()
|
0
pilot/model/cluster/apiserver/tests/__init__.py
Normal file
0
pilot/model/cluster/apiserver/tests/__init__.py
Normal file
248
pilot/model/cluster/apiserver/tests/test_api.py
Normal file
248
pilot/model/cluster/apiserver/tests/test_api.py
Normal file
@ -0,0 +1,248 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from aioresponses import aioresponses
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from httpx import AsyncClient, HTTPError
|
||||
|
||||
from pilot.component import SystemApp
|
||||
from pilot.utils.openai_utils import chat_completion_stream, chat_completion
|
||||
|
||||
from pilot.model.cluster.apiserver.api import (
|
||||
api_settings,
|
||||
initialize_apiserver,
|
||||
ModelList,
|
||||
UsageInfo,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
)
|
||||
from pilot.model.cluster.tests.conftest import _new_cluster
|
||||
|
||||
from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def system_app():
|
||||
return SystemApp(app)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(request, system_app: SystemApp):
|
||||
param = getattr(request, "param", {})
|
||||
api_keys = param.get("api_keys", [])
|
||||
client_api_key = param.get("client_api_key")
|
||||
if "num_workers" not in param:
|
||||
param["num_workers"] = 2
|
||||
if "api_keys" in param:
|
||||
del param["api_keys"]
|
||||
headers = {}
|
||||
if client_api_key:
|
||||
headers["Authorization"] = "Bearer " + client_api_key
|
||||
print(f"param: {param}")
|
||||
if api_settings:
|
||||
# Clear global api keys
|
||||
api_settings.api_keys = []
|
||||
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
|
||||
async with _new_cluster(**param) as cluster:
|
||||
worker_manager, model_registry = cluster
|
||||
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
|
||||
system_app.register_instance(model_registry)
|
||||
# print(f"Instances {model_registry.registry}")
|
||||
initialize_apiserver(None, app, system_app, api_keys=api_keys)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_models(client: AsyncClient):
|
||||
res = await client.get("/api/v1/models")
|
||||
res.status_code == 200
|
||||
model_lists = ModelList.parse_obj(res.json())
|
||||
print(f"model list json: {res.json()}")
|
||||
assert model_lists.object == "list"
|
||||
assert len(model_lists.data) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages",
|
||||
[
|
||||
({"stream_messags": ["Hello", " world."]}, "Hello world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions(client: AsyncClient, expected_messages):
|
||||
chat_data = {
|
||||
"model": "test-model-name-0",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
}
|
||||
full_text = ""
|
||||
async for text in chat_completion_stream(
|
||||
"/api/v1/chat/completions", chat_data, client
|
||||
):
|
||||
full_text += text
|
||||
assert full_text == expected_messages
|
||||
|
||||
assert (
|
||||
await chat_completion("/api/v1/chat/completions", chat_data, client)
|
||||
== expected_messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages, client_api_key",
|
||||
[
|
||||
(
|
||||
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions_with_openai_lib_async_no_stream(
|
||||
client: AsyncClient, expected_messages: str, client_api_key: str
|
||||
):
|
||||
import openai
|
||||
|
||||
openai.api_key = client_api_key
|
||||
openai.api_base = "http://test/api/v1"
|
||||
|
||||
model_name = "test-model-name-0"
|
||||
|
||||
with aioresponses() as mocked:
|
||||
mock_message = {"text": expected_messages}
|
||||
one_res = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=expected_messages),
|
||||
finish_reason="stop",
|
||||
)
|
||||
data = ChatCompletionResponse(
|
||||
model=model_name, choices=[one_res], usage=UsageInfo()
|
||||
)
|
||||
mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
# Mock http request
|
||||
mocked.post(
|
||||
"http://test/api/v1/chat/completions", status=200, body=mock_message
|
||||
)
|
||||
completion = await openai.ChatCompletion.acreate(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": "Hello! What is your name?"}],
|
||||
)
|
||||
assert completion.choices[0].message.content == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages, client_api_key",
|
||||
[
|
||||
(
|
||||
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
|
||||
"Hello world.",
|
||||
"abc",
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions_with_openai_lib_async_stream(
|
||||
client: AsyncClient, expected_messages: str, client_api_key: str
|
||||
):
|
||||
import openai
|
||||
|
||||
openai.api_key = client_api_key
|
||||
openai.api_base = "http://test/api/v1"
|
||||
|
||||
model_name = "test-model-name-0"
|
||||
|
||||
with aioresponses() as mocked:
|
||||
mock_message = {"text": expected_messages}
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=expected_messages),
|
||||
finish_reason="stop",
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=0, choices=[choice_data], model=model_name
|
||||
)
|
||||
mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
mocked.post(
|
||||
"http://test/api/v1/chat/completions",
|
||||
status=200,
|
||||
body=mock_message,
|
||||
content_type="text/event-stream",
|
||||
)
|
||||
|
||||
stream_stream_resp = ""
|
||||
async for stream_resp in await openai.ChatCompletion.acreate(
|
||||
model=model_name,
|
||||
messages=[{"role": "user", "content": "Hello! What is your name?"}],
|
||||
stream=True,
|
||||
):
|
||||
stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
|
||||
assert stream_stream_resp == expected_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, expected_messages, api_key_is_error",
|
||||
[
|
||||
(
|
||||
{
|
||||
"stream_messags": ["Hello", " world."],
|
||||
"api_keys": ["abc", "xx"],
|
||||
"client_api_key": "abc",
|
||||
},
|
||||
"Hello world.",
|
||||
False,
|
||||
),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False),
|
||||
(
|
||||
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]},
|
||||
"你好,我是张三。",
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"stream_messags": ["你好,我是", "张三。"],
|
||||
"api_keys": ["abc", "xx"],
|
||||
"client_api_key": "error_api_key",
|
||||
},
|
||||
"你好,我是张三。",
|
||||
True,
|
||||
),
|
||||
],
|
||||
indirect=["client"],
|
||||
)
|
||||
async def test_chat_completions_with_api_keys(
|
||||
client: AsyncClient, expected_messages: str, api_key_is_error: bool
|
||||
):
|
||||
chat_data = {
|
||||
"model": "test-model-name-0",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": True,
|
||||
}
|
||||
if api_key_is_error:
|
||||
with pytest.raises(HTTPError):
|
||||
await chat_completion("/api/v1/chat/completions", chat_data, client)
|
||||
else:
|
||||
assert (
|
||||
await chat_completion("/api/v1/chat/completions", chat_data, client)
|
||||
== expected_messages
|
||||
)
|
@ -66,7 +66,9 @@ class LocalModelController(BaseModelController):
|
||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||
)
|
||||
if not model_name:
|
||||
return await self.registry.get_all_model_instances()
|
||||
return await self.registry.get_all_model_instances(
|
||||
healthy_only=healthy_only
|
||||
)
|
||||
else:
|
||||
return await self.registry.get_all_instances(model_name, healthy_only)
|
||||
|
||||
@ -98,8 +100,10 @@ class _RemoteModelController(BaseModelController):
|
||||
|
||||
|
||||
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
return await self.get_all_instances()
|
||||
async def get_all_model_instances(
|
||||
self, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
return await self.get_all_instances(healthy_only=healthy_only)
|
||||
|
||||
@sync_api_remote(path="/api/controller/models")
|
||||
def sync_get_all_instances(
|
||||
|
@ -1,22 +1,37 @@
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import itertools
|
||||
|
||||
from pilot.component import BaseComponent, ComponentType, SystemApp
|
||||
from pilot.model.base import ModelInstance
|
||||
|
||||
|
||||
class ModelRegistry(ABC):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRegistry(BaseComponent, ABC):
|
||||
"""
|
||||
Abstract base class for a model registry. It provides an interface
|
||||
for registering, deregistering, fetching instances, and sending heartbeats
|
||||
for instances.
|
||||
"""
|
||||
|
||||
name = ComponentType.MODEL_REGISTRY
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
self.system_app = system_app
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the component with the main application."""
|
||||
self.system_app = system_app
|
||||
|
||||
@abstractmethod
|
||||
async def register_instance(self, instance: ModelInstance) -> bool:
|
||||
"""
|
||||
@ -65,9 +80,11 @@ class ModelRegistry(ABC):
|
||||
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
async def get_all_model_instances(
|
||||
self, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
"""
|
||||
Fetch all instances of all models
|
||||
Fetch all instances of all models, Optionally, fetch only the healthy instances.
|
||||
|
||||
Returns:
|
||||
- List[ModelInstance]: A list of instances for the all models.
|
||||
@ -105,8 +122,12 @@ class ModelRegistry(ABC):
|
||||
|
||||
class EmbeddedModelRegistry(ModelRegistry):
|
||||
def __init__(
|
||||
self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120
|
||||
self,
|
||||
system_app: SystemApp | None = None,
|
||||
heartbeat_interval_secs: int = 60,
|
||||
heartbeat_timeout_secs: int = 120,
|
||||
):
|
||||
super().__init__(system_app)
|
||||
self.registry: Dict[str, List[ModelInstance]] = defaultdict(list)
|
||||
self.heartbeat_interval_secs = heartbeat_interval_secs
|
||||
self.heartbeat_timeout_secs = heartbeat_timeout_secs
|
||||
@ -180,9 +201,14 @@ class EmbeddedModelRegistry(ModelRegistry):
|
||||
instances = [ins for ins in instances if ins.healthy == True]
|
||||
return instances
|
||||
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
print(self.registry)
|
||||
return list(itertools.chain(*self.registry.values()))
|
||||
async def get_all_model_instances(
|
||||
self, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
logger.debug("Current registry metadata:\n{self.registry}")
|
||||
instances = list(itertools.chain(*self.registry.values()))
|
||||
if healthy_only:
|
||||
instances = [ins for ins in instances if ins.healthy == True]
|
||||
return instances
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
_, exist_ins = self._get_instances(
|
||||
|
0
pilot/model/cluster/tests/__init__.py
Normal file
0
pilot/model/cluster/tests/__init__.py
Normal file
@ -6,6 +6,7 @@ from pilot.model.parameter import ModelParameters, ModelWorkerParameters, Worker
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
WorkerManager,
|
||||
LocalWorkerManager,
|
||||
RegisterFunc,
|
||||
DeregisterFunc,
|
||||
@ -13,6 +14,23 @@ from pilot.model.cluster.worker.manager import (
|
||||
ApplyFunction,
|
||||
)
|
||||
|
||||
from pilot.model.base import ModelInstance
|
||||
from pilot.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_registry(request):
|
||||
return EmbeddedModelRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_instance():
|
||||
return ModelInstance(
|
||||
model_name="test_model",
|
||||
host="192.168.1.1",
|
||||
port=5000,
|
||||
)
|
||||
|
||||
|
||||
class MockModelWorker(ModelWorker):
|
||||
def __init__(
|
||||
@ -51,8 +69,10 @@ class MockModelWorker(ModelWorker):
|
||||
raise Exception("Stop worker error for mock")
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
full_text = ""
|
||||
for msg in self.stream_messags:
|
||||
yield ModelOutput(text=msg, error_code=0)
|
||||
full_text += msg
|
||||
yield ModelOutput(text=full_text, error_code=0)
|
||||
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
output = None
|
||||
@ -67,6 +87,8 @@ class MockModelWorker(ModelWorker):
|
||||
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
|
||||
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
|
||||
|
||||
ClusterType = Tuple[WorkerManager, ModelRegistry]
|
||||
|
||||
|
||||
def _new_worker_params(
|
||||
model_name: str = _TEST_MODEL_NAME,
|
||||
@ -85,7 +107,9 @@ def _create_workers(
|
||||
worker_type: str = WorkerType.LLM.value,
|
||||
stream_messags: List[str] = None,
|
||||
embeddings: List[List[float]] = None,
|
||||
) -> List[Tuple[ModelWorker, ModelWorkerParameters]]:
|
||||
host: str = "127.0.0.1",
|
||||
start_port=8001,
|
||||
) -> List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]:
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
model_name = f"test-model-name-{i}"
|
||||
@ -98,10 +122,16 @@ def _create_workers(
|
||||
stream_messags=stream_messags,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
model_instance = ModelInstance(
|
||||
model_name=WorkerType.to_worker_key(model_name, worker_type),
|
||||
host=host,
|
||||
port=start_port + i,
|
||||
healthy=True,
|
||||
)
|
||||
worker_params = _new_worker_params(
|
||||
model_name, model_path, worker_type=worker_type
|
||||
)
|
||||
workers.append((worker, worker_params))
|
||||
workers.append((worker, worker_params, model_instance))
|
||||
return workers
|
||||
|
||||
|
||||
@ -127,12 +157,12 @@ async def _start_worker_manager(**kwargs):
|
||||
model_registry=model_registry,
|
||||
)
|
||||
|
||||
for worker, worker_params in _create_workers(
|
||||
for worker, worker_params, model_instance in _create_workers(
|
||||
num_workers, error_worker, stop_error, stream_messags, embeddings
|
||||
):
|
||||
worker_manager.add_worker(worker, worker_params)
|
||||
if workers:
|
||||
for worker, worker_params in workers:
|
||||
for worker, worker_params, model_instance in workers:
|
||||
worker_manager.add_worker(worker, worker_params)
|
||||
|
||||
if start:
|
||||
@ -143,6 +173,15 @@ async def _start_worker_manager(**kwargs):
|
||||
await worker_manager.stop()
|
||||
|
||||
|
||||
async def _create_model_registry(
|
||||
workers: List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]
|
||||
) -> ModelRegistry:
|
||||
registry = EmbeddedModelRegistry()
|
||||
for _, _, inst in workers:
|
||||
assert await registry.register_instance(inst) == True
|
||||
return registry
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def manager_2_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
@ -166,3 +205,27 @@ async def manager_2_embedding_workers(request):
|
||||
)
|
||||
async with _start_worker_manager(workers=workers, **param) as worker_manager:
|
||||
yield (worker_manager, workers)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _new_cluster(**kwargs) -> ClusterType:
|
||||
num_workers = kwargs.get("num_workers", 0)
|
||||
workers = _create_workers(
|
||||
num_workers, stream_messags=kwargs.get("stream_messags", [])
|
||||
)
|
||||
if "num_workers" in kwargs:
|
||||
del kwargs["num_workers"]
|
||||
registry = await _create_model_registry(
|
||||
workers,
|
||||
)
|
||||
async with _start_worker_manager(workers=workers, **kwargs) as worker_manager:
|
||||
yield (worker_manager, registry)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def cluster_2_workers(request):
|
||||
param = getattr(request, "param", {})
|
||||
workers = _create_workers(2)
|
||||
registry = await _create_model_registry(workers)
|
||||
async with _start_worker_manager(workers=workers, **param) as worker_manager:
|
||||
yield (worker_manager, registry)
|
@ -256,15 +256,22 @@ class DefaultModelWorker(ModelWorker):
|
||||
return params, model_context, generate_stream_func, model_span
|
||||
|
||||
def _handle_output(self, output, previous_response, model_context):
|
||||
finish_reason = None
|
||||
usage = None
|
||||
if isinstance(output, dict):
|
||||
finish_reason = output.get("finish_reason")
|
||||
usage = output.get("usage")
|
||||
output = output["text"]
|
||||
if finish_reason is not None:
|
||||
logger.info(f"finish_reason: {finish_reason}")
|
||||
incremental_output = output[len(previous_response) :]
|
||||
print(incremental_output, end="", flush=True)
|
||||
model_output = ModelOutput(
|
||||
text=output, error_code=0, model_context=model_context
|
||||
text=output,
|
||||
error_code=0,
|
||||
model_context=model_context,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
return model_output, incremental_output, output
|
||||
|
||||
|
@ -99,9 +99,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
)
|
||||
|
||||
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
||||
if isinstance(worker_type, WorkerType):
|
||||
worker_type = worker_type.value
|
||||
return f"{model_name}@{worker_type}"
|
||||
return WorkerType.to_worker_key(model_name, worker_type)
|
||||
|
||||
async def run_blocking_func(self, func, *args):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
@ -3,7 +3,7 @@ import pytest
|
||||
from typing import List, Iterator, Dict, Tuple
|
||||
from dataclasses import asdict
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.model.base import ModelOutput, WorkerApplyType
|
||||
from pilot.model.base import ModelOutput, WorkerApplyType, ModelInstance
|
||||
from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.manager_base import WorkerRunData
|
||||
@ -14,7 +14,7 @@ from pilot.model.cluster.worker.manager import (
|
||||
SendHeartbeatFunc,
|
||||
ApplyFunction,
|
||||
)
|
||||
from pilot.model.cluster.worker.tests.base_tests import (
|
||||
from pilot.model.cluster.tests.conftest import (
|
||||
MockModelWorker,
|
||||
manager_2_workers,
|
||||
manager_with_2_workers,
|
||||
@ -216,7 +216,7 @@ async def test__remove_worker():
|
||||
workers = _create_workers(3)
|
||||
async with _start_worker_manager(workers=workers, stop=False) as manager:
|
||||
assert len(manager.workers) == 3
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
manager._remove_worker(worker_params)
|
||||
not_exist_parmas = _new_worker_params(
|
||||
model_name="this is a not exist worker params"
|
||||
@ -229,7 +229,7 @@ async def test__remove_worker():
|
||||
async def test_model_startup(mock_build_worker):
|
||||
async with _start_worker_manager() as manager:
|
||||
workers = _create_workers(1)
|
||||
worker, worker_params = workers[0]
|
||||
worker, worker_params, model_instance = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
|
||||
req = WorkerStartupRequest(
|
||||
@ -245,7 +245,7 @@ async def test_model_startup(mock_build_worker):
|
||||
|
||||
async with _start_worker_manager() as manager:
|
||||
workers = _create_workers(1, error_worker=True)
|
||||
worker, worker_params = workers[0]
|
||||
worker, worker_params, model_instance = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
req = WorkerStartupRequest(
|
||||
host="127.0.0.1",
|
||||
@ -263,7 +263,7 @@ async def test_model_startup(mock_build_worker):
|
||||
async def test_model_shutdown(mock_build_worker):
|
||||
async with _start_worker_manager(start=False, stop=False) as manager:
|
||||
workers = _create_workers(1)
|
||||
worker, worker_params = workers[0]
|
||||
worker, worker_params, model_instance = workers[0]
|
||||
mock_build_worker.return_value = worker
|
||||
|
||||
req = WorkerStartupRequest(
|
||||
@ -298,7 +298,7 @@ async def test_get_model_instances(is_async):
|
||||
workers = _create_workers(3)
|
||||
async with _start_worker_manager(workers=workers, stop=False) as manager:
|
||||
assert len(manager.workers) == 3
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
if is_async:
|
||||
@ -326,7 +326,7 @@ async def test__simple_select(
|
||||
]
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
instances = await manager.get_model_instances(worker_type, model_name)
|
||||
@ -351,7 +351,7 @@ async def test_select_one_instance(
|
||||
],
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
if is_async:
|
||||
@ -376,7 +376,7 @@ async def test__get_model(
|
||||
],
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
@ -403,13 +403,13 @@ async def test_generate_stream(
|
||||
expected_messages: str,
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
text = ""
|
||||
async for out in manager.generate_stream(params):
|
||||
text += out.text
|
||||
text = out.text
|
||||
assert text == expected_messages
|
||||
|
||||
|
||||
@ -417,8 +417,8 @@ async def test_generate_stream(
|
||||
@pytest.mark.parametrize(
|
||||
"manager_with_2_workers, expected_messages",
|
||||
[
|
||||
({"stream_messags": ["Hello", " world."]}, " world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "张三。"),
|
||||
({"stream_messags": ["Hello", " world."]}, "Hello world."),
|
||||
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
|
||||
],
|
||||
indirect=["manager_with_2_workers"],
|
||||
)
|
||||
@ -429,7 +429,7 @@ async def test_generate(
|
||||
expected_messages: str,
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name}
|
||||
@ -454,7 +454,7 @@ async def test_embeddings(
|
||||
is_async: bool,
|
||||
):
|
||||
manager, workers = manager_2_embedding_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = {"model": model_name, "input": ["hello", "world"]}
|
||||
@ -472,7 +472,7 @@ async def test_parameter_descriptions(
|
||||
]
|
||||
):
|
||||
manager, workers = manager_with_2_workers
|
||||
for _, worker_params in workers:
|
||||
for _, worker_params, _ in workers:
|
||||
model_name = worker_params.model_name
|
||||
worker_type = worker_params.worker_type
|
||||
params = await manager.parameter_descriptions(worker_type, model_name)
|
||||
|
@ -467,7 +467,8 @@ register_conv_template(
|
||||
sep="\n",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
)
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
|
||||
register_conv_template(
|
||||
@ -482,7 +483,8 @@ register_conv_template(
|
||||
sep="###",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
)
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
|
||||
register_conv_template(
|
||||
@ -495,5 +497,6 @@ register_conv_template(
|
||||
sep="",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "<|endoftext|>"],
|
||||
)
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
|
@ -1,9 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union, Tuple
|
||||
|
||||
from pilot.model.conversation import conv_templates
|
||||
from pilot.utils.parameter_utils import BaseParameters
|
||||
@ -19,6 +20,35 @@ class WorkerType(str, Enum):
|
||||
def values():
|
||||
return [item.value for item in WorkerType]
|
||||
|
||||
@staticmethod
|
||||
def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str:
|
||||
"""Generate worker key from worker name and worker type
|
||||
|
||||
Args:
|
||||
worker_name (str): Worker name(eg., chatglm2-6b)
|
||||
worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`])
|
||||
|
||||
Returns:
|
||||
str: Generated worker key
|
||||
"""
|
||||
if "@" in worker_name:
|
||||
raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}")
|
||||
if isinstance(worker_type, WorkerType):
|
||||
worker_type = worker_type.value
|
||||
return f"{worker_name}@{worker_type}"
|
||||
|
||||
@staticmethod
|
||||
def parse_worker_key(worker_key: str) -> Tuple[str, str]:
|
||||
"""Parse worker name and worker type from worker key
|
||||
|
||||
Args:
|
||||
worker_key (str): Worker key generated by [`WorkerType.to_worker_key`]
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: Worker name and worker type
|
||||
"""
|
||||
return tuple(worker_key.split("@"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelControllerParameters(BaseParameters):
|
||||
@ -60,6 +90,56 @@ class ModelControllerParameters(BaseParameters):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelAPIServerParameters(BaseParameters):
|
||||
host: Optional[str] = field(
|
||||
default="0.0.0.0", metadata={"help": "Model API server deploy host"}
|
||||
)
|
||||
port: Optional[int] = field(
|
||||
default=8100, metadata={"help": "Model API server deploy port"}
|
||||
)
|
||||
daemon: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Run Model API server in background"}
|
||||
)
|
||||
controller_addr: Optional[str] = field(
|
||||
default="http://127.0.0.1:8000",
|
||||
metadata={"help": "The Model controller address to connect"},
|
||||
)
|
||||
|
||||
api_keys: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optional list of comma separated API keys"},
|
||||
)
|
||||
|
||||
log_level: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Logging level",
|
||||
"valid_values": [
|
||||
"FATAL",
|
||||
"ERROR",
|
||||
"WARNING",
|
||||
"WARNING",
|
||||
"INFO",
|
||||
"DEBUG",
|
||||
"NOTSET",
|
||||
],
|
||||
},
|
||||
)
|
||||
log_file: Optional[str] = field(
|
||||
default="dbgpt_model_apiserver.log",
|
||||
metadata={
|
||||
"help": "The filename to store log",
|
||||
},
|
||||
)
|
||||
tracer_file: Optional[str] = field(
|
||||
default="dbgpt_model_apiserver_tracer.jsonl",
|
||||
metadata={
|
||||
"help": "The filename to store tracer span records",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelParameters(BaseParameters):
|
||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from typing import Any, Dict, List, Tuple, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
@ -70,14 +70,6 @@ class SystemMessage(BaseMessage):
|
||||
return "system"
|
||||
|
||||
|
||||
class ModelMessage(BaseModel):
|
||||
"""Type of message that interaction between dbgpt-server and llm-server"""
|
||||
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ModelMessageRoleType:
|
||||
""" "Type of ModelMessage role"""
|
||||
|
||||
@ -87,6 +79,45 @@ class ModelMessageRoleType:
|
||||
VIEW = "view"
|
||||
|
||||
|
||||
class ModelMessage(BaseModel):
|
||||
"""Type of message that interaction between dbgpt-server and llm-server"""
|
||||
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
content: str
|
||||
|
||||
@staticmethod
|
||||
def from_openai_messages(
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List["ModelMessage"]:
|
||||
"""Openai message format to current ModelMessage format"""
|
||||
if isinstance(messages, str):
|
||||
return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
|
||||
result = []
|
||||
for message in messages:
|
||||
msg_role = message["role"]
|
||||
content = message["content"]
|
||||
if msg_role == "system":
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content)
|
||||
)
|
||||
elif msg_role == "user":
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
)
|
||||
elif msg_role == "assistant":
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content=content)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||
return list(map(lambda m: m.dict(), messages))
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
"""Output of a single generation."""
|
||||
|
||||
|
99
pilot/utils/openai_utils.py
Normal file
99
pilot/utils/openai_utils.py
Normal file
@ -0,0 +1,99 @@
|
||||
from typing import Dict, Any, Awaitable, Callable, Optional, Iterator
|
||||
import httpx
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
MessageCaller = Callable[[str], Awaitable[None]]
|
||||
|
||||
|
||||
async def _do_chat_completion(
|
||||
url: str,
|
||||
chat_data: Dict[str, Any],
|
||||
client: httpx.AsyncClient,
|
||||
headers: Dict[str, Any] = {},
|
||||
timeout: int = 60,
|
||||
caller: Optional[MessageCaller] = None,
|
||||
) -> Iterator[str]:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=chat_data,
|
||||
timeout=timeout,
|
||||
) as res:
|
||||
if res.status_code != 200:
|
||||
error_message = await res.aread()
|
||||
if error_message:
|
||||
error_message = error_message.decode("utf-8")
|
||||
logger.error(
|
||||
f"Request failed with status {res.status_code}. Error: {error_message}"
|
||||
)
|
||||
raise httpx.RequestError(
|
||||
f"Request failed with status {res.status_code}",
|
||||
request=res.request,
|
||||
)
|
||||
async for line in res.aiter_lines():
|
||||
if line:
|
||||
if not line.startswith("data: "):
|
||||
if caller:
|
||||
await caller(line)
|
||||
yield line
|
||||
else:
|
||||
decoded_line = line.split("data: ", 1)[1]
|
||||
if decoded_line.lower().strip() != "[DONE]".lower():
|
||||
obj = json.loads(decoded_line)
|
||||
if obj["choices"][0]["delta"].get("content") is not None:
|
||||
text = obj["choices"][0]["delta"].get("content")
|
||||
if caller:
|
||||
await caller(text)
|
||||
yield text
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
|
||||
async def chat_completion_stream(
|
||||
url: str,
|
||||
chat_data: Dict[str, Any],
|
||||
client: Optional[httpx.AsyncClient] = None,
|
||||
headers: Dict[str, Any] = {},
|
||||
timeout: int = 60,
|
||||
caller: Optional[MessageCaller] = None,
|
||||
) -> Iterator[str]:
|
||||
if client:
|
||||
async for text in _do_chat_completion(
|
||||
url,
|
||||
chat_data,
|
||||
client=client,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
caller=caller,
|
||||
):
|
||||
yield text
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async for text in _do_chat_completion(
|
||||
url,
|
||||
chat_data,
|
||||
client=client,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
caller=caller,
|
||||
):
|
||||
yield text
|
||||
|
||||
|
||||
async def chat_completion(
|
||||
url: str,
|
||||
chat_data: Dict[str, Any],
|
||||
client: Optional[httpx.AsyncClient] = None,
|
||||
headers: Dict[str, Any] = {},
|
||||
timeout: int = 60,
|
||||
caller: Optional[MessageCaller] = None,
|
||||
) -> str:
|
||||
full_text = ""
|
||||
async for text in chat_completion_stream(
|
||||
url, chat_data, client, headers=headers, timeout=timeout, caller=caller
|
||||
):
|
||||
full_text += text
|
||||
return full_text
|
@ -8,6 +8,7 @@ pytest-integration
|
||||
pytest-mock
|
||||
pytest-recording
|
||||
pytesseract==0.3.10
|
||||
aioresponses
|
||||
# python code format
|
||||
black
|
||||
# for git hooks
|
||||
|
Loading…
Reference in New Issue
Block a user