feat(model): supports the deployment of multiple models through the API and add the corresponding command line interface (#570)

- Split `LLM_MODEL_CONFIG` into `LLM_MODEL_CONFIG` and
`EMBEDDING_MODEL_CONFIG`.
- New HTTP API to obtain the list of models and configuration parameters
supported by the current cluster.
  - New HTTP API to launch models on a specified machine.
  - The command line supports above HTTP API.
This commit is contained in:
Aries-ckt
2023-09-11 17:55:44 +08:00
committed by GitHub
38 changed files with 1081 additions and 349 deletions

View File

@@ -41,25 +41,12 @@ LLM_MODEL_CONFIG = {
# (Llama2 based) see https://huggingface.co/lmsys/vicuna-13b-v1.5
"vicuna-13b-v1.5": os.path.join(MODEL_PATH, "vicuna-13b-v1.5"),
"vicuna-7b-v1.5": os.path.join(MODEL_PATH, "vicuna-7b-v1.5"),
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
# https://huggingface.co/moka-ai/m3e-large
"m3e-base": os.path.join(MODEL_PATH, "m3e-base"),
# https://huggingface.co/moka-ai/m3e-base
"m3e-large": os.path.join(MODEL_PATH, "m3e-large"),
# https://huggingface.co/BAAI/bge-large-en
"bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"),
"bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"),
# https://huggingface.co/BAAI/bge-large-zh
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
"chatglm2-6b": os.path.join(MODEL_PATH, "chatglm2-6b"),
"chatglm2-6b-int4": os.path.join(MODEL_PATH, "chatglm2-6b-int4"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
@@ -84,6 +71,22 @@ LLM_MODEL_CONFIG = {
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
}
EMBEDDING_MODEL_CONFIG = {
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
# https://huggingface.co/moka-ai/m3e-large
"m3e-base": os.path.join(MODEL_PATH, "m3e-base"),
# https://huggingface.co/moka-ai/m3e-base
"m3e-large": os.path.join(MODEL_PATH, "m3e-large"),
# https://huggingface.co/BAAI/bge-large-en
"bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"),
"bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"),
# https://huggingface.co/BAAI/bge-large-zh
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
}
# Load model config
ISDEBUG = False

View File

@@ -108,6 +108,17 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]:
return [param_class]
def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
try:
llm_adapter = get_llm_model_adapter(model_name, model_path)
return llm_adapter.model_param_class()
except Exception as e:
logger.warn(
f"Parse model parameters with model name {model_name} and model {model_path} failed {str(e)}, return `ModelParameters`"
)
return ModelParameters
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?

View File

@@ -2,9 +2,10 @@
# -*- coding: utf-8 -*-
from enum import Enum
from typing import TypedDict, Optional, Dict
from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass
from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription
class Message(TypedDict):
@@ -46,5 +47,40 @@ class ModelOutput:
@dataclass
class WorkerApplyOutput:
message: str
success: Optional[bool] = True
# The seconds cost to apply some action to worker instances
timecost: Optional[int] = -1
@dataclass
class SupportedModel:
model: str
path: str
worker_type: str
path_exist: bool
proxy: bool
enabled: bool
params: List[ParameterDescription]
@classmethod
def from_dict(cls, model_data: Dict) -> "SupportedModel":
params = model_data.get("params", [])
if params:
params = [ParameterDescription(**param) for param in params]
model_data["params"] = params
return cls(**model_data)
@dataclass
class WorkerSupportedModel:
host: str
port: int
models: List[SupportedModel]
@classmethod
def from_dict(cls, worker_data: Dict) -> "WorkerSupportedModel":
models = [
SupportedModel.from_dict(model_data) for model_data in worker_data["models"]
]
worker_data["models"] = models
return cls(**worker_data)

View File

@@ -2,19 +2,27 @@ import click
import functools
import logging
import os
from typing import Callable, List, Type
from typing import Callable, List, Type, Optional
from pilot.model.controller.controller import ModelRegistryClient
from pilot.configs.model_config import LOGDIR
from pilot.model.base import WorkerApplyType
from pilot.model.parameter import (
ModelControllerParameters,
ModelWorkerParameters,
ModelParameters,
BaseParameters,
)
from pilot.utils import get_or_create_event_loop
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.utils.command_utils import _run_current_with_daemon, _stop_service
from pilot.utils.parameter_utils import (
EnvArgumentParser,
_build_parameter_class,
build_lazy_click_command,
)
from pilot.utils.command_utils import (
_run_current_with_daemon,
_stop_service,
_detect_controller_address,
)
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
@@ -22,6 +30,14 @@ MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
logger = logging.getLogger("dbgpt_cli")
def _get_worker_manager(address: str):
from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
return worker_manager
@click.group("model")
@click.option(
"--address",
@@ -38,8 +54,6 @@ def model_cli_group(address: str):
"""Clients that manage model serving"""
global MODEL_CONTROLLER_ADDRESS
if not address:
from pilot.utils.command_utils import _detect_controller_address
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
else:
MODEL_CONTROLLER_ADDRESS = address
@@ -55,6 +69,7 @@ def model_cli_group(address: str):
def list(model_name: str, model_type: str):
"""List model instances"""
from prettytable import PrettyTable
from pilot.model.cluster import ModelRegistryClient
loop = get_or_create_event_loop()
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
@@ -90,7 +105,7 @@ def list(model_name: str, model_type: str):
instance.port,
instance.healthy,
instance.enabled,
instance.prompt_template,
instance.prompt_template if instance.prompt_template else "",
instance.last_heartbeat,
]
)
@@ -122,18 +137,156 @@ def add_model_options(func):
@model_cli_group.command()
@add_model_options
def stop(model_name: str, model_type: str):
@click.option(
"--host",
type=str,
required=True,
help=("The remote host to stop model"),
)
@click.option(
"--port",
type=int,
required=True,
help=("The remote port to stop model"),
)
def stop(model_name: str, model_type: str, host: str, port: int):
"""Stop model instances"""
worker_apply(MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.STOP)
from pilot.model.cluster import WorkerStartupRequest, RemoteWorkerManager
@model_cli_group.command()
@add_model_options
def start(model_name: str, model_type: str):
"""Start model instances"""
worker_apply(
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.START
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
req = WorkerStartupRequest(
host=host,
port=port,
worker_type=model_type,
model=model_name,
params={},
)
loop = get_or_create_event_loop()
res = loop.run_until_complete(worker_manager.model_shutdown(req))
print(res)
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
from pilot.model.adapter import _dynamic_model_parser
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.cluster import RemoteWorkerManager
from pilot.model.parameter import WorkerType
from dataclasses import dataclass, field, fields
pre_args = _SimpleArgParser("model_name", "address", "host", "port")
pre_args.parse()
model_name = pre_args.get("model_name")
address = pre_args.get("address")
host = pre_args.get("host")
port = pre_args.get("port")
if port:
port = int(port)
if not address:
address = _detect_controller_address()
worker_manager: RemoteWorkerManager = _get_worker_manager(address)
loop = get_or_create_event_loop()
models = loop.run_until_complete(worker_manager.supported_models())
fields_dict = {}
fields_dict["model_name"] = (
str,
field(default=None, metadata={"help": "The model name to deploy"}),
)
fields_dict["host"] = (
str,
field(default=None, metadata={"help": "The remote host to deploy model"}),
)
fields_dict["port"] = (
int,
field(default=None, metadata={"help": "The remote port to deploy model"}),
)
result_class = dataclass(
type("RemoteModelWorkerParameters", (object,), fields_dict)
)
if not models:
return [result_class]
valid_models = []
valid_model_cls = []
for model in models:
if host and host != model.host:
continue
if port and port != model.port:
continue
valid_models += [m.model for m in model.models]
valid_model_cls += [
(m, _build_parameter_class(m.params)) for m in model.models if m.params
]
real_model, real_params_cls = valid_model_cls[0]
real_path = None
real_worker_type = "llm"
if model_name:
params_cls_list = [m for m in valid_model_cls if m[0].model == model_name]
if not params_cls_list:
raise ValueError(f"Not supported model with model name: {model_name}")
real_model, real_params_cls = params_cls_list[0]
real_path = real_model.path
real_worker_type = real_model.worker_type
@dataclass
class RemoteModelWorkerParameters(BaseParameters):
model_name: str = field(
metadata={"valid_values": valid_models, "help": "The model name to deploy"}
)
model_path: Optional[str] = field(
default=real_path, metadata={"help": "The model path to deploy"}
)
host: Optional[str] = field(
default=models[0].host,
metadata={
"valid_values": [model.host for model in models],
"help": "The remote host to deploy model",
},
)
port: Optional[int] = field(
default=models[0].port,
metadata={
"valid_values": [model.port for model in models],
"help": "The remote port to deploy model",
},
)
worker_type: Optional[str] = field(
default=real_worker_type,
metadata={
"valid_values": WorkerType.values(),
"help": "Worker type",
},
)
return [RemoteModelWorkerParameters, real_params_cls]
@model_cli_group.command(
cls=build_lazy_click_command(_dynamic_factory=_remote_model_dynamic_factory)
)
def start(**kwargs):
"""Start model instances"""
from pilot.model.cluster import WorkerStartupRequest, RemoteWorkerManager
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
req = WorkerStartupRequest(
host=kwargs["host"],
port=kwargs["port"],
worker_type=kwargs["worker_type"],
model=kwargs["model_name"],
params={},
)
del kwargs["host"]
del kwargs["port"]
del kwargs["worker_type"]
req.params = kwargs
loop = get_or_create_event_loop()
res = loop.run_until_complete(worker_manager.model_startup(req))
print(res)
@model_cli_group.command()
@@ -165,25 +318,10 @@ def chat(model_name: str, system: str):
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
# @model_cli_group.command()
# @add_model_options
# def modify(address: str, model_name: str, model_type: str):
# """Restart model instances"""
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
def _get_worker_manager(address: str):
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
return worker_manager
def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
):
from pilot.model.worker.manager import WorkerApplyRequest
from pilot.model.cluster import WorkerApplyRequest
loop = get_or_create_event_loop()
worker_manager = _get_worker_manager(address)
@@ -201,7 +339,7 @@ def _cli_chat(address: str, model_name: str, system_prompt: str = None):
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
from pilot.model.worker.manager import PromptRequest
from pilot.model.cluster import PromptRequest
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
@@ -249,13 +387,11 @@ def add_stop_server_options(func):
def start_model_controller(**kwargs):
"""Start model controller"""
from pilot.model.controller.controller import run_model_controller
if kwargs["daemon"]:
log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log")
_run_current_with_daemon("ModelController", log_file)
else:
from pilot.model.controller.controller import run_model_controller
from pilot.model.cluster import run_model_controller
run_model_controller()
@@ -279,9 +415,8 @@ def _model_dynamic_factory() -> Callable[[None], List[Type]]:
return fix_class
@click.command(name="worker")
@EnvArgumentParser.create_click_option(
ModelWorkerParameters, ModelParameters, _dynamic_factory=_model_dynamic_factory
@click.command(
name="worker", cls=build_lazy_click_command(_dynamic_factory=_model_dynamic_factory)
)
def start_model_worker(**kwargs):
"""Start model worker"""
@@ -291,7 +426,7 @@ def start_model_worker(**kwargs):
log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log")
_run_current_with_daemon("ModelWorker", log_file)
else:
from pilot.model.worker.manager import run_worker_manager
from pilot.model.cluster import run_worker_manager
run_worker_manager()

View File

@@ -0,0 +1,33 @@
from pilot.model.cluster.base import (
EmbeddingsRequest,
PromptRequest,
WorkerApplyRequest,
WorkerParameterRequest,
WorkerStartupRequest,
)
from pilot.model.cluster.worker.manager import (
initialize_worker_manager_in_client,
run_worker_manager,
worker_manager,
)
from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.controller.controller import (
ModelRegistryClient,
run_model_controller,
)
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
__all__ = [
"EmbeddingsRequest",
"PromptRequest",
"WorkerApplyRequest",
"WorkerParameterRequest"
"WorkerStartupRequest"
"worker_manager"
"run_worker_manager",
"initialize_worker_manager_in_client",
"ModelRegistry",
"ModelRegistryClient" "RemoteWorkerManager" "run_model_controller",
]

View File

@@ -0,0 +1,46 @@
from dataclasses import dataclass
from typing import Dict, List
from pilot.model.base import WorkerApplyType
from pilot.model.parameter import WorkerType
from pilot.scene.base_message import ModelMessage
from pydantic import BaseModel
WORKER_MANAGER_SERVICE_TYPE = "service"
WORKER_MANAGER_SERVICE_NAME = "WorkerManager"
class PromptRequest(BaseModel):
messages: List[ModelMessage]
model: str
prompt: str = None
temperature: float = None
max_new_tokens: int = None
stop: str = None
echo: bool = True
class EmbeddingsRequest(BaseModel):
model: str
input: List[str]
class WorkerApplyRequest(BaseModel):
model: str
apply_type: WorkerApplyType
worker_type: WorkerType = WorkerType.LLM
params: Dict = None
apply_user: str = None
class WorkerParameterRequest(BaseModel):
model: str
worker_type: WorkerType = WorkerType.LLM
class WorkerStartupRequest(BaseModel):
host: str
port: int
model: str
worker_type: WorkerType
params: Dict

View File

@@ -6,7 +6,7 @@ from typing import List
from fastapi import APIRouter, FastAPI
from pilot.model.base import ModelInstance
from pilot.model.parameter import ModelControllerParameters
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.utils.api_utils import _api_remote as api_remote

View File

@@ -2,9 +2,8 @@ import pytest
from datetime import datetime, timedelta
import asyncio
from unittest.mock import patch
from pilot.model.base import ModelInstance
from pilot.model.controller.registry import ModelRegistry, EmbeddedModelRegistry
from pilot.model.cluster.registry import EmbeddedModelRegistry
@pytest.fixture
@@ -16,7 +15,7 @@ def model_registry():
def model_instance():
return ModelInstance(
model_name="test_model",
ip="192.168.1.1",
host="192.168.1.1",
port=5000,
)
@@ -89,12 +88,7 @@ async def test_send_heartbeat(model_registry, model_instance):
await model_registry.register_instance(model_instance)
last_heartbeat = datetime.now() - timedelta(seconds=10)
model_instance.last_heartbeat = last_heartbeat
assert (
await model_registry.send_heartbeat(
model_instance.model_name, model_instance.ip, model_instance.port
)
== True
)
assert await model_registry.send_heartbeat(model_instance) == True
assert (
model_registry.registry[model_instance.model_name][0].last_heartbeat
> last_heartbeat
@@ -125,7 +119,7 @@ async def test_multiple_instances(model_registry, model_instance):
"""
model_instance2 = ModelInstance(
model_name="test_model",
ip="192.168.1.2",
host="192.168.1.2",
port=5000,
)
await model_registry.register_instance(model_instance)
@@ -138,11 +132,11 @@ async def test_same_model_name_different_ip_port(model_registry):
"""
Test if instances with the same model name but different IP and port are handled correctly
"""
instance1 = ModelInstance(model_name="test_model", ip="192.168.1.1", port=5000)
instance2 = ModelInstance(model_name="test_model", ip="192.168.1.2", port=6000)
instance1 = ModelInstance(model_name="test_model", host="192.168.1.1", port=5000)
instance2 = ModelInstance(model_name="test_model", host="192.168.1.2", port=6000)
await model_registry.register_instance(instance1)
await model_registry.register_instance(instance2)
instances = await model_registry.get_all_instances("test_model")
assert len(instances) == 2
assert instances[0].ip != instances[1].ip
assert instances[0].host != instances[1].host
assert instances[0].port != instances[1].port

View File

@@ -0,0 +1,82 @@
import asyncio
from dataclasses import dataclass
from typing import List, Optional, Dict, Iterator
from abc import ABC, abstractmethod
from datetime import datetime
from concurrent.futures import Future
from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
from pilot.model.parameter import ModelWorkerParameters, ModelParameters
from pilot.utils.parameter_utils import ParameterDescription
@dataclass
class WorkerRunData:
host: str
port: int
worker_key: str
worker: ModelWorker
worker_params: ModelWorkerParameters
model_params: ModelParameters
stop_event: asyncio.Event
semaphore: asyncio.Semaphore = None
command_args: List[str] = None
_heartbeat_future: Optional[Future] = None
_last_heartbeat: Optional[datetime] = None
class WorkerManager(ABC):
@abstractmethod
async def start(self):
"""Start worker manager"""
@abstractmethod
async def stop(self):
"""Stop worker manager"""
@abstractmethod
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Get model instances by worker type and model name"""
@abstractmethod
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
"""Select one instance"""
@abstractmethod
async def supported_models(self) -> List[WorkerSupportedModel]:
"""List supported models"""
@abstractmethod
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
"""Create and start a model instance"""
@abstractmethod
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
"""Shutdown model instance"""
@abstractmethod
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
"""Generate stream result, chat scene"""
@abstractmethod
async def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
@abstractmethod
async def embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input"""
@abstractmethod
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
"""Worker apply"""
@abstractmethod
async def parameter_descriptions(
self, worker_type: str, model_name: str
) -> List[ParameterDescription]:
"""Get parameter descriptions of model"""

View File

@@ -7,7 +7,7 @@ from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.worker.base import ModelWorker
from pilot.model.cluster.worker_base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.parameter_utils import EnvArgumentParser

View File

@@ -7,7 +7,7 @@ from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
)
from pilot.model.worker.base import ModelWorker
from pilot.model.cluster.worker_base import ModelWorker
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.parameter_utils import EnvArgumentParser

View File

@@ -4,13 +4,11 @@ import json
import os
import random
import time
from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List
from fastapi import APIRouter, FastAPI, Request
from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse
from pilot.configs.model_config import LOGDIR
from pilot.model.base import (
@@ -18,103 +16,34 @@ from pilot.model.base import (
ModelOutput,
WorkerApplyOutput,
WorkerApplyType,
WorkerSupportedModel,
)
from pilot.model.controller.registry import ModelRegistry
from pilot.model.cluster.registry import ModelRegistry
from pilot.model.llm_utils import list_supported_models
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.worker.base import ModelWorker
from pilot.scene.base_message import ModelMessage
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.manager_base import WorkerManager, WorkerRunData
from pilot.model.cluster.base import *
from pilot.utils import build_logger
from pilot.utils.parameter_utils import EnvArgumentParser, ParameterDescription
from pydantic import BaseModel
from pilot.utils.parameter_utils import (
EnvArgumentParser,
ParameterDescription,
_dict_to_command_args,
)
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
class PromptRequest(BaseModel):
messages: List[ModelMessage]
model: str
prompt: str = None
temperature: float = None
max_new_tokens: int = None
stop: str = None
echo: bool = True
class EmbeddingsRequest(BaseModel):
model: str
input: List[str]
class WorkerApplyRequest(BaseModel):
model: str
apply_type: WorkerApplyType
worker_type: WorkerType = WorkerType.LLM
params: Dict = None
apply_user: str = None
class WorkerParameterRequest(BaseModel):
model: str
worker_type: WorkerType = WorkerType.LLM
@dataclass
class WorkerRunData:
worker_key: str
worker: ModelWorker
worker_params: ModelWorkerParameters
model_params: ModelParameters
stop_event: asyncio.Event
semaphore: asyncio.Semaphore = None
command_args: List[str] = None
_heartbeat_future: Optional[Future] = None
_last_heartbeat: Optional[datetime] = None
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]]
ApplyFunction = Callable[[WorkerRunData], Awaitable[None]]
class WorkerManager(ABC):
@abstractmethod
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
"""Get model instances by worker type and model name"""
@abstractmethod
async def select_one_instanes(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
"""Select one instances"""
@abstractmethod
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
"""Generate stream result, chat scene"""
@abstractmethod
async def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
@abstractmethod
async def embeddings(self, params: Dict) -> List[List[float]]:
"""Embed input"""
@abstractmethod
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
"""Worker apply"""
@abstractmethod
async def parameter_descriptions(
self, worker_type: str, model_name: str
) -> List[ParameterDescription]:
"""Get parameter descriptions of model"""
async def _async_heartbeat_sender(
worker_run_data: WorkerRunData, send_heartbeat_func: SendHeartbeatFunc
worker_run_data: WorkerRunData,
heartbeat_interval,
send_heartbeat_func: SendHeartbeatFunc,
):
while not worker_run_data.stop_event.is_set():
try:
@@ -122,7 +51,7 @@ async def _async_heartbeat_sender(
except Exception as e:
logger.warn(f"Send heartbeat func error: {str(e)}")
finally:
await asyncio.sleep(worker_run_data.worker_params.heartbeat_interval)
await asyncio.sleep(heartbeat_interval)
class LocalWorkerManager(WorkerManager):
@@ -132,6 +61,8 @@ class LocalWorkerManager(WorkerManager):
deregister_func: DeregisterFunc = None,
send_heartbeat_func: SendHeartbeatFunc = None,
model_registry: ModelRegistry = None,
host: str = None,
port: int = None,
) -> None:
self.workers: Dict[str, List[WorkerRunData]] = dict()
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
@@ -139,19 +70,58 @@ class LocalWorkerManager(WorkerManager):
self.deregister_func = deregister_func
self.send_heartbeat_func = send_heartbeat_func
self.model_registry = model_registry
self.host = host
self.port = port
self.run_data = WorkerRunData(
host=self.host,
port=self.port,
worker_key=self._worker_key(
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
),
worker=None,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=None,
command_args=None,
)
def _worker_key(self, worker_type: str, model_name: str) -> str:
if isinstance(worker_type, WorkerType):
worker_type = worker_type.value
return f"{model_name}@{worker_type}"
async def run_blocking_func(self, func, *args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, func, *args)
async def start(self):
if len(self.workers) > 0:
await self._start_all_worker(apply_req=None)
if self.register_func:
await self.register_func(self.run_data)
if self.send_heartbeat_func:
asyncio.create_task(
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
)
async def stop(self):
if not self.run_data.stop_event.is_set():
logger.info("Stop all workers")
self.run_data.stop_event.clear()
stop_tasks = []
stop_tasks.append(self._stop_all_worker(apply_req=None))
if self.deregister_func:
stop_tasks.append(self.deregister_func(self.run_data))
await asyncio.gather(*stop_tasks)
def add_worker(
self,
worker: ModelWorker,
worker_params: ModelWorkerParameters,
embedded_mod: bool = True,
command_args: List[str] = None,
):
) -> bool:
if not command_args:
import sys
@@ -179,6 +149,8 @@ class LocalWorkerManager(WorkerManager):
model_params = worker.parse_parameters(command_args=command_args)
worker_run_data = WorkerRunData(
host=self.host,
port=self.port,
worker_key=worker_key,
worker=worker,
worker_params=worker_params,
@@ -187,14 +159,66 @@ class LocalWorkerManager(WorkerManager):
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
command_args=command_args,
)
if not embedded_mod:
exist_instances = [
(w, p) for w, p in instances if p.host == host and p.port == port
]
if not exist_instances:
instances.append(worker_run_data)
else:
exist_instances = [
ins for ins in instances if ins.host == host and ins.port == port
]
if not exist_instances:
instances.append(worker_run_data)
return True
else:
# TODO Update worker
return False
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
"""Start model"""
model_name = startup_req.model
worker_type = startup_req.worker_type
params = startup_req.params
logger.debug(
f"start model, model name {model_name}, worker type {worker_type}, params: {params}"
)
worker_params: ModelWorkerParameters = ModelWorkerParameters.from_dict(
params, ignore_extra_fields=True
)
if not worker_params.model_name:
worker_params.model_name = model_name
assert model_name == worker_params.model_name
worker = _build_worker(worker_params)
command_args = _dict_to_command_args(params)
success = await self.run_blocking_func(
self.add_worker, worker, worker_params, command_args
)
if not success:
logger.warn(
f"Add worker failed, worker instances is exist, worker_params: {worker_params}"
)
return False
supported_types = WorkerType.values()
if worker_type not in supported_types:
raise ValueError(
f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}"
)
start_apply_req = WorkerApplyRequest(
model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type
)
await self.worker_apply(start_apply_req)
return True
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}")
apply_req = WorkerApplyRequest(
model=shutdown_req.model,
apply_type=WorkerApplyType.STOP,
worker_type=shutdown_req.worker_type,
)
out = await self._stop_all_worker(apply_req)
if out.success:
return True
raise Exception(out.message)
async def supported_models(self) -> List[WorkerSupportedModel]:
models = await self.run_blocking_func(list_supported_models)
return [WorkerSupportedModel(host=self.host, port=self.port, models=models)]
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
@@ -202,7 +226,7 @@ class LocalWorkerManager(WorkerManager):
worker_key = self._worker_key(worker_type, model_name)
return self.workers.get(worker_key)
async def select_one_instanes(
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
worker_instances = await self.get_model_instances(
@@ -219,7 +243,7 @@ class LocalWorkerManager(WorkerManager):
model = params.get("model")
if not model:
raise Exception("Model name count not be empty")
return await self.select_one_instanes(worker_type, model, healthy_only=True)
return await self.select_one_instance(worker_type, model, healthy_only=True)
async def generate_stream(
self, params: Dict, async_wrapper=None, **kwargs
@@ -262,9 +286,8 @@ class LocalWorkerManager(WorkerManager):
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_generate(params)
else:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, worker_run_data.worker.generate, params
return await self.run_blocking_func(
worker_run_data.worker.generate, params
)
async def embeddings(self, params: Dict) -> List[List[float]]:
@@ -277,9 +300,8 @@ class LocalWorkerManager(WorkerManager):
if worker_run_data.worker.support_async():
return await worker_run_data.worker.async_embeddings(params)
else:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, worker_run_data.worker.embeddings, params
return await self.run_blocking_func(
worker_run_data.worker.embeddings, params
)
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
@@ -342,8 +364,10 @@ class LocalWorkerManager(WorkerManager):
logger.info(f"Begin start all worker, apply_req: {apply_req}")
async def _start_worker(worker_run_data: WorkerRunData):
worker_run_data.worker.start(
worker_run_data.model_params, worker_run_data.command_args
await self.run_blocking_func(
worker_run_data.worker.start,
worker_run_data.model_params,
worker_run_data.command_args,
)
worker_run_data.stop_event.clear()
if worker_run_data.worker_params.register and self.register_func:
@@ -355,7 +379,9 @@ class LocalWorkerManager(WorkerManager):
):
asyncio.create_task(
_async_heartbeat_sender(
worker_run_data, self.send_heartbeat_func
worker_run_data,
worker_run_data.worker_params.heartbeat_interval,
self.send_heartbeat_func,
)
)
@@ -371,7 +397,7 @@ class LocalWorkerManager(WorkerManager):
start_time = time.time()
async def _stop_worker(worker_run_data: WorkerRunData):
worker_run_data.worker.stop()
await self.run_blocking_func(worker_run_data.worker.stop)
# Set stop event
worker_run_data.stop_event.set()
if worker_run_data._heartbeat_future:
@@ -422,63 +448,25 @@ class LocalWorkerManager(WorkerManager):
return WorkerApplyOutput(message=message, timecost=timecost)
class RemoteWorkerManager(LocalWorkerManager):
def __init__(self, model_registry: ModelRegistry = None) -> None:
super().__init__(model_registry=model_registry)
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
from pilot.model.worker.remote_worker import RemoteModelWorker
worker_key = self._worker_key(worker_type, model_name)
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
worker_key, healthy_only
)
worker_instances = []
for ins in instances:
worker = RemoteModelWorker()
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
wr = WorkerRunData(
worker_key=ins.model_name,
worker=worker,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(100), # Not limit in client
)
worker_instances.append(wr)
return worker_instances
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
import httpx
async def _remote_apply_func(worker_run_data: WorkerRunData):
worker_addr = worker_run_data.worker.worker_addr
async with httpx.AsyncClient() as client:
response = await client.post(
worker_addr + "/apply",
headers=worker_run_data.worker.headers,
json=apply_req.dict(),
timeout=worker_run_data.worker.timeout,
)
if response.status_code == 200:
output = WorkerApplyOutput(**response.json())
logger.info(f"worker_apply success: {output}")
else:
output = WorkerApplyOutput(message=response.text)
logger.warn(f"worker_apply failed: {output}")
return output
results = await self._apply_worker(apply_req, _remote_apply_func)
if results:
return results[0]
class WorkerManagerAdapter(WorkerManager):
def __init__(self, worker_manager: WorkerManager = None) -> None:
self.worker_manager = worker_manager
async def start(self):
return await self.worker_manager.start()
async def stop(self):
return await self.worker_manager.stop()
async def supported_models(self) -> List[WorkerSupportedModel]:
return await self.worker_manager.supported_models()
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
return await self.worker_manager.model_startup(startup_req)
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
return await self.worker_manager.model_shutdown(shutdown_req)
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
@@ -486,10 +474,10 @@ class WorkerManagerAdapter(WorkerManager):
worker_type, model_name, healthy_only
)
async def select_one_instanes(
async def select_one_instance(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> WorkerRunData:
return await self.worker_manager.select_one_instanes(
return await self.worker_manager.select_one_instance(
worker_type, model_name, healthy_only
)
@@ -535,37 +523,58 @@ async def api_generate_stream(request: PromptRequest):
@router.post("/worker/generate")
async def api_generate(request: PromptRequest):
params = request.dict(exclude_none=True)
output = await worker_manager.generate(params)
return output
return await worker_manager.generate(params)
@router.post("/worker/embeddings")
async def api_embeddings(request: EmbeddingsRequest):
params = request.dict(exclude_none=True)
output = await worker_manager.embeddings(params)
return output
return await worker_manager.embeddings(params)
@router.post("/worker/apply")
async def api_worker_apply(request: WorkerApplyRequest):
output = await worker_manager.worker_apply(request)
return output
return await worker_manager.worker_apply(request)
@router.get("/worker/parameter/descriptions")
async def api_worker_parameter_descs(
model: str, worker_type: str = WorkerType.LLM.value
):
output = await worker_manager.parameter_descriptions(worker_type, model)
return output
return await worker_manager.parameter_descriptions(worker_type, model)
@router.get("/worker/models/supports")
async def api_supported_models():
"""Get all supported models.
This method reads all models from the configuration file and tries to perform some basic checks on the model (like if the path exists).
If it's a RemoteWorkerManager, this method returns the list of models supported by the entire cluster.
"""
return await worker_manager.supported_models()
@router.post("/worker/models/startup")
async def api_model_startup(request: WorkerStartupRequest):
"""Start up a specific model."""
return await worker_manager.model_startup(request)
@router.post("/worker/models/shutdown")
async def api_model_shutdown(request: WorkerStartupRequest):
"""Shut down a specific model."""
return await worker_manager.model_shutdown(request)
def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
if not app:
app = FastAPI()
if worker_params.standalone:
from pilot.model.controller.controller import router as controller_router
from pilot.model.controller.controller import initialize_controller
from pilot.model.cluster.controller.controller import initialize_controller
from pilot.model.cluster.controller.controller import (
router as controller_router,
)
if not worker_params.controller_addr:
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
@@ -577,9 +586,11 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
@app.on_event("startup")
async def startup_event():
asyncio.create_task(
worker_manager.worker_manager._start_all_worker(apply_req=None)
)
asyncio.create_task(worker_manager.worker_manager.start())
@app.on_event("shutdown")
async def startup_event():
await worker_manager.worker_manager.stop()
return app
@@ -609,22 +620,23 @@ def _parse_worker_params(
def _create_local_model_manager(
worker_params: ModelWorkerParameters,
) -> LocalWorkerManager:
from pilot.utils.net_utils import _get_ip_address
host = (
worker_params.worker_register_host
if worker_params.worker_register_host
else _get_ip_address()
)
port = worker_params.port
if not worker_params.register or not worker_params.controller_addr:
logger.info(
f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}"
)
return LocalWorkerManager()
return LocalWorkerManager(host=host, port=port)
else:
from pilot.model.controller.controller import ModelRegistryClient
from pilot.utils.net_utils import _get_ip_address
from pilot.model.cluster.controller.controller import ModelRegistryClient
client = ModelRegistryClient(worker_params.controller_addr)
host = (
worker_params.worker_register_host
if worker_params.worker_register_host
else _get_ip_address()
)
port = worker_params.port
async def register_func(worker_run_data: WorkerRunData):
instance = ModelInstance(
@@ -648,31 +660,33 @@ def _create_local_model_manager(
register_func=register_func,
deregister_func=deregister_func,
send_heartbeat_func=send_heartbeat_func,
host=host,
port=port,
)
def _start_local_worker(
worker_manager: WorkerManagerAdapter,
worker_params: ModelWorkerParameters,
embedded_mod=True,
):
from pilot.utils.module_utils import import_from_checked_string
def _build_worker(worker_params: ModelWorkerParameters):
if worker_params.worker_class:
from pilot.utils.module_utils import import_from_checked_string
worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker)
logger.info(
f"Import worker class from {worker_params.worker_class} successfully"
)
worker: ModelWorker = worker_cls()
else:
from pilot.model.worker.default_worker import DefaultModelWorker
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
worker = DefaultModelWorker()
return worker
def _start_local_worker(
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
):
worker = _build_worker(worker_params)
worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(
worker, worker_params, embedded_mod=embedded_mod
)
worker_manager.worker_manager.add_worker(worker, worker_params)
def initialize_worker_manager_in_client(
@@ -713,16 +727,13 @@ def initialize_worker_manager_in_client(
worker_params.port = local_port
logger.info(f"Worker params: {worker_params}")
_setup_fastapi(worker_params, app)
_start_local_worker(worker_manager, worker_params, True)
# loop = asyncio.get_event_loop()
# loop.run_until_complete(
# worker_manager.worker_manager._start_all_worker(apply_req=None)
# )
_start_local_worker(worker_manager, worker_params)
else:
from pilot.model.controller.controller import (
initialize_controller,
from pilot.model.cluster.controller.controller import (
ModelRegistryClient,
initialize_controller,
)
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
if not worker_params.controller_addr:
raise ValueError("Controller can`t be None")
@@ -758,13 +769,11 @@ def run_worker_manager(
# Run worker manager independently
embedded_mod = False
app = _setup_fastapi(worker_params)
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
_start_local_worker(worker_manager, worker_params)
else:
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
_start_local_worker(worker_manager, worker_params)
loop = asyncio.get_event_loop()
loop.run_until_complete(
worker_manager.worker_manager._start_all_worker(apply_req=None)
)
loop.run_until_complete(worker_manager.worker_manager.start())
if include_router:
app.include_router(router, prefix="/api")

View File

View File

@@ -0,0 +1,169 @@
from typing import Callable, Any
import httpx
import asyncio
from pilot.model.cluster.registry import ModelRegistry
from pilot.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger
from pilot.model.cluster.base import *
from pilot.model.base import (
ModelInstance,
WorkerApplyOutput,
WorkerSupportedModel,
)
class RemoteWorkerManager(LocalWorkerManager):
def __init__(self, model_registry: ModelRegistry = None) -> None:
super().__init__(model_registry=model_registry)
async def start(self):
pass
async def stop(self):
pass
async def _fetch_from_worker(
self,
worker_run_data: WorkerRunData,
endpoint: str,
method: str = "GET",
json: dict = None,
params: dict = None,
additional_headers: dict = None,
success_handler: Callable = None,
error_handler: Callable = None,
) -> Any:
url = worker_run_data.worker.worker_addr + endpoint
headers = {**worker_run_data.worker.headers, **(additional_headers or {})}
timeout = worker_run_data.worker.timeout
async with httpx.AsyncClient() as client:
request = client.build_request(
method,
url,
json=json, # using json for data to ensure it sends as application/json
params=params,
headers=headers,
timeout=timeout,
)
response = await client.send(request)
if response.status_code != 200:
if error_handler:
return error_handler(response)
else:
error_msg = f"Request to {url} failed, error: {response.text}"
raise Exception(error_msg)
if success_handler:
return success_handler(response)
return response.json()
async def _apply_to_worker_manager_instances(self):
pass
async def supported_models(self) -> List[WorkerSupportedModel]:
worker_instances = await self.get_model_instances(
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
)
async def get_supported_models(worker_run_data) -> List[WorkerSupportedModel]:
def handler(response):
return list(WorkerSupportedModel.from_dict(m) for m in response.json())
return await self._fetch_from_worker(
worker_run_data, "/models/supports", success_handler=handler
)
models = []
results = await asyncio.gather(
*(get_supported_models(worker) for worker in worker_instances)
)
for res in results:
models += res
return models
async def _get_worker_service_instance(
self, host: str = None, port: int = None
) -> List[WorkerRunData]:
worker_instances = await self.get_model_instances(
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
)
error_msg = f"Cound not found worker instances"
if host and port:
worker_instances = [
ins for ins in worker_instances if ins.host == host and ins.port == port
]
error_msg = f"Cound not found worker instances for host {host} port {port}"
if not worker_instances:
raise Exception(error_msg)
return worker_instances
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
worker_instances = await self._get_worker_service_instance(
startup_req.host, startup_req.port
)
worker_run_data = worker_instances[0]
logger.info(f"Start model remote, startup_req: {startup_req}")
return await self._fetch_from_worker(
worker_run_data,
"/models/startup",
method="POST",
json=startup_req.dict(),
success_handler=lambda x: True,
)
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
worker_instances = await self._get_worker_service_instance(
shutdown_req.host, shutdown_req.port
)
worker_run_data = worker_instances[0]
logger.info(f"Shutdown model remote, shutdown_req: {shutdown_req}")
return await self._fetch_from_worker(
worker_run_data,
"/models/shutdown",
method="POST",
json=shutdown_req.dict(),
success_handler=lambda x: True,
)
async def get_model_instances(
self, worker_type: str, model_name: str, healthy_only: bool = True
) -> List[WorkerRunData]:
from pilot.model.cluster.worker.remote_worker import RemoteModelWorker
worker_key = self._worker_key(worker_type, model_name)
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
worker_key, healthy_only
)
worker_instances = []
for ins in instances:
worker = RemoteModelWorker()
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
wr = WorkerRunData(
host=ins.host,
port=ins.port,
worker_key=ins.model_name,
worker=worker,
worker_params=None,
model_params=None,
stop_event=asyncio.Event(),
semaphore=asyncio.Semaphore(100), # Not limit in client
)
worker_instances.append(wr)
return worker_instances
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
async def _remote_apply_func(worker_run_data: WorkerRunData):
return await self._fetch_from_worker(
worker_run_data,
"/apply",
method="POST",
json=apply_req.dict(),
success_handler=lambda res: WorkerApplyOutput(**res.json()),
error_handler=lambda res: WorkerApplyOutput(
message=res.text, success=False
),
)
results = await self._apply_worker(apply_req, _remote_apply_func)
if results:
return results[0]

View File

@@ -3,7 +3,7 @@ from typing import Dict, Iterator, List
import logging
from pilot.model.base import ModelOutput
from pilot.model.parameter import ModelParameters
from pilot.model.worker.base import ModelWorker
from pilot.model.cluster.worker_base import ModelWorker
class RemoteModelWorker(ModelWorker):

View File

@@ -2,14 +2,18 @@
# -*- coding:utf-8 -*-
import traceback
from pathlib import Path
from queue import Queue
from threading import Thread
import transformers
from typing import List, Optional
from typing import List, Optional, Dict
import cachetools
from pilot.configs.config import Config
from pilot.model.base import Message
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.model.base import Message, SupportedModel
from pilot.utils.parameter_utils import _get_parameter_descriptions
def create_chat_completion(
@@ -128,3 +132,49 @@ def is_partial_stop(output: str, stop_str: str):
if stop_str.startswith(output[-i:]):
return True
return False
@cachetools.cached(cachetools.TTLCache(maxsize=100, ttl=60 * 5))
def list_supported_models():
from pilot.model.parameter import WorkerType
models = _list_supported_models(WorkerType.LLM.value, LLM_MODEL_CONFIG)
models += _list_supported_models(WorkerType.TEXT2VEC.value, EMBEDDING_MODEL_CONFIG)
return models
def _list_supported_models(
worker_type: str, model_config: Dict[str, str]
) -> List[SupportedModel]:
from pilot.model.adapter import get_llm_model_adapter
from pilot.model.parameter import ModelParameters
from pilot.model.loader import _get_model_real_path
ret = []
for model_name, model_path in model_config.items():
model_path = _get_model_real_path(model_name, model_path)
model = SupportedModel(
model=model_name,
path=model_path,
worker_type=worker_type,
path_exist=False,
proxy=False,
enabled=False,
params=None,
)
if "proxyllm" in model_name:
model.proxy = True
else:
path = Path(model_path)
model.path_exist = path.exists()
param_cls = None
try:
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_cls = llm_adapter.model_param_class()
model.enabled = True
params = _get_parameter_descriptions(param_cls)
model.params = params
except Exception:
pass
ret.append(model)
return ret

View File

@@ -353,5 +353,6 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam
def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters):
from pilot.model.proxy.llms.proxy_model import ProxyModel
logger.info("Load proxyllm")
model = ProxyModel(model_params)
return model, model

View File

@@ -33,9 +33,13 @@ class ModelControllerParameters(BaseParameters):
@dataclass
class ModelWorkerParameters(BaseParameters):
class BaseModelParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
@dataclass
class ModelWorkerParameters(BaseModelParameters):
worker_type: Optional[str] = field(
default=None,
metadata={"valid_values": WorkerType.values(), "help": "Worker type"},
@@ -84,9 +88,7 @@ class ModelWorkerParameters(BaseParameters):
@dataclass
class EmbeddingModelParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
class EmbeddingModelParameters(BaseModelParameters):
device: Optional[str] = field(
default=None,
metadata={
@@ -114,12 +116,6 @@ class EmbeddingModelParameters(BaseParameters):
return kwargs
@dataclass
class BaseModelParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
@dataclass
class ModelParameters(BaseModelParameters):
device: Optional[str] = field(

View File

@@ -139,7 +139,7 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
from pilot.model.worker.manager import worker_manager
from pilot.model.cluster import worker_manager
async for output in worker_manager.generate_stream(payload):
yield output
@@ -157,7 +157,7 @@ class BaseChat(ABC):
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:
from pilot.model.worker.manager import worker_manager
from pilot.model.cluster import worker_manager
model_output = await worker_manager.generate(payload)

View File

@@ -6,7 +6,7 @@ from pilot.configs.config import Config
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
EMBEDDING_MODEL_CONFIG,
)
from pilot.scene.chat_knowledge.v1.prompt import prompt
@@ -48,7 +48,7 @@ class ChatKnowledge(BaseChat):
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
self.knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)

View File

@@ -23,7 +23,7 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.worker.manager import initialize_worker_manager_in_client
from pilot.model.cluster import initialize_worker_manager_in_client
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
static_file_path = os.path.join(os.getcwd(), "server/static")

View File

@@ -7,7 +7,10 @@ from fastapi import APIRouter, File, UploadFile, Form
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from pilot.openapi.api_view_model import Result
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
@@ -29,7 +32,9 @@ CFG = Config()
router = APIRouter()
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL])
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
knowledge_space_service = KnowledgeService()

View File

@@ -5,7 +5,10 @@ from datetime import datetime
from pilot.vector_store.connector import VectorStoreConnector
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from pilot.logs import logger
from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity,
@@ -204,7 +207,7 @@ class KnowledgeService:
client = EmbeddingEngine(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": space_name,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
@@ -341,7 +344,7 @@ class KnowledgeService:
"topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
"recall_score": 0.0,
"recall_type": "TopK",
"model": LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL].rsplit("/", 1)[-1],
"model": EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL].rsplit("/", 1)[-1],
"chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE,
"chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP,
},

View File

@@ -9,7 +9,7 @@ sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.model.worker.manager import run_worker_manager
from pilot.model.cluster import run_worker_manager
CFG = Config()

View File

@@ -5,7 +5,7 @@ from pilot.common.schema import DBType
from pilot.configs.config import Config
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
EMBEDDING_MODEL_CONFIG,
LOGDIR,
)
from pilot.scene.base import ChatScene
@@ -36,7 +36,7 @@ class DBSummaryClient:
db_summary_client = RdbmsSummary(dbname, db_type)
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vector_store_config = {
"vector_store_name": dbname + "_summary",
@@ -90,7 +90,7 @@ class DBSummaryClient:
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)
table_docs = knowledge_embedding_client.similar_search(query, topk)
@@ -108,7 +108,7 @@ class DBSummaryClient:
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)
if CFG.SUMMARY_CONFIG == "FAST":
@@ -134,7 +134,7 @@ class DBSummaryClient:
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = EmbeddingEngine(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
)
table_summery = knowledge_embedding_client.similar_search(query, 1)

View File

@@ -4,6 +4,7 @@ import subprocess
from typing import List, Dict
import psutil
import platform
from functools import lru_cache
def _get_abspath_of_current_command(command_path: str):
@@ -137,6 +138,7 @@ def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]:
return ports
@lru_cache()
def _detect_controller_address() -> str:
controller_addr = os.getenv("CONTROLLER_ADDRESS")
if controller_addr:

View File

@@ -2,7 +2,7 @@ from typing import Type
from importlib import import_module
def import_from_string(module_path: str):
def import_from_string(module_path: str, ignore_import_error: bool = False):
try:
module_path, class_name = module_path.rsplit(".", 1)
except ValueError:
@@ -12,6 +12,8 @@ def import_from_string(module_path: str):
try:
return getattr(module, class_name)
except AttributeError:
if ignore_import_error:
return None
raise ImportError(
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
)

View File

@@ -1,21 +1,50 @@
import argparse
import os
from dataclasses import dataclass, fields, MISSING
from dataclasses import dataclass, fields, MISSING, asdict, field
from typing import Any, List, Optional, Type, Union, Callable, Dict
from collections import OrderedDict
@dataclass
class ParameterDescription:
param_class: str
param_name: str
param_type: str
description: str
default_value: Optional[Any]
description: str
valid_values: Optional[List[Any]]
ext_metadata: Dict
@dataclass
class BaseParameters:
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = False
) -> "BaseParameters":
"""Create an instance of the dataclass from a dictionary.
Args:
data: A dictionary containing values for the dataclass fields.
ignore_extra_fields: If True, any extra fields in the data dictionary that are
not part of the dataclass will be ignored.
If False, extra fields will raise an error. Defaults to False.
Returns:
An instance of the dataclass with values populated from the given dictionary.
Raises:
TypeError: If `ignore_extra_fields` is False and there are fields in the
dictionary that aren't present in the dataclass.
"""
all_field_names = {f.name for f in fields(cls)}
if ignore_extra_fields:
data = {key: value for key, value in data.items() if key in all_field_names}
else:
extra_fields = set(data.keys()) - all_field_names
if extra_fields:
raise TypeError(f"Unexpected fields: {', '.join(extra_fields)}")
return cls(**data)
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
"""
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
@@ -68,6 +97,35 @@ class BaseParameters:
)
return "\n".join(parameters)
def to_command_args(self, args_prefix: str = "--") -> List[str]:
"""Convert the fields of the dataclass to a list of command line arguments.
Args:
args_prefix: args prefix
Returns:
A list of strings where each field is represented by two items:
one for the field name prefixed by args_prefix, and one for its value.
"""
return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]:
"""Convert dict to a list of command line arguments
Args:
obj: dict
Returns:
A list of strings where each field is represented by two items:
one for the field name prefixed by args_prefix, and one for its value.
"""
args = []
for key, value in obj.items():
if value is None:
continue
args.append(f"{args_prefix}{key}")
args.append(str(value))
return args
def _get_simple_privacy_field_value(obj, field_info):
"""Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary.
@@ -81,9 +139,9 @@ def _get_simple_privacy_field_value(obj, field_info):
- str: if length > 5, masks the middle part and returns first and last char;
otherwise, returns "******"
Parameters:
- obj: The dataclass instance.
- field_info: A Field object that contains information about the dataclass field.
Args:
obj: The dataclass instance.
field_info: A Field object that contains information about the dataclass field.
Returns:
The original or modified value of the field based on the privacy rules.
@@ -202,11 +260,42 @@ class EnvArgumentParser:
parser.add_argument(f"--{field.name}", **argument_kwargs)
return parser
@staticmethod
def _create_click_option_from_field(field_name: str, field: Type, is_func=True):
import click
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
cli_params = {
"default": None if field.default is MISSING else field.default,
"help": help_text,
"show_default": True,
"required": field.default is MISSING,
}
if valid_values:
cli_params["type"] = click.Choice(valid_values)
real_type = EnvArgumentParser._get_argparse_type(field.type)
if real_type is int:
cli_params["type"] = click.INT
elif real_type is float:
cli_params["type"] = click.FLOAT
elif real_type is str:
cli_params["type"] = click.STRING
elif real_type is bool:
cli_params["is_flag"] = True
name = f"--{field_name}"
if is_func:
return click.option(
name,
**cli_params,
)
else:
return click.Option([name], **cli_params)
@staticmethod
def create_click_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
):
import click
import functools
from collections import OrderedDict
@@ -222,30 +311,8 @@ class EnvArgumentParser:
def decorator(func):
for field_name, field in reversed(combined_fields.items()):
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
cli_params = {
"default": None if field.default is MISSING else field.default,
"help": help_text,
"show_default": True,
"required": field.default is MISSING,
}
if valid_values:
cli_params["type"] = click.Choice(valid_values)
real_type = EnvArgumentParser._get_argparse_type(field.type)
if real_type is int:
cli_params["type"] = click.INT
elif real_type is float:
cli_params["type"] = click.FLOAT
elif real_type is str:
cli_params["type"] = click.STRING
elif real_type is bool:
cli_params["is_flag"] = True
option_decorator = click.option(
# f"--{field_name.replace('_', '-')}", **cli_params
f"--{field_name}",
**cli_params,
option_decorator = EnvArgumentParser._create_click_option_from_field(
field_name, field
)
func = option_decorator(func)
@@ -257,6 +324,23 @@ class EnvArgumentParser:
return decorator
@staticmethod
def _create_raw_click_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
):
combined_fields = _merge_dataclass_types(
*dataclass_types, _dynamic_factory=_dynamic_factory
)
options = []
for field_name, field in reversed(combined_fields.items()):
options.append(
EnvArgumentParser._create_click_option_from_field(
field_name, field, is_func=False
)
)
return options
@staticmethod
def create_argparse_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
@@ -366,21 +450,70 @@ def _merge_dataclass_types(
return combined_fields
def _type_str_to_python_type(type_str: str) -> Type:
type_mapping: Dict[str, Type] = {
"int": int,
"float": float,
"bool": bool,
"str": str,
}
return type_mapping.get(type_str, str)
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
descriptions = []
for field in fields(dataclass_type):
ext_metadata = {
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
}
descriptions.append(
ParameterDescription(
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
param_name=field.name,
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
description=field.metadata.get("help", None),
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
default_value=field.default if field.default != MISSING else None,
valid_values=field.metadata.get("valid_values", None),
ext_metadata=ext_metadata,
)
)
return descriptions
def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
from pilot.utils.module_utils import import_from_string
if not desc:
raise ValueError("Parameter descriptions cant be empty")
param_class_str = desc[0].param_class
param_class = import_from_string(param_class_str, ignore_import_error=True)
if param_class:
return param_class
module_name, _, class_name = param_class_str.rpartition(".")
fields_dict = {} # This will store field names and their default values or field()
annotations = {} # This will store the type annotations for the fields
for d in desc:
metadata = d.ext_metadata if d.ext_metadata else {}
metadata["help"] = d.description
metadata["valid_values"] = d.valid_values
annotations[d.param_name] = _type_str_to_python_type(
d.param_type
) # Set type annotation
fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata)
# Create the new class. Note the setting of __annotations__ for type hints
new_class = type(
class_name, (object,), {**fields_dict, "__annotations__": annotations}
)
result_class = dataclass(new_class) # Make it a dataclass
return result_class
class _SimpleArgParser:
def __init__(self, *args):
self.params = {arg.replace("_", "-"): None for arg in args}
@@ -422,3 +555,24 @@ class _SimpleArgParser:
return "\n".join(
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
)
def build_lazy_click_command(*dataclass_types: Type, _dynamic_factory=None):
import click
class LazyCommand(click.Command):
def __init__(self, *args, **kwargs):
super(LazyCommand, self).__init__(*args, **kwargs)
self.dynamic_params_added = False
def get_params(self, ctx):
if ctx and not self.dynamic_params_added:
dynamic_params = EnvArgumentParser._create_raw_click_option(
*dataclass_types, _dynamic_factory=_dynamic_factory
)
for param in reversed(dynamic_params):
self.params.append(param)
self.dynamic_params_added = True
return super(LazyCommand, self).get_params(ctx)
return LazyCommand

View File

@@ -29,10 +29,10 @@ def knownledge_tovec_st(filename):
"""Use sentence transformers to embedding the document.
https://github.com/UKPLab/sentence-transformers
"""
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
)
with open(filename, "r") as f:
@@ -57,10 +57,10 @@ def load_knownledge_from_doc():
"Not Exists Local DataSets, We will answers the Question use model default."
)
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
)
files = os.listdir(DATASETS_DIR)

View File

@@ -16,7 +16,7 @@ from langchain.vectorstores import Chroma
from pilot.configs.model_config import (
DATASETS_DIR,
LLM_MODEL_CONFIG,
EMBEDDING_MODEL_CONFIG,
VECTORE_PATH,
)
@@ -39,7 +39,7 @@ class KnownLedge2Vector:
"""
embeddings: object = None
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
model_name = EMBEDDING_MODEL_CONFIG["sentence-transforms"]
def __init__(self, model_name=None) -> None:
if not model_name:

View File

@@ -79,4 +79,5 @@ duckdb
duckdb-engine
# cli
prettytable
prettytable
cachetools