feat(model): Support OpenAI-Compatible RESTful APIs

This commit is contained in:
FangYin Cheng 2023-11-02 20:38:58 +08:00
parent c7cad041d5
commit 2c9c539404
22 changed files with 1124 additions and 62 deletions

View File

@ -7,6 +7,16 @@ services:
restart: unless-stopped
networks:
- dbgptnet
api-server:
image: eosphorosai/dbgpt:latest
command: dbgpt start apiserver --controller_addr http://controller:8000
restart: unless-stopped
depends_on:
- controller
networks:
- dbgptnet
ports:
- 8100:8100/tcp
llm-worker:
image: eosphorosai/dbgpt:latest
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000

View File

View File

@ -46,6 +46,8 @@ class ComponentType(str, Enum):
WORKER_MANAGER = "dbgpt_worker_manager"
WORKER_MANAGER_FACTORY = "dbgpt_worker_manager_factory"
MODEL_CONTROLLER = "dbgpt_model_controller"
MODEL_REGISTRY = "dbgpt_model_registry"
MODEL_API_SERVER = "dbgpt_model_api_server"
AGENT_HUB = "dbgpt_agent_hub"
EXECUTOR_DEFAULT = "dbgpt_thread_pool_default"
TRACER = "dbgpt_tracer"
@ -68,7 +70,6 @@ class BaseComponent(LifeCycle, ABC):
This method needs to be implemented by every component to define how it integrates
with the main system app.
"""
pass
T = TypeVar("T", bound=BaseComponent)
@ -90,13 +91,28 @@ class SystemApp(LifeCycle):
"""Returns the internal ASGI app."""
return self._asgi_app
def register(self, component: Type[BaseComponent], *args, **kwargs):
"""Register a new component by its type."""
def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
"""Register a new component by its type.
Args:
component (Type[BaseComponent]): The component class to register
Returns:
T: The instance of registered component
"""
instance = component(self, *args, **kwargs)
self.register_instance(instance)
return instance
def register_instance(self, instance: T):
"""Register an already initialized component."""
def register_instance(self, instance: T) -> T:
"""Register an already initialized component.
Args:
instance (T): The component instance to register
Returns:
T: The instance of registered component
"""
name = instance.name
if isinstance(name, ComponentType):
name = name.value
@ -107,18 +123,34 @@ class SystemApp(LifeCycle):
logger.info(f"Register component with name {name} and instance: {instance}")
self.components[name] = instance
instance.init_app(self)
return instance
def get_component(
self,
name: Union[str, ComponentType],
component_type: Type[T],
default_component=_EMPTY_DEFAULT_COMPONENT,
or_register_component: Type[BaseComponent] = None,
*args,
**kwargs,
) -> T:
"""Retrieve a registered component by its name and type."""
"""Retrieve a registered component by its name and type.
Args:
name (Union[str, ComponentType]): Component name
component_type (Type[T]): The type of current retrieve component
default_component : The default component instance if not retrieve by name
or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name
Returns:
T: The instance retrieved by component name
"""
if isinstance(name, ComponentType):
name = name.value
component = self.components.get(name)
if not component:
if or_register_component:
return self.register(or_register_component, *args, **kwargs)
if default_component != _EMPTY_DEFAULT_COMPONENT:
return default_component
raise ValueError(f"No component found with name {name}")

View File

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
from enum import Enum
from typing import TypedDict, Optional, Dict, List
from typing import TypedDict, Optional, Dict, List, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription
@ -52,6 +52,8 @@ class ModelOutput:
text: str
error_code: int
model_context: Dict = None
finish_reason: str = None
usage: Dict[str, Any] = None
def to_dict(self) -> Dict:
return asdict(self)

View File

@ -8,6 +8,7 @@ from pilot.configs.model_config import LOGDIR
from pilot.model.base import WorkerApplyType
from pilot.model.parameter import (
ModelControllerParameters,
ModelAPIServerParameters,
ModelWorkerParameters,
ModelParameters,
BaseParameters,
@ -441,15 +442,27 @@ def stop_model_worker(port: int):
@click.command(name="apiserver")
@EnvArgumentParser.create_click_option(ModelAPIServerParameters)
def start_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError
"""Start apiserver"""
if kwargs["daemon"]:
log_file = os.path.join(LOGDIR, "model_apiserver_uvicorn.log")
_run_current_with_daemon("ModelAPIServer", log_file)
else:
from pilot.model.cluster import run_apiserver
run_apiserver()
@click.command(name="apiserver")
def stop_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError
@add_stop_server_options
def stop_apiserver(port: int):
"""Stop apiserver"""
name = "ModelAPIServer"
if port:
name = f"{name}-{port}"
_stop_service("apiserver", name, port=port)
def _stop_all_model_server(**kwargs):

View File

@ -21,6 +21,7 @@ from pilot.model.cluster.controller.controller import (
run_model_controller,
BaseModelController,
)
from pilot.model.cluster.apiserver.api import run_apiserver
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
@ -40,4 +41,5 @@ __all__ = [
"ModelRegistryClient",
"RemoteWorkerManager",
"run_model_controller",
"run_apiserver",
]

View File

@ -0,0 +1,443 @@
"""A server that provides OpenAI-compatible RESTful APIs. It supports:
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
"""
from typing import Optional, List, Dict, Any, Generator
import logging
import asyncio
import shortuuid
import json
from fastapi import APIRouter, FastAPI
from fastapi import Depends, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseSettings
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponse,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
ChatCompletionResponseChoice,
DeltaMessage,
EmbeddingsRequest,
EmbeddingsResponse,
ErrorResponse,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
from fastchat.protocol.api_protocol import (
APIChatCompletionRequest,
APITokenCheckRequest,
APITokenCheckResponse,
APITokenCheckResponseItem,
)
from fastchat.serve.openai_api_server import create_error_response, check_requests
from fastchat.constants import ErrorCode
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.model.base import ModelInstance, ModelOutput
from pilot.model.parameter import ModelAPIServerParameters, WorkerType
from pilot.model.cluster import ModelRegistry, ModelRegistryClient
from pilot.model.cluster.manager_base import WorkerManager, WorkerManagerFactory
from pilot.utils.utils import setup_logging
logger = logging.getLogger(__name__)
class APIServerException(Exception):
def __init__(self, code: int, message: str):
self.code = code
self.message = message
class APISettings(BaseSettings):
api_keys: Optional[List[str]] = None
api_settings = APISettings()
get_bearer_token = HTTPBearer(auto_error=False)
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
) -> str:
if api_settings.api_keys:
if auth is None or (token := auth.credentials) not in api_settings.api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
class APIServer(BaseComponent):
name = ComponentType.MODEL_API_SERVER
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def get_worker_manager(self) -> WorkerManager:
"""Get the worker manager component instance
Raises:
APIServerException: If can't get worker manager component instance
"""
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
if not worker_manager:
raise APIServerException(
ErrorCode.INTERNAL_ERROR,
f"Could not get component {ComponentType.WORKER_MANAGER_FACTORY} from system_app",
)
return worker_manager
def get_model_registry(self) -> ModelRegistry:
"""Get the model registry component instance
Raises:
APIServerException: If can't get model registry component instance
"""
controller = self.system_app.get_component(
ComponentType.MODEL_REGISTRY, ModelRegistry
)
if not controller:
raise APIServerException(
ErrorCode.INTERNAL_ERROR,
f"Could not get component {ComponentType.MODEL_REGISTRY} from system_app",
)
return controller
async def get_model_instances_or_raise(
self, model_name: str
) -> List[ModelInstance]:
"""Get healthy model instances with request model name
Args:
model_name (str): Model name
Raises:
APIServerException: If can't get healthy model instances with request model name
"""
registry = self.get_model_registry()
registry_model_name = f"{model_name}@llm"
model_instances = await registry.get_all_instances(
registry_model_name, healthy_only=True
)
if not model_instances:
all_instances = await registry.get_all_model_instances(healthy_only=True)
models = [
ins.model_name.split("@llm")[0]
for ins in all_instances
if ins.model_name.endswith("@llm")
]
if models:
models = "&&".join(models)
message = f"Only {models} allowed now, your model {model_name}"
else:
message = f"No models allowed now, your model {model_name}"
raise APIServerException(ErrorCode.INVALID_MODEL, message)
return model_instances
async def get_available_models(self) -> ModelList:
"""Return available models
Just include LLM and embedding models.
Returns:
List[ModelList]: The list of models.
"""
registry = self.get_model_registry()
model_instances = await registry.get_all_model_instances(healthy_only=True)
model_name_set = set()
for inst in model_instances:
name, worker_type = WorkerType.parse_worker_key(inst.model_name)
if worker_type == WorkerType.LLM or worker_type == WorkerType.TEXT2VEC:
model_name_set.add(name)
models = list(model_name_set)
models.sort()
# TODO: return real model permission details
model_cards = []
for m in models:
model_cards.append(
ModelCard(
id=m, root=m, owned_by="DB-GPT", permission=[ModelPermission()]
)
)
return ModelList(data=model_cards)
async def chat_completion_stream_generator(
self, model_name: str, params: Dict[str, Any], n: int
) -> Generator[str, Any, None]:
"""Chat stream completion generator
Args:
model_name (str): Model name
params (Dict[str, Any]): The parameters pass to model worker
n (int): How many completions to generate for each prompt.
"""
worker_manager = self.get_worker_manager()
id = f"chatcmpl-{shortuuid.random()}"
finish_stream_events = []
for i in range(n):
# First chunk with role
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
previous_text = ""
async for model_output in worker_manager.generate_stream(params):
model_output: ModelOutput = model_output
if model_output.error_code != 0:
yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
decoded_unicode = model_output.text.replace("\ufffd", "")
delta_text = decoded_unicode[len(previous_text) :]
previous_text = (
decoded_unicode
if len(decoded_unicode) > len(previous_text)
else previous_text
)
if len(delta_text) == 0:
delta_text = None
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
finish_reason=model_output.finish_reason,
)
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model_name
)
if delta_text is None:
if model_output.finish_reason is not None:
finish_stream_events.append(chunk)
continue
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
for finish_chunk in finish_stream_events:
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
async def chat_completion_generate(
self, model_name: str, params: Dict[str, Any], n: int
) -> ChatCompletionResponse:
"""Generate completion
Args:
model_name (str): Model name
params (Dict[str, Any]): The parameters pass to model worker
n (int): How many completions to generate for each prompt.
"""
worker_manager: WorkerManager = self.get_worker_manager()
choices = []
chat_completions = []
for i in range(n):
model_output = asyncio.create_task(worker_manager.generate(params))
chat_completions.append(model_output)
try:
all_tasks = await asyncio.gather(*chat_completions)
except Exception as e:
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
usage = UsageInfo()
for i, model_output in enumerate(all_tasks):
model_output: ModelOutput = model_output
if model_output.error_code != 0:
return create_error_response(model_output.error_code, model_output.text)
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role="assistant", content=model_output.text),
finish_reason=model_output.finish_reason or "stop",
)
)
if model_output.usage:
task_usage = UsageInfo.parse_obj(model_output.usage)
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
def get_api_server() -> APIServer:
api_server = global_system_app.get_component(
ComponentType.MODEL_API_SERVER, APIServer, default_component=None
)
if not api_server:
global_system_app.register(APIServer)
return global_system_app.get_component(ComponentType.MODEL_API_SERVER, APIServer)
router = APIRouter()
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
async def get_available_models(api_server: APIServer = Depends(get_api_server)):
return await api_server.get_available_models()
@router.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(
request: APIChatCompletionRequest, api_server: APIServer = Depends(get_api_server)
):
await api_server.get_model_instances_or_raise(request.model)
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret
params = {
"model": request.model,
"messages": ModelMessage.to_dict_list(
ModelMessage.from_openai_messages(request.messages)
),
"echo": False,
}
if request.temperature:
params["temperature"] = request.temperature
if request.top_p:
params["top_p"] = request.top_p
if request.max_tokens:
params["max_new_tokens"] = request.max_tokens
if request.stop:
params["stop"] = request.stop
if request.user:
params["user"] = request.user
# TODO check token length
if request.stream:
generator = api_server.chat_completion_stream_generator(
request.model, params, request.n
)
return StreamingResponse(generator, media_type="text/event-stream")
return await api_server.chat_completion_generate(request.model, params, request.n)
def _initialize_all(controller_addr: str, system_app: SystemApp):
from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient
from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
if not system_app.get_component(
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
):
# Register model registry if not exist
registry = ModelRegistryClient(controller_addr)
registry.name = ComponentType.MODEL_REGISTRY.value
system_app.register_instance(registry)
registry = system_app.get_component(
ComponentType.MODEL_REGISTRY, ModelRegistry, default_component=None
)
worker_manager = RemoteWorkerManager(registry)
# Register worker manager component if not exist
system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
or_register_component=_DefaultWorkerManagerFactory,
worker_manager=worker_manager,
)
# Register api server component if not exist
system_app.get_component(
ComponentType.MODEL_API_SERVER, APIServer, or_register_component=APIServer
)
def initialize_apiserver(
controller_addr: str,
app=None,
system_app: SystemApp = None,
host: str = None,
port: int = None,
api_keys: List[str] = None,
):
global global_system_app
global api_settings
embedded_mod = True
if not app:
embedded_mod = False
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
if not system_app:
system_app = SystemApp(app)
global_system_app = system_app
if api_keys:
api_settings.api_keys = api_keys
app.include_router(router, prefix="/api", tags=["APIServer"])
@app.exception_handler(APIServerException)
async def validation_apiserver_exception_handler(request, exc: APIServerException):
return create_error_response(exc.code, exc.message)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
_initialize_all(controller_addr, system_app)
if not embedded_mod:
import uvicorn
uvicorn.run(app, host=host, port=port, log_level="info")
def run_apiserver():
parser = EnvArgumentParser()
env_prefix = "apiserver_"
apiserver_params: ModelAPIServerParameters = parser.parse_args_into_dataclass(
ModelAPIServerParameters,
env_prefixes=[env_prefix],
)
setup_logging(
"pilot",
logging_level=apiserver_params.log_level,
logger_filename=apiserver_params.log_file,
)
api_keys = None
if apiserver_params.api_keys:
api_keys = apiserver_params.api_keys.strip().split(",")
initialize_apiserver(
apiserver_params.controller_addr,
host=apiserver_params.host,
port=apiserver_params.port,
api_keys=api_keys,
)
if __name__ == "__main__":
run_apiserver()

View File

@ -0,0 +1,248 @@
import pytest
import pytest_asyncio
from aioresponses import aioresponses
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient, HTTPError
from pilot.component import SystemApp
from pilot.utils.openai_utils import chat_completion_stream, chat_completion
from pilot.model.cluster.apiserver.api import (
api_settings,
initialize_apiserver,
ModelList,
UsageInfo,
ChatCompletionResponse,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
ChatCompletionResponseChoice,
DeltaMessage,
)
from pilot.model.cluster.tests.conftest import _new_cluster
from pilot.model.cluster.worker.manager import _DefaultWorkerManagerFactory
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
@pytest_asyncio.fixture
async def system_app():
return SystemApp(app)
@pytest_asyncio.fixture
async def client(request, system_app: SystemApp):
param = getattr(request, "param", {})
api_keys = param.get("api_keys", [])
client_api_key = param.get("client_api_key")
if "num_workers" not in param:
param["num_workers"] = 2
if "api_keys" in param:
del param["api_keys"]
headers = {}
if client_api_key:
headers["Authorization"] = "Bearer " + client_api_key
print(f"param: {param}")
if api_settings:
# Clear global api keys
api_settings.api_keys = []
async with AsyncClient(app=app, base_url="http://test", headers=headers) as client:
async with _new_cluster(**param) as cluster:
worker_manager, model_registry = cluster
system_app.register(_DefaultWorkerManagerFactory, worker_manager)
system_app.register_instance(model_registry)
# print(f"Instances {model_registry.registry}")
initialize_apiserver(None, app, system_app, api_keys=api_keys)
yield client
@pytest.mark.asyncio
async def test_get_all_models(client: AsyncClient):
res = await client.get("/api/v1/models")
res.status_code == 200
model_lists = ModelList.parse_obj(res.json())
print(f"model list json: {res.json()}")
assert model_lists.object == "list"
assert len(model_lists.data) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages",
[
({"stream_messags": ["Hello", " world."]}, "Hello world."),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
],
indirect=["client"],
)
async def test_chat_completions(client: AsyncClient, expected_messages):
chat_data = {
"model": "test-model-name-0",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
}
full_text = ""
async for text in chat_completion_stream(
"/api/v1/chat/completions", chat_data, client
):
full_text += text
assert full_text == expected_messages
assert (
await chat_completion("/api/v1/chat/completions", chat_data, client)
== expected_messages
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages, client_api_key",
[
(
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
],
indirect=["client"],
)
async def test_chat_completions_with_openai_lib_async_no_stream(
client: AsyncClient, expected_messages: str, client_api_key: str
):
import openai
openai.api_key = client_api_key
openai.api_base = "http://test/api/v1"
model_name = "test-model-name-0"
with aioresponses() as mocked:
mock_message = {"text": expected_messages}
one_res = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=expected_messages),
finish_reason="stop",
)
data = ChatCompletionResponse(
model=model_name, choices=[one_res], usage=UsageInfo()
)
mock_message = f"{data.json(exclude_unset=True, ensure_ascii=False)}\n\n"
# Mock http request
mocked.post(
"http://test/api/v1/chat/completions", status=200, body=mock_message
)
completion = await openai.ChatCompletion.acreate(
model=model_name,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
)
assert completion.choices[0].message.content == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages, client_api_key",
[
(
{"stream_messags": ["Hello", " world."], "api_keys": ["abc"]},
"Hello world.",
"abc",
),
({"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc"]}, "你好,我是张三。", "abc"),
],
indirect=["client"],
)
async def test_chat_completions_with_openai_lib_async_stream(
client: AsyncClient, expected_messages: str, client_api_key: str
):
import openai
openai.api_key = client_api_key
openai.api_base = "http://test/api/v1"
model_name = "test-model-name-0"
with aioresponses() as mocked:
mock_message = {"text": expected_messages}
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=expected_messages),
finish_reason="stop",
)
chunk = ChatCompletionStreamResponse(
id=0, choices=[choice_data], model=model_name
)
mock_message = f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
mocked.post(
"http://test/api/v1/chat/completions",
status=200,
body=mock_message,
content_type="text/event-stream",
)
stream_stream_resp = ""
async for stream_resp in await openai.ChatCompletion.acreate(
model=model_name,
messages=[{"role": "user", "content": "Hello! What is your name?"}],
stream=True,
):
stream_stream_resp = stream_resp.choices[0]["delta"].get("content", "")
assert stream_stream_resp == expected_messages
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, expected_messages, api_key_is_error",
[
(
{
"stream_messags": ["Hello", " world."],
"api_keys": ["abc", "xx"],
"client_api_key": "abc",
},
"Hello world.",
False,
),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。", False),
(
{"stream_messags": ["你好,我是", "张三。"], "api_keys": ["abc", "xx"]},
"你好,我是张三。",
True,
),
(
{
"stream_messags": ["你好,我是", "张三。"],
"api_keys": ["abc", "xx"],
"client_api_key": "error_api_key",
},
"你好,我是张三。",
True,
),
],
indirect=["client"],
)
async def test_chat_completions_with_api_keys(
client: AsyncClient, expected_messages: str, api_key_is_error: bool
):
chat_data = {
"model": "test-model-name-0",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
}
if api_key_is_error:
with pytest.raises(HTTPError):
await chat_completion("/api/v1/chat/completions", chat_data, client)
else:
assert (
await chat_completion("/api/v1/chat/completions", chat_data, client)
== expected_messages
)

View File

@ -66,7 +66,9 @@ class LocalModelController(BaseModelController):
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
)
if not model_name:
return await self.registry.get_all_model_instances()
return await self.registry.get_all_model_instances(
healthy_only=healthy_only
)
else:
return await self.registry.get_all_instances(model_name, healthy_only)
@ -98,8 +100,10 @@ class _RemoteModelController(BaseModelController):
class ModelRegistryClient(_RemoteModelController, ModelRegistry):
async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.get_all_instances()
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
return await self.get_all_instances(healthy_only=healthy_only)
@sync_api_remote(path="/api/controller/models")
def sync_get_all_instances(

View File

@ -1,22 +1,37 @@
import random
import threading
import time
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import itertools
from pilot.component import BaseComponent, ComponentType, SystemApp
from pilot.model.base import ModelInstance
class ModelRegistry(ABC):
logger = logging.getLogger(__name__)
class ModelRegistry(BaseComponent, ABC):
"""
Abstract base class for a model registry. It provides an interface
for registering, deregistering, fetching instances, and sending heartbeats
for instances.
"""
name = ComponentType.MODEL_REGISTRY
def __init__(self, system_app: SystemApp | None = None):
self.system_app = system_app
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Initialize the component with the main application."""
self.system_app = system_app
@abstractmethod
async def register_instance(self, instance: ModelInstance) -> bool:
"""
@ -65,9 +80,11 @@ class ModelRegistry(ABC):
"""Fetch all instances of a given model. Optionally, fetch only the healthy instances."""
@abstractmethod
async def get_all_model_instances(self) -> List[ModelInstance]:
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
"""
Fetch all instances of all models
Fetch all instances of all models, Optionally, fetch only the healthy instances.
Returns:
- List[ModelInstance]: A list of instances for the all models.
@ -105,8 +122,12 @@ class ModelRegistry(ABC):
class EmbeddedModelRegistry(ModelRegistry):
def __init__(
self, heartbeat_interval_secs: int = 60, heartbeat_timeout_secs: int = 120
self,
system_app: SystemApp | None = None,
heartbeat_interval_secs: int = 60,
heartbeat_timeout_secs: int = 120,
):
super().__init__(system_app)
self.registry: Dict[str, List[ModelInstance]] = defaultdict(list)
self.heartbeat_interval_secs = heartbeat_interval_secs
self.heartbeat_timeout_secs = heartbeat_timeout_secs
@ -180,9 +201,14 @@ class EmbeddedModelRegistry(ModelRegistry):
instances = [ins for ins in instances if ins.healthy == True]
return instances
async def get_all_model_instances(self) -> List[ModelInstance]:
print(self.registry)
return list(itertools.chain(*self.registry.values()))
async def get_all_model_instances(
self, healthy_only: bool = False
) -> List[ModelInstance]:
logger.debug("Current registry metadata:\n{self.registry}")
instances = list(itertools.chain(*self.registry.values()))
if healthy_only:
instances = [ins for ins in instances if ins.healthy == True]
return instances
async def send_heartbeat(self, instance: ModelInstance) -> bool:
_, exist_ins = self._get_instances(

View File

View File

@ -6,6 +6,7 @@ from pilot.model.parameter import ModelParameters, ModelWorkerParameters, Worker
from pilot.model.base import ModelOutput
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.worker.manager import (
WorkerManager,
LocalWorkerManager,
RegisterFunc,
DeregisterFunc,
@ -13,6 +14,23 @@ from pilot.model.cluster.worker.manager import (
ApplyFunction,
)
from pilot.model.base import ModelInstance
from pilot.model.cluster.registry import ModelRegistry, EmbeddedModelRegistry
@pytest.fixture
def model_registry(request):
return EmbeddedModelRegistry()
@pytest.fixture
def model_instance():
return ModelInstance(
model_name="test_model",
host="192.168.1.1",
port=5000,
)
class MockModelWorker(ModelWorker):
def __init__(
@ -51,8 +69,10 @@ class MockModelWorker(ModelWorker):
raise Exception("Stop worker error for mock")
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
full_text = ""
for msg in self.stream_messags:
yield ModelOutput(text=msg, error_code=0)
full_text += msg
yield ModelOutput(text=full_text, error_code=0)
def generate(self, params: Dict) -> ModelOutput:
output = None
@ -67,6 +87,8 @@ class MockModelWorker(ModelWorker):
_TEST_MODEL_NAME = "vicuna-13b-v1.5"
_TEST_MODEL_PATH = "/app/models/vicuna-13b-v1.5"
ClusterType = Tuple[WorkerManager, ModelRegistry]
def _new_worker_params(
model_name: str = _TEST_MODEL_NAME,
@ -85,7 +107,9 @@ def _create_workers(
worker_type: str = WorkerType.LLM.value,
stream_messags: List[str] = None,
embeddings: List[List[float]] = None,
) -> List[Tuple[ModelWorker, ModelWorkerParameters]]:
host: str = "127.0.0.1",
start_port=8001,
) -> List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]:
workers = []
for i in range(num_workers):
model_name = f"test-model-name-{i}"
@ -98,10 +122,16 @@ def _create_workers(
stream_messags=stream_messags,
embeddings=embeddings,
)
model_instance = ModelInstance(
model_name=WorkerType.to_worker_key(model_name, worker_type),
host=host,
port=start_port + i,
healthy=True,
)
worker_params = _new_worker_params(
model_name, model_path, worker_type=worker_type
)
workers.append((worker, worker_params))
workers.append((worker, worker_params, model_instance))
return workers
@ -127,12 +157,12 @@ async def _start_worker_manager(**kwargs):
model_registry=model_registry,
)
for worker, worker_params in _create_workers(
for worker, worker_params, model_instance in _create_workers(
num_workers, error_worker, stop_error, stream_messags, embeddings
):
worker_manager.add_worker(worker, worker_params)
if workers:
for worker, worker_params in workers:
for worker, worker_params, model_instance in workers:
worker_manager.add_worker(worker, worker_params)
if start:
@ -143,6 +173,15 @@ async def _start_worker_manager(**kwargs):
await worker_manager.stop()
async def _create_model_registry(
workers: List[Tuple[ModelWorker, ModelWorkerParameters, ModelInstance]]
) -> ModelRegistry:
registry = EmbeddedModelRegistry()
for _, _, inst in workers:
assert await registry.register_instance(inst) == True
return registry
@pytest_asyncio.fixture
async def manager_2_workers(request):
param = getattr(request, "param", {})
@ -166,3 +205,27 @@ async def manager_2_embedding_workers(request):
)
async with _start_worker_manager(workers=workers, **param) as worker_manager:
yield (worker_manager, workers)
@asynccontextmanager
async def _new_cluster(**kwargs) -> ClusterType:
num_workers = kwargs.get("num_workers", 0)
workers = _create_workers(
num_workers, stream_messags=kwargs.get("stream_messags", [])
)
if "num_workers" in kwargs:
del kwargs["num_workers"]
registry = await _create_model_registry(
workers,
)
async with _start_worker_manager(workers=workers, **kwargs) as worker_manager:
yield (worker_manager, registry)
@pytest_asyncio.fixture
async def cluster_2_workers(request):
param = getattr(request, "param", {})
workers = _create_workers(2)
registry = await _create_model_registry(workers)
async with _start_worker_manager(workers=workers, **param) as worker_manager:
yield (worker_manager, registry)

View File

@ -256,15 +256,22 @@ class DefaultModelWorker(ModelWorker):
return params, model_context, generate_stream_func, model_span
def _handle_output(self, output, previous_response, model_context):
finish_reason = None
usage = None
if isinstance(output, dict):
finish_reason = output.get("finish_reason")
usage = output.get("usage")
output = output["text"]
if finish_reason is not None:
logger.info(f"finish_reason: {finish_reason}")
incremental_output = output[len(previous_response) :]
print(incremental_output, end="", flush=True)
model_output = ModelOutput(
text=output, error_code=0, model_context=model_context
text=output,
error_code=0,
model_context=model_context,
finish_reason=finish_reason,
usage=usage,
)
return model_output, incremental_output, output

View File

@ -99,9 +99,7 @@ class LocalWorkerManager(WorkerManager):
)
def _worker_key(self, worker_type: str, model_name: str) -> str:
if isinstance(worker_type, WorkerType):
worker_type = worker_type.value
return f"{model_name}@{worker_type}"
return WorkerType.to_worker_key(model_name, worker_type)
async def run_blocking_func(self, func, *args):
if asyncio.iscoroutinefunction(func):

View File

@ -3,7 +3,7 @@ import pytest
from typing import List, Iterator, Dict, Tuple
from dataclasses import asdict
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.base import ModelOutput, WorkerApplyType
from pilot.model.base import ModelOutput, WorkerApplyType, ModelInstance
from pilot.model.cluster.base import WorkerApplyRequest, WorkerStartupRequest
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.manager_base import WorkerRunData
@ -14,7 +14,7 @@ from pilot.model.cluster.worker.manager import (
SendHeartbeatFunc,
ApplyFunction,
)
from pilot.model.cluster.worker.tests.base_tests import (
from pilot.model.cluster.tests.conftest import (
MockModelWorker,
manager_2_workers,
manager_with_2_workers,
@ -216,7 +216,7 @@ async def test__remove_worker():
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
for _, worker_params in workers:
for _, worker_params, _ in workers:
manager._remove_worker(worker_params)
not_exist_parmas = _new_worker_params(
model_name="this is a not exist worker params"
@ -229,7 +229,7 @@ async def test__remove_worker():
async def test_model_startup(mock_build_worker):
async with _start_worker_manager() as manager:
workers = _create_workers(1)
worker, worker_params = workers[0]
worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
@ -245,7 +245,7 @@ async def test_model_startup(mock_build_worker):
async with _start_worker_manager() as manager:
workers = _create_workers(1, error_worker=True)
worker, worker_params = workers[0]
worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
host="127.0.0.1",
@ -263,7 +263,7 @@ async def test_model_startup(mock_build_worker):
async def test_model_shutdown(mock_build_worker):
async with _start_worker_manager(start=False, stop=False) as manager:
workers = _create_workers(1)
worker, worker_params = workers[0]
worker, worker_params, model_instance = workers[0]
mock_build_worker.return_value = worker
req = WorkerStartupRequest(
@ -298,7 +298,7 @@ async def test_get_model_instances(is_async):
workers = _create_workers(3)
async with _start_worker_manager(workers=workers, stop=False) as manager:
assert len(manager.workers) == 3
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
@ -326,7 +326,7 @@ async def test__simple_select(
]
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
instances = await manager.get_model_instances(worker_type, model_name)
@ -351,7 +351,7 @@ async def test_select_one_instance(
],
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
if is_async:
@ -376,7 +376,7 @@ async def test__get_model(
],
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
@ -403,13 +403,13 @@ async def test_generate_stream(
expected_messages: str,
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
text = ""
async for out in manager.generate_stream(params):
text += out.text
text = out.text
assert text == expected_messages
@ -417,8 +417,8 @@ async def test_generate_stream(
@pytest.mark.parametrize(
"manager_with_2_workers, expected_messages",
[
({"stream_messags": ["Hello", " world."]}, " world."),
({"stream_messags": ["你好,我是", "张三。"]}, "张三。"),
({"stream_messags": ["Hello", " world."]}, "Hello world."),
({"stream_messags": ["你好,我是", "张三。"]}, "你好,我是张三。"),
],
indirect=["manager_with_2_workers"],
)
@ -429,7 +429,7 @@ async def test_generate(
expected_messages: str,
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name}
@ -454,7 +454,7 @@ async def test_embeddings(
is_async: bool,
):
manager, workers = manager_2_embedding_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = {"model": model_name, "input": ["hello", "world"]}
@ -472,7 +472,7 @@ async def test_parameter_descriptions(
]
):
manager, workers = manager_with_2_workers
for _, worker_params in workers:
for _, worker_params, _ in workers:
model_name = worker_params.model_name
worker_type = worker_params.worker_type
params = await manager.parameter_descriptions(worker_type, model_name)

View File

@ -467,7 +467,8 @@ register_conv_template(
sep="\n",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
)
),
override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
register_conv_template(
@ -482,7 +483,8 @@ register_conv_template(
sep="###",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
)
),
override=True,
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
register_conv_template(
@ -495,5 +497,6 @@ register_conv_template(
sep="",
sep2="</s>",
stop_str=["</s>", "<|endoftext|>"],
)
),
override=True,
)

View File

@ -1,9 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional
from typing import Dict, Optional, Union, Tuple
from pilot.model.conversation import conv_templates
from pilot.utils.parameter_utils import BaseParameters
@ -19,6 +20,35 @@ class WorkerType(str, Enum):
def values():
return [item.value for item in WorkerType]
@staticmethod
def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str:
"""Generate worker key from worker name and worker type
Args:
worker_name (str): Worker name(eg., chatglm2-6b)
worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`])
Returns:
str: Generated worker key
"""
if "@" in worker_name:
raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}")
if isinstance(worker_type, WorkerType):
worker_type = worker_type.value
return f"{worker_name}@{worker_type}"
@staticmethod
def parse_worker_key(worker_key: str) -> Tuple[str, str]:
"""Parse worker name and worker type from worker key
Args:
worker_key (str): Worker key generated by [`WorkerType.to_worker_key`]
Returns:
Tuple[str, str]: Worker name and worker type
"""
return tuple(worker_key.split("@"))
@dataclass
class ModelControllerParameters(BaseParameters):
@ -60,6 +90,56 @@ class ModelControllerParameters(BaseParameters):
)
@dataclass
class ModelAPIServerParameters(BaseParameters):
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model API server deploy host"}
)
port: Optional[int] = field(
default=8100, metadata={"help": "Model API server deploy port"}
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model API server in background"}
)
controller_addr: Optional[str] = field(
default="http://127.0.0.1:8000",
metadata={"help": "The Model controller address to connect"},
)
api_keys: Optional[str] = field(
default=None,
metadata={"help": "Optional list of comma separated API keys"},
)
log_level: Optional[str] = field(
default=None,
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
log_file: Optional[str] = field(
default="dbgpt_model_apiserver.log",
metadata={
"help": "The filename to store log",
},
)
tracer_file: Optional[str] = field(
default="dbgpt_model_apiserver_tracer.jsonl",
metadata={
"help": "The filename to store tracer span records",
},
)
@dataclass
class BaseModelParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, Optional
from typing import Any, Dict, List, Tuple, Optional, Union
from pydantic import BaseModel, Field, root_validator
@ -70,14 +70,6 @@ class SystemMessage(BaseMessage):
return "system"
class ModelMessage(BaseModel):
"""Type of message that interaction between dbgpt-server and llm-server"""
"""Similar to openai's message format"""
role: str
content: str
class ModelMessageRoleType:
""" "Type of ModelMessage role"""
@ -87,6 +79,45 @@ class ModelMessageRoleType:
VIEW = "view"
class ModelMessage(BaseModel):
"""Type of message that interaction between dbgpt-server and llm-server"""
"""Similar to openai's message format"""
role: str
content: str
@staticmethod
def from_openai_messages(
messages: Union[str, List[Dict[str, str]]]
) -> List["ModelMessage"]:
"""Openai message format to current ModelMessage format"""
if isinstance(messages, str):
return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
result = []
for message in messages:
msg_role = message["role"]
content = message["content"]
if msg_role == "system":
result.append(
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content)
)
elif msg_role == "user":
result.append(
ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
)
elif msg_role == "assistant":
result.append(
ModelMessage(role=ModelMessageRoleType.AI, content=content)
)
else:
raise ValueError(f"Unknown role: {msg_role}")
return result
@staticmethod
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
return list(map(lambda m: m.dict(), messages))
class Generation(BaseModel):
"""Output of a single generation."""

View File

@ -0,0 +1,99 @@
from typing import Dict, Any, Awaitable, Callable, Optional, Iterator
import httpx
import asyncio
import logging
import json
logger = logging.getLogger(__name__)
MessageCaller = Callable[[str], Awaitable[None]]
async def _do_chat_completion(
url: str,
chat_data: Dict[str, Any],
client: httpx.AsyncClient,
headers: Dict[str, Any] = {},
timeout: int = 60,
caller: Optional[MessageCaller] = None,
) -> Iterator[str]:
async with client.stream(
"POST",
url,
headers=headers,
json=chat_data,
timeout=timeout,
) as res:
if res.status_code != 200:
error_message = await res.aread()
if error_message:
error_message = error_message.decode("utf-8")
logger.error(
f"Request failed with status {res.status_code}. Error: {error_message}"
)
raise httpx.RequestError(
f"Request failed with status {res.status_code}",
request=res.request,
)
async for line in res.aiter_lines():
if line:
if not line.startswith("data: "):
if caller:
await caller(line)
yield line
else:
decoded_line = line.split("data: ", 1)[1]
if decoded_line.lower().strip() != "[DONE]".lower():
obj = json.loads(decoded_line)
if obj["choices"][0]["delta"].get("content") is not None:
text = obj["choices"][0]["delta"].get("content")
if caller:
await caller(text)
yield text
await asyncio.sleep(0.02)
async def chat_completion_stream(
url: str,
chat_data: Dict[str, Any],
client: Optional[httpx.AsyncClient] = None,
headers: Dict[str, Any] = {},
timeout: int = 60,
caller: Optional[MessageCaller] = None,
) -> Iterator[str]:
if client:
async for text in _do_chat_completion(
url,
chat_data,
client=client,
headers=headers,
timeout=timeout,
caller=caller,
):
yield text
else:
async with httpx.AsyncClient() as client:
async for text in _do_chat_completion(
url,
chat_data,
client=client,
headers=headers,
timeout=timeout,
caller=caller,
):
yield text
async def chat_completion(
url: str,
chat_data: Dict[str, Any],
client: Optional[httpx.AsyncClient] = None,
headers: Dict[str, Any] = {},
timeout: int = 60,
caller: Optional[MessageCaller] = None,
) -> str:
full_text = ""
async for text in chat_completion_stream(
url, chat_data, client, headers=headers, timeout=timeout, caller=caller
):
full_text += text
return full_text

View File

@ -8,6 +8,7 @@ pytest-integration
pytest-mock
pytest-recording
pytesseract==0.3.10
aioresponses
# python code format
black
# for git hooks