From 2c9c539404256e55aff6f479ca91afcd0349e923 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Thu, 2 Nov 2023 20:38:58 +0800 Subject: [PATCH] feat(model): Support OpenAI-Compatible RESTful APIs --- .../cluster-docker-compose.yml | 10 + pilot/base_modules/agent/db/__init__.py | 0 pilot/component.py | 44 +- pilot/model/base.py | 4 +- pilot/model/cli.py | 23 +- pilot/model/cluster/__init__.py | 2 + pilot/model/cluster/apiserver/__init__.py | 0 pilot/model/cluster/apiserver/api.py | 443 ++++++++++++++++++ .../model/cluster/apiserver/tests/__init__.py | 0 .../model/cluster/apiserver/tests/test_api.py | 248 ++++++++++ pilot/model/cluster/controller/controller.py | 10 +- pilot/model/cluster/registry.py | 42 +- pilot/model/cluster/tests/__init__.py | 0 .../tests/base_tests.py => tests/conftest.py} | 73 ++- pilot/model/cluster/worker/default_worker.py | 9 +- pilot/model/cluster/worker/manager.py | 4 +- .../cluster/worker/tests/test_manager.py | 34 +- pilot/model/model_adapter.py | 9 +- pilot/model/parameter.py | 82 +++- pilot/scene/base_message.py | 49 +- pilot/utils/openai_utils.py | 99 ++++ requirements/dev-requirements.txt | 1 + 22 files changed, 1124 insertions(+), 62 deletions(-) create mode 100644 pilot/base_modules/agent/db/__init__.py create mode 100644 pilot/model/cluster/apiserver/__init__.py create mode 100644 pilot/model/cluster/apiserver/api.py create mode 100644 pilot/model/cluster/apiserver/tests/__init__.py create mode 100644 pilot/model/cluster/apiserver/tests/test_api.py create mode 100644 pilot/model/cluster/tests/__init__.py rename pilot/model/cluster/{worker/tests/base_tests.py => tests/conftest.py} (71%) create mode 100644 pilot/utils/openai_utils.py diff --git a/docker/compose_examples/cluster-docker-compose.yml b/docker/compose_examples/cluster-docker-compose.yml index b41033458..0ad6be9ae 100644 --- a/docker/compose_examples/cluster-docker-compose.yml +++ b/docker/compose_examples/cluster-docker-compose.yml @@ -7,6 +7,16 @@ services: restart: unless-stopped networks: - dbgptnet + api-server: + image: eosphorosai/dbgpt:latest + command: dbgpt start apiserver --controller_addr http://controller:8000 + restart: unless-stopped + depends_on: + - controller + networks: + - dbgptnet + ports: + - 8100:8100/tcp llm-worker: image: eosphorosai/dbgpt:latest command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000 diff --git a/pilot/base_modules/agent/db/__init__.py b/pilot/base_modules/agent/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/component.py b/pilot/component.py index 8f8c8c5a4..8182f3435 100644 --- a/pilot/component.py +++ b/pilot/component.py @@ -46,6 +46,8 @@ class ComponentType(str, Enum): WORKER_MANAGER = "dbgpt_worker_manager" WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory" MODEL_CONTROLLER = "dbgpt_model_controller" + MODEL_REGISTRY = "dbgpt_model_registry" + MODEL_API_SERVER = "dbgpt_model_api_server" AGENT_HUB = "dbgpt_agent_hub" EXECUTOR_DEFAULT = "dbgpt_thread_pool_default" TRACER = "dbgpt_tracer" @@ -68,7 +70,6 @@ class BaseComponent(LifeCycle, ABC): This method needs to be implemented by every component to define how it integrates with the main system app. """ - pass T = TypeVar("T", bound=BaseComponent) @@ -90,13 +91,28 @@ class SystemApp(LifeCycle): """Returns the internal ASGI app.""" return self._asgi_app - def register(self, component: Type[BaseComponent], *args, **kwargs): - """Register a new component by its type.""" + def register(self, component: Type[BaseComponent], *args, **kwargs) -> T: + """Register a new component by its type. + + Args: + component (Type[BaseComponent]): The component class to register + + Returns: + T: The instance of registered component + """ instance = component(self, *args, **kwargs) self.register_instance(instance) + return instance - def register_instance(self, instance: T): - """Register an already initialized component.""" + def register_instance(self, instance: T) -> T: + """Register an already initialized component. + + Args: + instance (T): The component instance to register + + Returns: + T: The instance of registered component + """ name = instance.name if isinstance(name, ComponentType): name = name.value @@ -107,18 +123,34 @@ class SystemApp(LifeCycle): logger.info(f"Register component with name {name} and instance: {instance}") self.components[name] = instance instance.init_app(self) + return instance def get_component( self, name: Union[str, ComponentType], component_type: Type[T], default_component=_EMPTY_DEFAULT_COMPONENT, + or_register_component: Type[BaseComponent] = None, + *args, + **kwargs, ) -> T: - """Retrieve a registered component by its name and type.""" + """Retrieve a registered component by its name and type. + + Args: + name (Union[str, ComponentType]): Component name + component_type (Type[T]): The type of current retrieve component + default_component : The default component instance if not retrieve by name + or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name + + Returns: + T: The instance retrieved by component name + """ if isinstance(name, ComponentType): name = name.value component = self.components.get(name) if not component: + if or_register_component: + return self.register(or_register_component, *args, **kwargs) if default_component != _EMPTY_DEFAULT_COMPONENT: return default_component raise ValueError(f"No component found with name {name}") diff --git a/pilot/model/base.py b/pilot/model/base.py index e89b243c9..48480b94b 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from enum import Enum -from typing import TypedDict, Optional, Dict, List +from typing import TypedDict, Optional, Dict, List, Any from dataclasses import dataclass, asdict from datetime import datetime from pilot.utils.parameter_utils import ParameterDescription @@ -52,6 +52,8 @@ class ModelOutput: text: str error_code: int model_context: Dict = None + finish_reason: str = None + usage: Dict[str, Any] = None def to_dict(self) -> Dict: return asdict(self) diff --git a/pilot/model/cli.py b/pilot/model/cli.py index 1030adfc2..79b47db82 100644 --- a/pilot/model/cli.py +++ b/pilot/model/cli.py @@ -8,6 +8,7 @@ from pilot.configs.model_config import LOGDIR from pilot.model.base import WorkerApplyType from pilot.model.parameter import ( ModelControllerParameters, + ModelAPIServerParameters, ModelWorkerParameters, ModelParameters, BaseParameters, @@ -441,15 +442,27 @@ def stop_model_worker(port: int): @click.command(name="apiserver") +@EnvArgumentParser.create_click_option(ModelAPIServerParameters) def start_apiserver(**kwargs): - """Start apiserver(TODO)""" - raise NotImplementedError + """Start apiserver""" + + if kwargs["daemon"]: + log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log") + _run_current_with_daemon("ModelAPIServer", log_file) + else: + from pilot.model.cluster import run_apiserver + + run_apiserver() @click.command(name="apiserver") -def stop_apiserver(**kwargs): - """Start apiserver(TODO)""" - raise NotImplementedError +@add_stop_server_options +def stop_apiserver(port: int): + """Stop apiserver""" + name = "ModelAPIServer" + if port: + name = f"{name}-{port}" + _stop_service("apiserver", name, port=port) def _stop_all_model_server(**kwargs): diff --git a/pilot/model/cluster/__init__.py b/pilot/model/cluster/__init__.py index 9937ffa0b..a777a8d4b 100644 --- a/pilot/model/cluster/__init__.py +++ b/pilot/model/cluster/__init__.py @@ -21,6 +21,7 @@ from pilot.model.cluster.controller.controller import ( run_model_controller, BaseModelController, ) +from pilot.model.cluster.apiserver.api import run_apiserver from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager @@ -40,4 +41,5 @@ __all__ = [ "ModelRegistryClient", "RemoteWorkerManager", "run_model_controller", + "run_apiserver", ] diff --git a/pilot/model/cluster/apiserver/__init__.py b/pilot/model/cluster/apiserver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/apiserver/api.py b/pilot/model/cluster/apiserver/api.py new file mode 100644 index 000000000..148a51eed --- /dev/null +++ b/pilot/model/cluster/apiserver/api.py @@ -0,0 +1,443 @@ +"""A server that provides OpenAI-compatible RESTful APIs. It supports: +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) + +Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py +""" +from typing import Optional, List, Dict, Any, Generator + +import logging +import asyncio +import shortuuid +import json +from fastapi import APIRouter, FastAPI +from fastapi import Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + +from pydantic import BaseSettings + +from fastchat.protocol.openai_api_protocol import ( + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + DeltaMessage, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) +from fastchat.protocol.api_protocol import ( + APIChatCompletionRequest, + APITokenCheckRequest, + APITokenCheckResponse, + APITokenCheckResponseItem, +) +from fastchat.serve.openai_api_server import create_error_response, check_requests +from fastchat.constants import ErrorCode + +from pilot.component import BaseComponent, ComponentType, SystemApp +from pilot.utils.parameter_utils import EnvArgumentParser +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +from pilot.model.base import ModelInstance, ModelOutput +from pilot.model.parameter import ModelAPIServerParameters, WorkerType +from pilot.model.cluster import ModelRegistry, ModelRegistryClient +from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory +from pilot.utils.utils import setup_logging + +logger = logging.getLogger(__name__) + + +class APIServerException(Exception): + def __init__(self, code: int, message: str): + self.code = code + self.message = message + + +class APISettings(BaseSettings): + api_keys: Optional[List[str]] = None + + +api_settings = APISettings() +get_bearer_token = HTTPBearer(auto_error=False) + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +) -> str: + if api_settings.api_keys: + if auth is None or (token := auth.credentials) not in api_settings.api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +class APIServer(BaseComponent): + name = ComponentType.MODEL_API_SERVER + + def init_app(self, system_app: SystemApp): + self.system_app = system_app + + def get_worker_manager(self) -> WorkerManager: + """Get the worker manager component instance + + Raises: + APIServerException: If can't get worker manager component instance + """ + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + if not worker_manager: + raise APIServerException( + ErrorCode.INTERNAL_ERROR, + f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app", + ) + return worker_manager + + def get_model_registry(self) -> ModelRegistry: + """Get the model registry component instance + + Raises: + APIServerException: If can't get model registry component instance + """ + + controller = self.system_app.get_component( + ComponentType.MODEL_REGISTRY, ModelRegistry + ) + if not controller: + raise APIServerException( + ErrorCode.INTERNAL_ERROR, + f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app", + ) + return controller + + async def get_model_instances_or_raise( + self, model_name: str + ) -> List[ModelInstance]: + """Get healthy model instances with request model name + + Args: + model_name (str): Model name + + Raises: + APIServerException: If can't get healthy model instances with request model name + """ + registry = self.get_model_registry() + registry_model_name = f"{model_name}@llm" + model_instances = await registry.get_all_instances( + registry_model_name, healthy_only=True + ) + if not model_instances: + all_instances = await registry.get_all_model_instances(healthy_only=True) + models = [ + ins.model_name.split("@llm")[0] + for ins in all_instances + if ins.model_name.endswith("@llm") + ] + if models: + models = "&&".join(models) + message = f"Only {models} allowed now, your model {model_name}" + else: + message = f"No models allowed now, your model {model_name}" + raise APIServerException(ErrorCode.INVALID_MODEL, message) + return model_instances + + async def get_available_models(self) -> ModelList: + """Return available models + + Just include LLM and embedding models. + + Returns: + List[ModelList]: The list of models. + """ + registry = self.get_model_registry() + model_instances = await registry.get_all_model_instances(healthy_only=True) + model_name_set = set() + for inst in model_instances: + name, worker_type = WorkerType.parse_worker_key(inst.model_name) + if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC: + model_name_set.add(name) + models = list(model_name_set) + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append( + ModelCard( + id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()] + ) + ) + return ModelList(data=model_cards) + + async def chat_completion_stream_generator( + self, model_name: str, params: Dict[str, Any], n: int + ) -> Generator[str, Any, None]: + """Chat stream completion generator + + Args: + model_name (str): Model name + params (Dict[str, Any]): The parameters pass to model worker + n (int): How many completions to generate for each prompt. + """ + worker_manager = self.get_worker_manager() + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + for i in range(n): + # First chunk with role + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + previous_text = "" + async for model_output in worker_manager.generate_stream(params): + model_output: ModelOutput = model_output + if model_output.error_code != 0: + yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = model_output.text.replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=model_output.finish_reason, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + if delta_text is None: + if model_output.finish_reason is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + async def chat_completion_generate( + self, model_name: str, params: Dict[str, Any], n: int + ) -> ChatCompletionResponse: + """Generate completion + Args: + model_name (str): Model name + params (Dict[str, Any]): The parameters pass to model worker + n (int): How many completions to generate for each prompt. + """ + worker_manager: WorkerManager = self.get_worker_manager() + choices = [] + chat_completions = [] + for i in range(n): + model_output = asyncio.create_task(worker_manager.generate(params)) + chat_completions.append(model_output) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, model_output in enumerate(all_tasks): + model_output: ModelOutput = model_output + if model_output.error_code != 0: + return create_error_response(model_output.error_code, model_output.text) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=model_output.text), + finish_reason=model_output.finish_reason or "stop", + ) + ) + if model_output.usage: + task_usage = UsageInfo.parse_obj(model_output.usage) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=model_name, choices=choices, usage=usage) + + +def get_api_server() -> APIServer: + api_server = global_system_app.get_component( + ComponentType.MODEL_API_SERVER, APIServer, default_component=None + ) + if not api_server: + global_system_app.register(APIServer) + return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer) + + +router = APIRouter() + + +@router.get("/v1/models", dependencies=[Depends(check_api_key)]) +async def get_available_models(api_server: APIServer = Depends(get_api_server)): + return await api_server.get_available_models() + + +@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +async def create_chat_completion( + request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server) +): + await api_server.get_model_instances_or_raise(request.model) + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + params = { + "model": request.model, + "messages": ModelMessage.to_dict_list( + ModelMessage.from_openai_messages(request.messages) + ), + "echo": False, + } + if request.temperature: + params["temperature"] = request.temperature + if request.top_p: + params["top_p"] = request.top_p + if request.max_tokens: + params["max_new_tokens"] = request.max_tokens + if request.stop: + params["stop"] = request.stop + if request.user: + params["user"] = request.user + + # TODO check token length + if request.stream: + generator = api_server.chat_completion_stream_generator( + request.model, params, request.n + ) + return StreamingResponse(generator, media_type="text/event-stream") + return await api_server.chat_completion_generate(request.model, params, request.n) + + +def _initialize_all(controller_addr: str, system_app: SystemApp): + from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient + from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory + + if not system_app.get_component( + ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None + ): + # Register model registry if not exist + registry = ModelRegistryClient(controller_addr) + registry.name = ComponentType.MODEL_REGISTRY.value + system_app.register_instance(registry) + + registry = system_app.get_component( + ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None + ) + worker_manager = RemoteWorkerManager(registry) + + # Register worker manager component if not exist + system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, + WorkerManagerFactory, + or_register_component=_DefaultWorkerManagerFactory, + worker_manager=worker_manager, + ) + # Register api server component if not exist + system_app.get_component( + ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer + ) + + +def initialize_apiserver( + controller_addr: str, + app=None, + system_app: SystemApp = None, + host: str = None, + port: int = None, + api_keys: List[str] = None, +): + global global_system_app + global api_settings + embedded_mod = True + if not app: + embedded_mod = False + app = FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], + ) + + if not system_app: + system_app = SystemApp(app) + global_system_app = system_app + + if api_keys: + api_settings.api_keys = api_keys + + app.include_router(router, prefix="/api", tags=["APIServer"]) + + @app.exception_handler(APIServerException) + async def validation_apiserver_exception_handler(request, exc: APIServerException): + return create_error_response(exc.code, exc.message) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + _initialize_all(controller_addr, system_app) + + if not embedded_mod: + import uvicorn + + uvicorn.run(app, host=host, port=port, log_level="info") + + +def run_apiserver(): + parser = EnvArgumentParser() + env_prefix = "apiserver_" + apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass( + ModelAPIServerParameters, + env_prefixes=[env_prefix], + ) + setup_logging( + "pilot", + logging_level=apiserver_params.log_level, + logger_filename=apiserver_params.log_file, + ) + api_keys = None + if apiserver_params.api_keys: + api_keys = apiserver_params.api_keys.strip().split(",") + + initialize_apiserver( + apiserver_params.controller_addr, + host=apiserver_params.host, + port=apiserver_params.port, + api_keys=api_keys, + ) + + +if __name__ == "__main__": + run_apiserver() diff --git a/pilot/model/cluster/apiserver/tests/__init__.py b/pilot/model/cluster/apiserver/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/apiserver/tests/test_api.py b/pilot/model/cluster/apiserver/tests/test_api.py new file mode 100644 index 000000000..281a8aff6 --- /dev/null +++ b/pilot/model/cluster/apiserver/tests/test_api.py @@ -0,0 +1,248 @@ +import pytest +import pytest_asyncio +from aioresponses import aioresponses +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from httpx import AsyncClient, HTTPError + +from pilot.component import SystemApp +from pilot.utils.openai_utils import chat_completion_stream, chat_completion + +from pilot.model.cluster.apiserver.api import ( + api_settings, + initialize_apiserver, + ModelList, + UsageInfo, + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + DeltaMessage, +) +from pilot.model.cluster.tests.conftest import _new_cluster + +from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + + +@pytest_asyncio.fixture +async def system_app(): + return SystemApp(app) + + +@pytest_asyncio.fixture +async def client(request, system_app: SystemApp): + param = getattr(request, "param", {}) + api_keys = param.get("api_keys", []) + client_api_key = param.get("client_api_key") + if "num_workers" not in param: + param["num_workers"] = 2 + if "api_keys" in param: + del param["api_keys"] + headers = {} + if client_api_key: + headers["Authorization"] = "Bearer " + client_api_key + print(f"param: {param}") + if api_settings: + # Clear global api keys + api_settings.api_keys = [] + async with AsyncClient(app=app, base_url="http://test", headers=headers) as client: + async with _new_cluster(**param) as cluster: + worker_manager, model_registry = cluster + system_app.register(_DefaultWorkerManagerFactory, worker_manager) + system_app.register_instance(model_registry) + # print(f"Instances {model_registry.registry}") + initialize_apiserver(None, app, system_app, api_keys=api_keys) + yield client + + +@pytest.mark.asyncio +async def test_get_all_models(client: AsyncClient): + res = await client.get("/api/v1/models") + res.status_code == 200 + model_lists = ModelList.parse_obj(res.json()) + print(f"model list json: {res.json()}") + assert model_lists.object == "list" + assert len(model_lists.data) == 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages", + [ + ({"stream_messags": ["Hello", " world."]}, "Hello world."), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"), + ], + indirect=["client"], +) +async def test_chat_completions(client: AsyncClient, expected_messages): + chat_data = { + "model": "test-model-name-0", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + } + full_text = "" + async for text in chat_completion_stream( + "/api/v1/chat/completions", chat_data, client + ): + full_text += text + assert full_text == expected_messages + + assert ( + await chat_completion("/api/v1/chat/completions", chat_data, client) + == expected_messages + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages, client_api_key", + [ + ( + {"stream_messags": ["Hello", " world."], "api_keys": ["abc"]}, + "Hello world.", + "abc", + ), + ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"), + ], + indirect=["client"], +) +async def test_chat_completions_with_openai_lib_async_no_stream( + client: AsyncClient, expected_messages: str, client_api_key: str +): + import openai + + openai.api_key = client_api_key + openai.api_base = "http://test/api/v1" + + model_name = "test-model-name-0" + + with aioresponses() as mocked: + mock_message = {"text": expected_messages} + one_res = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=expected_messages), + finish_reason="stop", + ) + data = ChatCompletionResponse( + model=model_name, choices=[one_res], usage=UsageInfo() + ) + mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # Mock http request + mocked.post( + "http://test/api/v1/chat/completions", status=200, body=mock_message + ) + completion = await openai.ChatCompletion.acreate( + model=model_name, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + ) + assert completion.choices[0].message.content == expected_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages, client_api_key", + [ + ( + {"stream_messags": ["Hello", " world."], "api_keys": ["abc"]}, + "Hello world.", + "abc", + ), + ({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"), + ], + indirect=["client"], +) +async def test_chat_completions_with_openai_lib_async_stream( + client: AsyncClient, expected_messages: str, client_api_key: str +): + import openai + + openai.api_key = client_api_key + openai.api_base = "http://test/api/v1" + + model_name = "test-model-name-0" + + with aioresponses() as mocked: + mock_message = {"text": expected_messages} + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=expected_messages), + finish_reason="stop", + ) + chunk = ChatCompletionStreamResponse( + id=0, choices=[choice_data], model=model_name + ) + mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + mocked.post( + "http://test/api/v1/chat/completions", + status=200, + body=mock_message, + content_type="text/event-stream", + ) + + stream_stream_resp = "" + async for stream_resp in await openai.ChatCompletion.acreate( + model=model_name, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + stream=True, + ): + stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "") + assert stream_stream_resp == expected_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, expected_messages, api_key_is_error", + [ + ( + { + "stream_messags": ["Hello", " world."], + "api_keys": ["abc", "xx"], + "client_api_key": "abc", + }, + "Hello world.", + False, + ), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False), + ( + {"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]}, + "你好,我是张三。", + True, + ), + ( + { + "stream_messags": ["你好,我是", "张三。"], + "api_keys": ["abc", "xx"], + "client_api_key": "error_api_key", + }, + "你好,我是张三。", + True, + ), + ], + indirect=["client"], +) +async def test_chat_completions_with_api_keys( + client: AsyncClient, expected_messages: str, api_key_is_error: bool +): + chat_data = { + "model": "test-model-name-0", + "messages": [{"role": "user", "content": "Hello"}], + "stream": True, + } + if api_key_is_error: + with pytest.raises(HTTPError): + await chat_completion("/api/v1/chat/completions", chat_data, client) + else: + assert ( + await chat_completion("/api/v1/chat/completions", chat_data, client) + == expected_messages + ) diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index 173c8c019..0006d91a0 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -66,7 +66,9 @@ class LocalModelController(BaseModelController): f"Get all instances with {model_name}, healthy_only: {healthy_only}" ) if not model_name: - return await self.registry.get_all_model_instances() + return await self.registry.get_all_model_instances( + healthy_only=healthy_only + ) else: return await self.registry.get_all_instances(model_name, healthy_only) @@ -98,8 +100,10 @@ class _RemoteModelController(BaseModelController): class ModelRegistryClient(_RemoteModelController, ModelRegistry): - async def get_all_model_instances(self) -> List[ModelInstance]: - return await self.get_all_instances() + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: + return await self.get_all_instances(healthy_only=healthy_only) @sync_api_remote(path="/api/controller/models") def sync_get_all_instances( diff --git a/pilot/model/cluster/registry.py b/pilot/model/cluster/registry.py index 398882eb9..eb5f1e415 100644 --- a/pilot/model/cluster/registry.py +++ b/pilot/model/cluster/registry.py @@ -1,22 +1,37 @@ import random import threading import time +import logging from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime, timedelta -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import itertools +from pilot.component import BaseComponent, ComponentType, SystemApp from pilot.model.base import ModelInstance -class ModelRegistry(ABC): +logger = logging.getLogger(__name__) + + +class ModelRegistry(BaseComponent, ABC): """ Abstract base class for a model registry. It provides an interface for registering, deregistering, fetching instances, and sending heartbeats for instances. """ + name = ComponentType.MODEL_REGISTRY + + def __init__(self, system_app: SystemApp | None = None): + self.system_app = system_app + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Initialize the component with the main application.""" + self.system_app = system_app + @abstractmethod async def register_instance(self, instance: ModelInstance) -> bool: """ @@ -65,9 +80,11 @@ class ModelRegistry(ABC): """Fetch all instances of a given model. Optionally, fetch only the healthy instances.""" @abstractmethod - async def get_all_model_instances(self) -> List[ModelInstance]: + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: """ - Fetch all instances of all models + Fetch all instances of all models, Optionally, fetch only the healthy instances. Returns: - List[ModelInstance]: A list of instances for the all models. @@ -105,8 +122,12 @@ class ModelRegistry(ABC): class EmbeddedModelRegistry(ModelRegistry): def __init__( - self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120 + self, + system_app: SystemApp | None = None, + heartbeat_interval_secs: int = 60, + heartbeat_timeout_secs: int = 120, ): + super().__init__(system_app) self.registry: Dict[str, List[ModelInstance]] = defaultdict(list) self.heartbeat_interval_secs = heartbeat_interval_secs self.heartbeat_timeout_secs = heartbeat_timeout_secs @@ -180,9 +201,14 @@ class EmbeddedModelRegistry(ModelRegistry): instances = [ins for ins in instances if ins.healthy == True] return instances - async def get_all_model_instances(self) -> List[ModelInstance]: - print(self.registry) - return list(itertools.chain(*self.registry.values())) + async def get_all_model_instances( + self, healthy_only: bool = False + ) -> List[ModelInstance]: + logger.debug("Current registry metadata:\n{self.registry}") + instances = list(itertools.chain(*self.registry.values())) + if healthy_only: + instances = [ins for ins in instances if ins.healthy == True] + return instances async def send_heartbeat(self, instance: ModelInstance) -> bool: _, exist_ins = self._get_instances( diff --git a/pilot/model/cluster/tests/__init__.py b/pilot/model/cluster/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/cluster/worker/tests/base_tests.py b/pilot/model/cluster/tests/conftest.py similarity index 71% rename from pilot/model/cluster/worker/tests/base_tests.py rename to pilot/model/cluster/tests/conftest.py index 21821d9f9..f614387ac 100644 --- a/pilot/model/cluster/worker/tests/base_tests.py +++ b/pilot/model/cluster/tests/conftest.py @@ -6,6 +6,7 @@ from pilot.model.parameter import ModelParameters, ModelWorkerParameters, Worker from pilot.model.base import ModelOutput from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker.manager import ( + WorkerManager, LocalWorkerManager, RegisterFunc, DeregisterFunc, @@ -13,6 +14,23 @@ from pilot.model.cluster.worker.manager import ( ApplyFunction, ) +from pilot.model.base import ModelInstance +from pilot.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry + + +@pytest.fixture +def model_registry(request): + return EmbeddedModelRegistry() + + +@pytest.fixture +def model_instance(): + return ModelInstance( + model_name="test_model", + host="192.168.1.1", + port=5000, + ) + class MockModelWorker(ModelWorker): def __init__( @@ -51,8 +69,10 @@ class MockModelWorker(ModelWorker): raise Exception("Stop worker error for mock") def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + full_text = "" for msg in self.stream_messags: - yield ModelOutput(text=msg, error_code=0) + full_text += msg + yield ModelOutput(text=full_text, error_code=0) def generate(self, params: Dict) -> ModelOutput: output = None @@ -67,6 +87,8 @@ class MockModelWorker(ModelWorker): _TEST_MODEL_NAME = "vicuna-13b-v1.5" _TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5" +ClusterType = Tuple[WorkerManager, ModelRegistry] + def _new_worker_params( model_name: str = _TEST_MODEL_NAME, @@ -85,7 +107,9 @@ def _create_workers( worker_type: str = WorkerType.LLM.value, stream_messags: List[str] = None, embeddings: List[List[float]] = None, -) -> List[Tuple[ModelWorker, ModelWorkerParameters]]: + host: str = "127.0.0.1", + start_port=8001, +) -> List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]: workers = [] for i in range(num_workers): model_name = f"test-model-name-{i}" @@ -98,10 +122,16 @@ def _create_workers( stream_messags=stream_messags, embeddings=embeddings, ) + model_instance = ModelInstance( + model_name=WorkerType.to_worker_key(model_name, worker_type), + host=host, + port=start_port + i, + healthy=True, + ) worker_params = _new_worker_params( model_name, model_path, worker_type=worker_type ) - workers.append((worker, worker_params)) + workers.append((worker, worker_params, model_instance)) return workers @@ -127,12 +157,12 @@ async def _start_worker_manager(**kwargs): model_registry=model_registry, ) - for worker, worker_params in _create_workers( + for worker, worker_params, model_instance in _create_workers( num_workers, error_worker, stop_error, stream_messags, embeddings ): worker_manager.add_worker(worker, worker_params) if workers: - for worker, worker_params in workers: + for worker, worker_params, model_instance in workers: worker_manager.add_worker(worker, worker_params) if start: @@ -143,6 +173,15 @@ async def _start_worker_manager(**kwargs): await worker_manager.stop() +async def _create_model_registry( + workers: List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]] +) -> ModelRegistry: + registry = EmbeddedModelRegistry() + for _, _, inst in workers: + assert await registry.register_instance(inst) == True + return registry + + @pytest_asyncio.fixture async def manager_2_workers(request): param = getattr(request, "param", {}) @@ -166,3 +205,27 @@ async def manager_2_embedding_workers(request): ) async with _start_worker_manager(workers=workers, **param) as worker_manager: yield (worker_manager, workers) + + +@asynccontextmanager +async def _new_cluster(**kwargs) -> ClusterType: + num_workers = kwargs.get("num_workers", 0) + workers = _create_workers( + num_workers, stream_messags=kwargs.get("stream_messags", []) + ) + if "num_workers" in kwargs: + del kwargs["num_workers"] + registry = await _create_model_registry( + workers, + ) + async with _start_worker_manager(workers=workers, **kwargs) as worker_manager: + yield (worker_manager, registry) + + +@pytest_asyncio.fixture +async def cluster_2_workers(request): + param = getattr(request, "param", {}) + workers = _create_workers(2) + registry = await _create_model_registry(workers) + async with _start_worker_manager(workers=workers, **param) as worker_manager: + yield (worker_manager, registry) diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 5caa2ee7e..44a476f20 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -256,15 +256,22 @@ class DefaultModelWorker(ModelWorker): return params, model_context, generate_stream_func, model_span def _handle_output(self, output, previous_response, model_context): + finish_reason = None + usage = None if isinstance(output, dict): finish_reason = output.get("finish_reason") + usage = output.get("usage") output = output["text"] if finish_reason is not None: logger.info(f"finish_reason: {finish_reason}") incremental_output = output[len(previous_response) :] print(incremental_output, end="", flush=True) model_output = ModelOutput( - text=output, error_code=0, model_context=model_context + text=output, + error_code=0, + model_context=model_context, + finish_reason=finish_reason, + usage=usage, ) return model_output, incremental_output, output diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index a76fa6685..2dcfb086e 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -99,9 +99,7 @@ class LocalWorkerManager(WorkerManager): ) def _worker_key(self, worker_type: str, model_name: str) -> str: - if isinstance(worker_type, WorkerType): - worker_type = worker_type.value - return f"{model_name}@{worker_type}" + return WorkerType.to_worker_key(model_name, worker_type) async def run_blocking_func(self, func, *args): if asyncio.iscoroutinefunction(func): diff --git a/pilot/model/cluster/worker/tests/test_manager.py b/pilot/model/cluster/worker/tests/test_manager.py index 919e64f99..681fb49a3 100644 --- a/pilot/model/cluster/worker/tests/test_manager.py +++ b/pilot/model/cluster/worker/tests/test_manager.py @@ -3,7 +3,7 @@ import pytest from typing import List, Iterator, Dict, Tuple from dataclasses import asdict from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType -from pilot.model.base import ModelOutput, WorkerApplyType +from pilot.model.base import ModelOutput, WorkerApplyType, ModelInstance from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.manager_base import WorkerRunData @@ -14,7 +14,7 @@ from pilot.model.cluster.worker.manager import ( SendHeartbeatFunc, ApplyFunction, ) -from pilot.model.cluster.worker.tests.base_tests import ( +from pilot.model.cluster.tests.conftest import ( MockModelWorker, manager_2_workers, manager_with_2_workers, @@ -216,7 +216,7 @@ async def test__remove_worker(): workers = _create_workers(3) async with _start_worker_manager(workers=workers, stop=False) as manager: assert len(manager.workers) == 3 - for _, worker_params in workers: + for _, worker_params, _ in workers: manager._remove_worker(worker_params) not_exist_parmas = _new_worker_params( model_name="this is a not exist worker params" @@ -229,7 +229,7 @@ async def test__remove_worker(): async def test_model_startup(mock_build_worker): async with _start_worker_manager() as manager: workers = _create_workers(1) - worker, worker_params = workers[0] + worker, worker_params, model_instance = workers[0] mock_build_worker.return_value = worker req = WorkerStartupRequest( @@ -245,7 +245,7 @@ async def test_model_startup(mock_build_worker): async with _start_worker_manager() as manager: workers = _create_workers(1, error_worker=True) - worker, worker_params = workers[0] + worker, worker_params, model_instance = workers[0] mock_build_worker.return_value = worker req = WorkerStartupRequest( host="127.0.0.1", @@ -263,7 +263,7 @@ async def test_model_startup(mock_build_worker): async def test_model_shutdown(mock_build_worker): async with _start_worker_manager(start=False, stop=False) as manager: workers = _create_workers(1) - worker, worker_params = workers[0] + worker, worker_params, model_instance = workers[0] mock_build_worker.return_value = worker req = WorkerStartupRequest( @@ -298,7 +298,7 @@ async def test_get_model_instances(is_async): workers = _create_workers(3) async with _start_worker_manager(workers=workers, stop=False) as manager: assert len(manager.workers) == 3 - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type if is_async: @@ -326,7 +326,7 @@ async def test__simple_select( ] ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type instances = await manager.get_model_instances(worker_type, model_name) @@ -351,7 +351,7 @@ async def test_select_one_instance( ], ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type if is_async: @@ -376,7 +376,7 @@ async def test__get_model( ], ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name} @@ -403,13 +403,13 @@ async def test_generate_stream( expected_messages: str, ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name} text = "" async for out in manager.generate_stream(params): - text += out.text + text = out.text assert text == expected_messages @@ -417,8 +417,8 @@ async def test_generate_stream( @pytest.mark.parametrize( "manager_with_2_workers, expected_messages", [ - ({"stream_messags": ["Hello", " world."]}, " world."), - ({"stream_messags": ["你好,我是", "张三。"]}, "张三。"), + ({"stream_messags": ["Hello", " world."]}, "Hello world."), + ({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"), ], indirect=["manager_with_2_workers"], ) @@ -429,7 +429,7 @@ async def test_generate( expected_messages: str, ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name} @@ -454,7 +454,7 @@ async def test_embeddings( is_async: bool, ): manager, workers = manager_2_embedding_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = {"model": model_name, "input": ["hello", "world"]} @@ -472,7 +472,7 @@ async def test_parameter_descriptions( ] ): manager, workers = manager_with_2_workers - for _, worker_params in workers: + for _, worker_params, _ in workers: model_name = worker_params.model_name worker_type = worker_params.worker_type params = await manager.parameter_descriptions(worker_type, model_name) diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index e09b868e7..e2deeaa02 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -467,7 +467,8 @@ register_conv_template( sep="\n", sep2="", stop_str=["", "[UNK]"], - ) + ), + override=True, ) # source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 register_conv_template( @@ -482,7 +483,8 @@ register_conv_template( sep="###", sep2="", stop_str=["", "[UNK]"], - ) + ), + override=True, ) # source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 register_conv_template( @@ -495,5 +497,6 @@ register_conv_template( sep="", sep2="", stop_str=["", "<|endoftext|>"], - ) + ), + override=True, ) diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index ea81ec091..e21de1c42 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- + import os from dataclasses import dataclass, field from enum import Enum -from typing import Dict, Optional +from typing import Dict, Optional, Union, Tuple from pilot.model.conversation import conv_templates from pilot.utils.parameter_utils import BaseParameters @@ -19,6 +20,35 @@ class WorkerType(str, Enum): def values(): return [item.value for item in WorkerType] + @staticmethod + def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str: + """Generate worker key from worker name and worker type + + Args: + worker_name (str): Worker name(eg., chatglm2-6b) + worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`]) + + Returns: + str: Generated worker key + """ + if "@" in worker_name: + raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}") + if isinstance(worker_type, WorkerType): + worker_type = worker_type.value + return f"{worker_name}@{worker_type}" + + @staticmethod + def parse_worker_key(worker_key: str) -> Tuple[str, str]: + """Parse worker name and worker type from worker key + + Args: + worker_key (str): Worker key generated by [`WorkerType.to_worker_key`] + + Returns: + Tuple[str, str]: Worker name and worker type + """ + return tuple(worker_key.split("@")) + @dataclass class ModelControllerParameters(BaseParameters): @@ -60,6 +90,56 @@ class ModelControllerParameters(BaseParameters): ) +@dataclass +class ModelAPIServerParameters(BaseParameters): + host: Optional[str] = field( + default="0.0.0.0", metadata={"help": "Model API server deploy host"} + ) + port: Optional[int] = field( + default=8100, metadata={"help": "Model API server deploy port"} + ) + daemon: Optional[bool] = field( + default=False, metadata={"help": "Run Model API server in background"} + ) + controller_addr: Optional[str] = field( + default="http://127.0.0.1:8000", + metadata={"help": "The Model controller address to connect"}, + ) + + api_keys: Optional[str] = field( + default=None, + metadata={"help": "Optional list of comma separated API keys"}, + ) + + log_level: Optional[str] = field( + default=None, + metadata={ + "help": "Logging level", + "valid_values": [ + "FATAL", + "ERROR", + "WARNING", + "WARNING", + "INFO", + "DEBUG", + "NOTSET", + ], + }, + ) + log_file: Optional[str] = field( + default="dbgpt_model_apiserver.log", + metadata={ + "help": "The filename to store log", + }, + ) + tracer_file: Optional[str] = field( + default="dbgpt_model_apiserver_tracer.jsonl", + metadata={ + "help": "The filename to store tracer span records", + }, + ) + + @dataclass class BaseModelParameters(BaseParameters): model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index eeb42a285..12a72e909 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional, Union from pydantic import BaseModel, Field, root_validator @@ -70,14 +70,6 @@ class SystemMessage(BaseMessage): return "system" -class ModelMessage(BaseModel): - """Type of message that interaction between dbgpt-server and llm-server""" - - """Similar to openai's message format""" - role: str - content: str - - class ModelMessageRoleType: """ "Type of ModelMessage role""" @@ -87,6 +79,45 @@ class ModelMessageRoleType: VIEW = "view" +class ModelMessage(BaseModel): + """Type of message that interaction between dbgpt-server and llm-server""" + + """Similar to openai's message format""" + role: str + content: str + + @staticmethod + def from_openai_messages( + messages: Union[str, List[Dict[str, str]]] + ) -> List["ModelMessage"]: + """Openai message format to current ModelMessage format""" + if isinstance(messages, str): + return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)] + result = [] + for message in messages: + msg_role = message["role"] + content = message["content"] + if msg_role == "system": + result.append( + ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content) + ) + elif msg_role == "user": + result.append( + ModelMessage(role=ModelMessageRoleType.HUMAN, content=content) + ) + elif msg_role == "assistant": + result.append( + ModelMessage(role=ModelMessageRoleType.AI, content=content) + ) + else: + raise ValueError(f"Unknown role: {msg_role}") + return result + + @staticmethod + def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]: + return list(map(lambda m: m.dict(), messages)) + + class Generation(BaseModel): """Output of a single generation.""" diff --git a/pilot/utils/openai_utils.py b/pilot/utils/openai_utils.py new file mode 100644 index 000000000..6577d3abf --- /dev/null +++ b/pilot/utils/openai_utils.py @@ -0,0 +1,99 @@ +from typing import Dict, Any, Awaitable, Callable, Optional, Iterator +import httpx +import asyncio +import logging +import json + +logger = logging.getLogger(__name__) +MessageCaller = Callable[[str], Awaitable[None]] + + +async def _do_chat_completion( + url: str, + chat_data: Dict[str, Any], + client: httpx.AsyncClient, + headers: Dict[str, Any] = {}, + timeout: int = 60, + caller: Optional[MessageCaller] = None, +) -> Iterator[str]: + async with client.stream( + "POST", + url, + headers=headers, + json=chat_data, + timeout=timeout, + ) as res: + if res.status_code != 200: + error_message = await res.aread() + if error_message: + error_message = error_message.decode("utf-8") + logger.error( + f"Request failed with status {res.status_code}. Error: {error_message}" + ) + raise httpx.RequestError( + f"Request failed with status {res.status_code}", + request=res.request, + ) + async for line in res.aiter_lines(): + if line: + if not line.startswith("data: "): + if caller: + await caller(line) + yield line + else: + decoded_line = line.split("data: ", 1)[1] + if decoded_line.lower().strip() != "[DONE]".lower(): + obj = json.loads(decoded_line) + if obj["choices"][0]["delta"].get("content") is not None: + text = obj["choices"][0]["delta"].get("content") + if caller: + await caller(text) + yield text + await asyncio.sleep(0.02) + + +async def chat_completion_stream( + url: str, + chat_data: Dict[str, Any], + client: Optional[httpx.AsyncClient] = None, + headers: Dict[str, Any] = {}, + timeout: int = 60, + caller: Optional[MessageCaller] = None, +) -> Iterator[str]: + if client: + async for text in _do_chat_completion( + url, + chat_data, + client=client, + headers=headers, + timeout=timeout, + caller=caller, + ): + yield text + else: + async with httpx.AsyncClient() as client: + async for text in _do_chat_completion( + url, + chat_data, + client=client, + headers=headers, + timeout=timeout, + caller=caller, + ): + yield text + + +async def chat_completion( + url: str, + chat_data: Dict[str, Any], + client: Optional[httpx.AsyncClient] = None, + headers: Dict[str, Any] = {}, + timeout: int = 60, + caller: Optional[MessageCaller] = None, +) -> str: + full_text = "" + async for text in chat_completion_stream( + url, chat_data, client, headers=headers, timeout=timeout, caller=caller + ): + full_text += text + return full_text diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 072b527f1..d1a98ed49 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -8,6 +8,7 @@ pytest-integration pytest-mock pytest-recording pytesseract==0.3.10 +aioresponses # python code format black # for git hooks