mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 09:06:55 +00:00
feat(model): supports the deployment of multiple models through the API and add the corresponding command line interface (#570)
- Split `LLM_MODEL_CONFIG` into `LLM_MODEL_CONFIG` and `EMBEDDING_MODEL_CONFIG`. - New HTTP API to obtain the list of models and configuration parameters supported by the current cluster. - New HTTP API to launch models on a specified machine. - The command line supports above HTTP API.
This commit is contained in:
@@ -41,25 +41,12 @@ LLM_MODEL_CONFIG = {
|
||||
# (Llama2 based) see https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||
"vicuna-13b-v1.5": os.path.join(MODEL_PATH, "vicuna-13b-v1.5"),
|
||||
"vicuna-7b-v1.5": os.path.join(MODEL_PATH, "vicuna-7b-v1.5"),
|
||||
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
# https://huggingface.co/moka-ai/m3e-large
|
||||
"m3e-base": os.path.join(MODEL_PATH, "m3e-base"),
|
||||
# https://huggingface.co/moka-ai/m3e-base
|
||||
"m3e-large": os.path.join(MODEL_PATH, "m3e-large"),
|
||||
# https://huggingface.co/BAAI/bge-large-en
|
||||
"bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"),
|
||||
"bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"),
|
||||
# https://huggingface.co/BAAI/bge-large-zh
|
||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
|
||||
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
|
||||
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
|
||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||
"chatglm2-6b": os.path.join(MODEL_PATH, "chatglm2-6b"),
|
||||
"chatglm2-6b-int4": os.path.join(MODEL_PATH, "chatglm2-6b-int4"),
|
||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
||||
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||
@@ -84,6 +71,22 @@ LLM_MODEL_CONFIG = {
|
||||
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
||||
}
|
||||
|
||||
EMBEDDING_MODEL_CONFIG = {
|
||||
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||
# https://huggingface.co/moka-ai/m3e-large
|
||||
"m3e-base": os.path.join(MODEL_PATH, "m3e-base"),
|
||||
# https://huggingface.co/moka-ai/m3e-base
|
||||
"m3e-large": os.path.join(MODEL_PATH, "m3e-large"),
|
||||
# https://huggingface.co/BAAI/bge-large-en
|
||||
"bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"),
|
||||
"bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"),
|
||||
# https://huggingface.co/BAAI/bge-large-zh
|
||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
}
|
||||
|
||||
# Load model config
|
||||
ISDEBUG = False
|
||||
|
||||
|
@@ -108,6 +108,17 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||
return [param_class]
|
||||
|
||||
|
||||
def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
|
||||
try:
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||
return llm_adapter.model_param_class()
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
f"Parse model parameters with model name {model_name} and model {model_path} failed {str(e)}, return `ModelParameters`"
|
||||
)
|
||||
return ModelParameters
|
||||
|
||||
|
||||
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||
|
||||
|
||||
|
@@ -2,9 +2,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from enum import Enum
|
||||
from typing import TypedDict, Optional, Dict
|
||||
from typing import TypedDict, Optional, Dict, List
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pilot.utils.parameter_utils import ParameterDescription
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
@@ -46,5 +47,40 @@ class ModelOutput:
|
||||
@dataclass
|
||||
class WorkerApplyOutput:
|
||||
message: str
|
||||
success: Optional[bool] = True
|
||||
# The seconds cost to apply some action to worker instances
|
||||
timecost: Optional[int] = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupportedModel:
|
||||
model: str
|
||||
path: str
|
||||
worker_type: str
|
||||
path_exist: bool
|
||||
proxy: bool
|
||||
enabled: bool
|
||||
params: List[ParameterDescription]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, model_data: Dict) -> "SupportedModel":
|
||||
params = model_data.get("params", [])
|
||||
if params:
|
||||
params = [ParameterDescription(**param) for param in params]
|
||||
model_data["params"] = params
|
||||
return cls(**model_data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerSupportedModel:
|
||||
host: str
|
||||
port: int
|
||||
models: List[SupportedModel]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, worker_data: Dict) -> "WorkerSupportedModel":
|
||||
models = [
|
||||
SupportedModel.from_dict(model_data) for model_data in worker_data["models"]
|
||||
]
|
||||
worker_data["models"] = models
|
||||
return cls(**worker_data)
|
||||
|
@@ -2,19 +2,27 @@ import click
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Type
|
||||
from typing import Callable, List, Type, Optional
|
||||
|
||||
from pilot.model.controller.controller import ModelRegistryClient
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import WorkerApplyType
|
||||
from pilot.model.parameter import (
|
||||
ModelControllerParameters,
|
||||
ModelWorkerParameters,
|
||||
ModelParameters,
|
||||
BaseParameters,
|
||||
)
|
||||
from pilot.utils import get_or_create_event_loop
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
from pilot.utils.command_utils import _run_current_with_daemon, _stop_service
|
||||
from pilot.utils.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
_build_parameter_class,
|
||||
build_lazy_click_command,
|
||||
)
|
||||
from pilot.utils.command_utils import (
|
||||
_run_current_with_daemon,
|
||||
_stop_service,
|
||||
_detect_controller_address,
|
||||
)
|
||||
|
||||
|
||||
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
||||
@@ -22,6 +30,14 @@ MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
||||
logger = logging.getLogger("dbgpt_cli")
|
||||
|
||||
|
||||
def _get_worker_manager(address: str):
|
||||
from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient
|
||||
|
||||
registry = ModelRegistryClient(address)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
return worker_manager
|
||||
|
||||
|
||||
@click.group("model")
|
||||
@click.option(
|
||||
"--address",
|
||||
@@ -38,8 +54,6 @@ def model_cli_group(address: str):
|
||||
"""Clients that manage model serving"""
|
||||
global MODEL_CONTROLLER_ADDRESS
|
||||
if not address:
|
||||
from pilot.utils.command_utils import _detect_controller_address
|
||||
|
||||
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
|
||||
else:
|
||||
MODEL_CONTROLLER_ADDRESS = address
|
||||
@@ -55,6 +69,7 @@ def model_cli_group(address: str):
|
||||
def list(model_name: str, model_type: str):
|
||||
"""List model instances"""
|
||||
from prettytable import PrettyTable
|
||||
from pilot.model.cluster import ModelRegistryClient
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
|
||||
@@ -90,7 +105,7 @@ def list(model_name: str, model_type: str):
|
||||
instance.port,
|
||||
instance.healthy,
|
||||
instance.enabled,
|
||||
instance.prompt_template,
|
||||
instance.prompt_template if instance.prompt_template else "",
|
||||
instance.last_heartbeat,
|
||||
]
|
||||
)
|
||||
@@ -122,18 +137,156 @@ def add_model_options(func):
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def stop(model_name: str, model_type: str):
|
||||
@click.option(
|
||||
"--host",
|
||||
type=str,
|
||||
required=True,
|
||||
help=("The remote host to stop model"),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
required=True,
|
||||
help=("The remote port to stop model"),
|
||||
)
|
||||
def stop(model_name: str, model_type: str, host: str, port: int):
|
||||
"""Stop model instances"""
|
||||
worker_apply(MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.STOP)
|
||||
from pilot.model.cluster import WorkerStartupRequest, RemoteWorkerManager
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def start(model_name: str, model_type: str):
|
||||
"""Start model instances"""
|
||||
worker_apply(
|
||||
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.START
|
||||
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
|
||||
req = WorkerStartupRequest(
|
||||
host=host,
|
||||
port=port,
|
||||
worker_type=model_type,
|
||||
model=model_name,
|
||||
params={},
|
||||
)
|
||||
loop = get_or_create_event_loop()
|
||||
res = loop.run_until_complete(worker_manager.model_shutdown(req))
|
||||
print(res)
|
||||
|
||||
|
||||
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||
from pilot.model.adapter import _dynamic_model_parser
|
||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||
from pilot.model.cluster import RemoteWorkerManager
|
||||
from pilot.model.parameter import WorkerType
|
||||
from dataclasses import dataclass, field, fields
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "address", "host", "port")
|
||||
pre_args.parse()
|
||||
model_name = pre_args.get("model_name")
|
||||
address = pre_args.get("address")
|
||||
host = pre_args.get("host")
|
||||
port = pre_args.get("port")
|
||||
if port:
|
||||
port = int(port)
|
||||
|
||||
if not address:
|
||||
address = _detect_controller_address()
|
||||
|
||||
worker_manager: RemoteWorkerManager = _get_worker_manager(address)
|
||||
loop = get_or_create_event_loop()
|
||||
models = loop.run_until_complete(worker_manager.supported_models())
|
||||
|
||||
fields_dict = {}
|
||||
fields_dict["model_name"] = (
|
||||
str,
|
||||
field(default=None, metadata={"help": "The model name to deploy"}),
|
||||
)
|
||||
fields_dict["host"] = (
|
||||
str,
|
||||
field(default=None, metadata={"help": "The remote host to deploy model"}),
|
||||
)
|
||||
fields_dict["port"] = (
|
||||
int,
|
||||
field(default=None, metadata={"help": "The remote port to deploy model"}),
|
||||
)
|
||||
result_class = dataclass(
|
||||
type("RemoteModelWorkerParameters", (object,), fields_dict)
|
||||
)
|
||||
|
||||
if not models:
|
||||
return [result_class]
|
||||
|
||||
valid_models = []
|
||||
valid_model_cls = []
|
||||
for model in models:
|
||||
if host and host != model.host:
|
||||
continue
|
||||
if port and port != model.port:
|
||||
continue
|
||||
valid_models += [m.model for m in model.models]
|
||||
valid_model_cls += [
|
||||
(m, _build_parameter_class(m.params)) for m in model.models if m.params
|
||||
]
|
||||
real_model, real_params_cls = valid_model_cls[0]
|
||||
real_path = None
|
||||
real_worker_type = "llm"
|
||||
if model_name:
|
||||
params_cls_list = [m for m in valid_model_cls if m[0].model == model_name]
|
||||
if not params_cls_list:
|
||||
raise ValueError(f"Not supported model with model name: {model_name}")
|
||||
real_model, real_params_cls = params_cls_list[0]
|
||||
real_path = real_model.path
|
||||
real_worker_type = real_model.worker_type
|
||||
|
||||
@dataclass
|
||||
class RemoteModelWorkerParameters(BaseParameters):
|
||||
model_name: str = field(
|
||||
metadata={"valid_values": valid_models, "help": "The model name to deploy"}
|
||||
)
|
||||
model_path: Optional[str] = field(
|
||||
default=real_path, metadata={"help": "The model path to deploy"}
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default=models[0].host,
|
||||
metadata={
|
||||
"valid_values": [model.host for model in models],
|
||||
"help": "The remote host to deploy model",
|
||||
},
|
||||
)
|
||||
|
||||
port: Optional[int] = field(
|
||||
default=models[0].port,
|
||||
metadata={
|
||||
"valid_values": [model.port for model in models],
|
||||
"help": "The remote port to deploy model",
|
||||
},
|
||||
)
|
||||
worker_type: Optional[str] = field(
|
||||
default=real_worker_type,
|
||||
metadata={
|
||||
"valid_values": WorkerType.values(),
|
||||
"help": "Worker type",
|
||||
},
|
||||
)
|
||||
|
||||
return [RemoteModelWorkerParameters, real_params_cls]
|
||||
|
||||
|
||||
@model_cli_group.command(
|
||||
cls=build_lazy_click_command(_dynamic_factory=_remote_model_dynamic_factory)
|
||||
)
|
||||
def start(**kwargs):
|
||||
"""Start model instances"""
|
||||
from pilot.model.cluster import WorkerStartupRequest, RemoteWorkerManager
|
||||
|
||||
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
|
||||
req = WorkerStartupRequest(
|
||||
host=kwargs["host"],
|
||||
port=kwargs["port"],
|
||||
worker_type=kwargs["worker_type"],
|
||||
model=kwargs["model_name"],
|
||||
params={},
|
||||
)
|
||||
del kwargs["host"]
|
||||
del kwargs["port"]
|
||||
del kwargs["worker_type"]
|
||||
req.params = kwargs
|
||||
loop = get_or_create_event_loop()
|
||||
res = loop.run_until_complete(worker_manager.model_startup(req))
|
||||
print(res)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@@ -165,25 +318,10 @@ def chat(model_name: str, system: str):
|
||||
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
|
||||
|
||||
|
||||
# @model_cli_group.command()
|
||||
# @add_model_options
|
||||
# def modify(address: str, model_name: str, model_type: str):
|
||||
# """Restart model instances"""
|
||||
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
|
||||
|
||||
|
||||
def _get_worker_manager(address: str):
|
||||
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
|
||||
|
||||
registry = ModelRegistryClient(address)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
return worker_manager
|
||||
|
||||
|
||||
def worker_apply(
|
||||
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||
):
|
||||
from pilot.model.worker.manager import WorkerApplyRequest
|
||||
from pilot.model.cluster import WorkerApplyRequest
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
worker_manager = _get_worker_manager(address)
|
||||
@@ -201,7 +339,7 @@ def _cli_chat(address: str, model_name: str, system_prompt: str = None):
|
||||
|
||||
|
||||
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
|
||||
from pilot.model.worker.manager import PromptRequest
|
||||
from pilot.model.cluster import PromptRequest
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
|
||||
@@ -249,13 +387,11 @@ def add_stop_server_options(func):
|
||||
def start_model_controller(**kwargs):
|
||||
"""Start model controller"""
|
||||
|
||||
from pilot.model.controller.controller import run_model_controller
|
||||
|
||||
if kwargs["daemon"]:
|
||||
log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log")
|
||||
_run_current_with_daemon("ModelController", log_file)
|
||||
else:
|
||||
from pilot.model.controller.controller import run_model_controller
|
||||
from pilot.model.cluster import run_model_controller
|
||||
|
||||
run_model_controller()
|
||||
|
||||
@@ -279,9 +415,8 @@ def _model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||
return fix_class
|
||||
|
||||
|
||||
@click.command(name="worker")
|
||||
@EnvArgumentParser.create_click_option(
|
||||
ModelWorkerParameters, ModelParameters, _dynamic_factory=_model_dynamic_factory
|
||||
@click.command(
|
||||
name="worker", cls=build_lazy_click_command(_dynamic_factory=_model_dynamic_factory)
|
||||
)
|
||||
def start_model_worker(**kwargs):
|
||||
"""Start model worker"""
|
||||
@@ -291,7 +426,7 @@ def start_model_worker(**kwargs):
|
||||
log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log")
|
||||
_run_current_with_daemon("ModelWorker", log_file)
|
||||
else:
|
||||
from pilot.model.worker.manager import run_worker_manager
|
||||
from pilot.model.cluster import run_worker_manager
|
||||
|
||||
run_worker_manager()
|
||||
|
||||
|
33
pilot/model/cluster/__init__.py
Normal file
33
pilot/model/cluster/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from pilot.model.cluster.base import (
|
||||
EmbeddingsRequest,
|
||||
PromptRequest,
|
||||
WorkerApplyRequest,
|
||||
WorkerParameterRequest,
|
||||
WorkerStartupRequest,
|
||||
)
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
initialize_worker_manager_in_client,
|
||||
run_worker_manager,
|
||||
worker_manager,
|
||||
)
|
||||
|
||||
from pilot.model.cluster.registry import ModelRegistry
|
||||
from pilot.model.cluster.controller.controller import (
|
||||
ModelRegistryClient,
|
||||
run_model_controller,
|
||||
)
|
||||
|
||||
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingsRequest",
|
||||
"PromptRequest",
|
||||
"WorkerApplyRequest",
|
||||
"WorkerParameterRequest"
|
||||
"WorkerStartupRequest"
|
||||
"worker_manager"
|
||||
"run_worker_manager",
|
||||
"initialize_worker_manager_in_client",
|
||||
"ModelRegistry",
|
||||
"ModelRegistryClient" "RemoteWorkerManager" "run_model_controller",
|
||||
]
|
46
pilot/model/cluster/base.py
Normal file
46
pilot/model/cluster/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from pilot.model.base import WorkerApplyType
|
||||
from pilot.model.parameter import WorkerType
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
WORKER_MANAGER_SERVICE_TYPE = "service"
|
||||
WORKER_MANAGER_SERVICE_NAME = "WorkerManager"
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
messages: List[ModelMessage]
|
||||
model: str
|
||||
prompt: str = None
|
||||
temperature: float = None
|
||||
max_new_tokens: int = None
|
||||
stop: str = None
|
||||
echo: bool = True
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
model: str
|
||||
input: List[str]
|
||||
|
||||
|
||||
class WorkerApplyRequest(BaseModel):
|
||||
model: str
|
||||
apply_type: WorkerApplyType
|
||||
worker_type: WorkerType = WorkerType.LLM
|
||||
params: Dict = None
|
||||
apply_user: str = None
|
||||
|
||||
|
||||
class WorkerParameterRequest(BaseModel):
|
||||
model: str
|
||||
worker_type: WorkerType = WorkerType.LLM
|
||||
|
||||
|
||||
class WorkerStartupRequest(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
model: str
|
||||
worker_type: WorkerType
|
||||
params: Dict
|
@@ -6,7 +6,7 @@ from typing import List
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from pilot.model.base import ModelInstance
|
||||
from pilot.model.parameter import ModelControllerParameters
|
||||
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
from pilot.utils.api_utils import _api_remote as api_remote
|
||||
|
@@ -2,9 +2,8 @@ import pytest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
from pilot.model.base import ModelInstance
|
||||
from pilot.model.controller.registry import ModelRegistry, EmbeddedModelRegistry
|
||||
from pilot.model.cluster.registry import EmbeddedModelRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -16,7 +15,7 @@ def model_registry():
|
||||
def model_instance():
|
||||
return ModelInstance(
|
||||
model_name="test_model",
|
||||
ip="192.168.1.1",
|
||||
host="192.168.1.1",
|
||||
port=5000,
|
||||
)
|
||||
|
||||
@@ -89,12 +88,7 @@ async def test_send_heartbeat(model_registry, model_instance):
|
||||
await model_registry.register_instance(model_instance)
|
||||
last_heartbeat = datetime.now() - timedelta(seconds=10)
|
||||
model_instance.last_heartbeat = last_heartbeat
|
||||
assert (
|
||||
await model_registry.send_heartbeat(
|
||||
model_instance.model_name, model_instance.ip, model_instance.port
|
||||
)
|
||||
== True
|
||||
)
|
||||
assert await model_registry.send_heartbeat(model_instance) == True
|
||||
assert (
|
||||
model_registry.registry[model_instance.model_name][0].last_heartbeat
|
||||
> last_heartbeat
|
||||
@@ -125,7 +119,7 @@ async def test_multiple_instances(model_registry, model_instance):
|
||||
"""
|
||||
model_instance2 = ModelInstance(
|
||||
model_name="test_model",
|
||||
ip="192.168.1.2",
|
||||
host="192.168.1.2",
|
||||
port=5000,
|
||||
)
|
||||
await model_registry.register_instance(model_instance)
|
||||
@@ -138,11 +132,11 @@ async def test_same_model_name_different_ip_port(model_registry):
|
||||
"""
|
||||
Test if instances with the same model name but different IP and port are handled correctly
|
||||
"""
|
||||
instance1 = ModelInstance(model_name="test_model", ip="192.168.1.1", port=5000)
|
||||
instance2 = ModelInstance(model_name="test_model", ip="192.168.1.2", port=6000)
|
||||
instance1 = ModelInstance(model_name="test_model", host="192.168.1.1", port=5000)
|
||||
instance2 = ModelInstance(model_name="test_model", host="192.168.1.2", port=6000)
|
||||
await model_registry.register_instance(instance1)
|
||||
await model_registry.register_instance(instance2)
|
||||
instances = await model_registry.get_all_instances("test_model")
|
||||
assert len(instances) == 2
|
||||
assert instances[0].ip != instances[1].ip
|
||||
assert instances[0].host != instances[1].host
|
||||
assert instances[0].port != instances[1].port
|
82
pilot/model/cluster/manager_base.py
Normal file
82
pilot/model/cluster/manager_base.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Dict, Iterator
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from concurrent.futures import Future
|
||||
from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
|
||||
from pilot.model.parameter import ModelWorkerParameters, ModelParameters
|
||||
from pilot.utils.parameter_utils import ParameterDescription
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerRunData:
|
||||
host: str
|
||||
port: int
|
||||
worker_key: str
|
||||
worker: ModelWorker
|
||||
worker_params: ModelWorkerParameters
|
||||
model_params: ModelParameters
|
||||
stop_event: asyncio.Event
|
||||
semaphore: asyncio.Semaphore = None
|
||||
command_args: List[str] = None
|
||||
_heartbeat_future: Optional[Future] = None
|
||||
_last_heartbeat: Optional[datetime] = None
|
||||
|
||||
|
||||
class WorkerManager(ABC):
|
||||
@abstractmethod
|
||||
async def start(self):
|
||||
"""Start worker manager"""
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self):
|
||||
"""Stop worker manager"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
"""Get model instances by worker type and model name"""
|
||||
|
||||
@abstractmethod
|
||||
async def select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
"""Select one instance"""
|
||||
|
||||
@abstractmethod
|
||||
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||
"""List supported models"""
|
||||
|
||||
@abstractmethod
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
"""Create and start a model instance"""
|
||||
|
||||
@abstractmethod
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
"""Shutdown model instance"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
|
||||
"""Generate stream result, chat scene"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
|
||||
@abstractmethod
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Embed input"""
|
||||
|
||||
@abstractmethod
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
"""Worker apply"""
|
||||
|
||||
@abstractmethod
|
||||
async def parameter_descriptions(
|
||||
self, worker_type: str, model_name: str
|
||||
) -> List[ParameterDescription]:
|
||||
"""Get parameter descriptions of model"""
|
@@ -7,7 +7,7 @@ from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
@@ -7,7 +7,7 @@ from pilot.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
WorkerType,
|
||||
)
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
@@ -4,13 +4,11 @@ import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import (
|
||||
@@ -18,103 +16,34 @@ from pilot.model.base import (
|
||||
ModelOutput,
|
||||
WorkerApplyOutput,
|
||||
WorkerApplyType,
|
||||
WorkerSupportedModel,
|
||||
)
|
||||
from pilot.model.controller.registry import ModelRegistry
|
||||
from pilot.model.cluster.registry import ModelRegistry
|
||||
from pilot.model.llm_utils import list_supported_models
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.manager_base import WorkerManager, WorkerRunData
|
||||
from pilot.model.cluster.base import *
|
||||
from pilot.utils import build_logger
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser, ParameterDescription
|
||||
from pydantic import BaseModel
|
||||
from pilot.utils.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
ParameterDescription,
|
||||
_dict_to_command_args,
|
||||
)
|
||||
|
||||
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
messages: List[ModelMessage]
|
||||
model: str
|
||||
prompt: str = None
|
||||
temperature: float = None
|
||||
max_new_tokens: int = None
|
||||
stop: str = None
|
||||
echo: bool = True
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
model: str
|
||||
input: List[str]
|
||||
|
||||
|
||||
class WorkerApplyRequest(BaseModel):
|
||||
model: str
|
||||
apply_type: WorkerApplyType
|
||||
worker_type: WorkerType = WorkerType.LLM
|
||||
params: Dict = None
|
||||
apply_user: str = None
|
||||
|
||||
|
||||
class WorkerParameterRequest(BaseModel):
|
||||
model: str
|
||||
worker_type: WorkerType = WorkerType.LLM
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerRunData:
|
||||
worker_key: str
|
||||
worker: ModelWorker
|
||||
worker_params: ModelWorkerParameters
|
||||
model_params: ModelParameters
|
||||
stop_event: asyncio.Event
|
||||
semaphore: asyncio.Semaphore = None
|
||||
command_args: List[str] = None
|
||||
_heartbeat_future: Optional[Future] = None
|
||||
_last_heartbeat: Optional[datetime] = None
|
||||
|
||||
|
||||
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
ApplyFunction = Callable[[WorkerRunData], Awaitable[None]]
|
||||
|
||||
|
||||
class WorkerManager(ABC):
|
||||
@abstractmethod
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
"""Get model instances by worker type and model name"""
|
||||
|
||||
@abstractmethod
|
||||
async def select_one_instanes(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
"""Select one instances"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
|
||||
"""Generate stream result, chat scene"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
|
||||
@abstractmethod
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Embed input"""
|
||||
|
||||
@abstractmethod
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
"""Worker apply"""
|
||||
|
||||
@abstractmethod
|
||||
async def parameter_descriptions(
|
||||
self, worker_type: str, model_name: str
|
||||
) -> List[ParameterDescription]:
|
||||
"""Get parameter descriptions of model"""
|
||||
|
||||
|
||||
async def _async_heartbeat_sender(
|
||||
worker_run_data: WorkerRunData, send_heartbeat_func: SendHeartbeatFunc
|
||||
worker_run_data: WorkerRunData,
|
||||
heartbeat_interval,
|
||||
send_heartbeat_func: SendHeartbeatFunc,
|
||||
):
|
||||
while not worker_run_data.stop_event.is_set():
|
||||
try:
|
||||
@@ -122,7 +51,7 @@ async def _async_heartbeat_sender(
|
||||
except Exception as e:
|
||||
logger.warn(f"Send heartbeat func error: {str(e)}")
|
||||
finally:
|
||||
await asyncio.sleep(worker_run_data.worker_params.heartbeat_interval)
|
||||
await asyncio.sleep(heartbeat_interval)
|
||||
|
||||
|
||||
class LocalWorkerManager(WorkerManager):
|
||||
@@ -132,6 +61,8 @@ class LocalWorkerManager(WorkerManager):
|
||||
deregister_func: DeregisterFunc = None,
|
||||
send_heartbeat_func: SendHeartbeatFunc = None,
|
||||
model_registry: ModelRegistry = None,
|
||||
host: str = None,
|
||||
port: int = None,
|
||||
) -> None:
|
||||
self.workers: Dict[str, List[WorkerRunData]] = dict()
|
||||
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
||||
@@ -139,19 +70,58 @@ class LocalWorkerManager(WorkerManager):
|
||||
self.deregister_func = deregister_func
|
||||
self.send_heartbeat_func = send_heartbeat_func
|
||||
self.model_registry = model_registry
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.run_data = WorkerRunData(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
worker_key=self._worker_key(
|
||||
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
|
||||
),
|
||||
worker=None,
|
||||
worker_params=None,
|
||||
model_params=None,
|
||||
stop_event=asyncio.Event(),
|
||||
semaphore=None,
|
||||
command_args=None,
|
||||
)
|
||||
|
||||
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
||||
if isinstance(worker_type, WorkerType):
|
||||
worker_type = worker_type.value
|
||||
return f"{model_name}@{worker_type}"
|
||||
|
||||
async def run_blocking_func(self, func, *args):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, func, *args)
|
||||
|
||||
async def start(self):
|
||||
if len(self.workers) > 0:
|
||||
await self._start_all_worker(apply_req=None)
|
||||
if self.register_func:
|
||||
await self.register_func(self.run_data)
|
||||
if self.send_heartbeat_func:
|
||||
asyncio.create_task(
|
||||
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
if not self.run_data.stop_event.is_set():
|
||||
logger.info("Stop all workers")
|
||||
self.run_data.stop_event.clear()
|
||||
stop_tasks = []
|
||||
stop_tasks.append(self._stop_all_worker(apply_req=None))
|
||||
if self.deregister_func:
|
||||
stop_tasks.append(self.deregister_func(self.run_data))
|
||||
await asyncio.gather(*stop_tasks)
|
||||
|
||||
def add_worker(
|
||||
self,
|
||||
worker: ModelWorker,
|
||||
worker_params: ModelWorkerParameters,
|
||||
embedded_mod: bool = True,
|
||||
command_args: List[str] = None,
|
||||
):
|
||||
) -> bool:
|
||||
if not command_args:
|
||||
import sys
|
||||
|
||||
@@ -179,6 +149,8 @@ class LocalWorkerManager(WorkerManager):
|
||||
model_params = worker.parse_parameters(command_args=command_args)
|
||||
|
||||
worker_run_data = WorkerRunData(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
worker_key=worker_key,
|
||||
worker=worker,
|
||||
worker_params=worker_params,
|
||||
@@ -187,14 +159,66 @@ class LocalWorkerManager(WorkerManager):
|
||||
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
|
||||
command_args=command_args,
|
||||
)
|
||||
if not embedded_mod:
|
||||
exist_instances = [
|
||||
(w, p) for w, p in instances if p.host == host and p.port == port
|
||||
]
|
||||
if not exist_instances:
|
||||
instances.append(worker_run_data)
|
||||
else:
|
||||
exist_instances = [
|
||||
ins for ins in instances if ins.host == host and ins.port == port
|
||||
]
|
||||
if not exist_instances:
|
||||
instances.append(worker_run_data)
|
||||
return True
|
||||
else:
|
||||
# TODO Update worker
|
||||
return False
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
"""Start model"""
|
||||
model_name = startup_req.model
|
||||
worker_type = startup_req.worker_type
|
||||
params = startup_req.params
|
||||
logger.debug(
|
||||
f"start model, model name {model_name}, worker type {worker_type}, params: {params}"
|
||||
)
|
||||
worker_params: ModelWorkerParameters = ModelWorkerParameters.from_dict(
|
||||
params, ignore_extra_fields=True
|
||||
)
|
||||
if not worker_params.model_name:
|
||||
worker_params.model_name = model_name
|
||||
assert model_name == worker_params.model_name
|
||||
worker = _build_worker(worker_params)
|
||||
command_args = _dict_to_command_args(params)
|
||||
success = await self.run_blocking_func(
|
||||
self.add_worker, worker, worker_params, command_args
|
||||
)
|
||||
if not success:
|
||||
logger.warn(
|
||||
f"Add worker failed, worker instances is exist, worker_params: {worker_params}"
|
||||
)
|
||||
return False
|
||||
supported_types = WorkerType.values()
|
||||
if worker_type not in supported_types:
|
||||
raise ValueError(
|
||||
f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}"
|
||||
)
|
||||
start_apply_req = WorkerApplyRequest(
|
||||
model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type
|
||||
)
|
||||
await self.worker_apply(start_apply_req)
|
||||
return True
|
||||
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}")
|
||||
apply_req = WorkerApplyRequest(
|
||||
model=shutdown_req.model,
|
||||
apply_type=WorkerApplyType.STOP,
|
||||
worker_type=shutdown_req.worker_type,
|
||||
)
|
||||
out = await self._stop_all_worker(apply_req)
|
||||
if out.success:
|
||||
return True
|
||||
raise Exception(out.message)
|
||||
|
||||
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||
models = await self.run_blocking_func(list_supported_models)
|
||||
return [WorkerSupportedModel(host=self.host, port=self.port, models=models)]
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
@@ -202,7 +226,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
return self.workers.get(worker_key)
|
||||
|
||||
async def select_one_instanes(
|
||||
async def select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
worker_instances = await self.get_model_instances(
|
||||
@@ -219,7 +243,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
model = params.get("model")
|
||||
if not model:
|
||||
raise Exception("Model name count not be empty")
|
||||
return await self.select_one_instanes(worker_type, model, healthy_only=True)
|
||||
return await self.select_one_instance(worker_type, model, healthy_only=True)
|
||||
|
||||
async def generate_stream(
|
||||
self, params: Dict, async_wrapper=None, **kwargs
|
||||
@@ -262,9 +286,8 @@ class LocalWorkerManager(WorkerManager):
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_generate(params)
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor, worker_run_data.worker.generate, params
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.generate, params
|
||||
)
|
||||
|
||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
@@ -277,9 +300,8 @@ class LocalWorkerManager(WorkerManager):
|
||||
if worker_run_data.worker.support_async():
|
||||
return await worker_run_data.worker.async_embeddings(params)
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor, worker_run_data.worker.embeddings, params
|
||||
return await self.run_blocking_func(
|
||||
worker_run_data.worker.embeddings, params
|
||||
)
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
@@ -342,8 +364,10 @@ class LocalWorkerManager(WorkerManager):
|
||||
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
||||
|
||||
async def _start_worker(worker_run_data: WorkerRunData):
|
||||
worker_run_data.worker.start(
|
||||
worker_run_data.model_params, worker_run_data.command_args
|
||||
await self.run_blocking_func(
|
||||
worker_run_data.worker.start,
|
||||
worker_run_data.model_params,
|
||||
worker_run_data.command_args,
|
||||
)
|
||||
worker_run_data.stop_event.clear()
|
||||
if worker_run_data.worker_params.register and self.register_func:
|
||||
@@ -355,7 +379,9 @@ class LocalWorkerManager(WorkerManager):
|
||||
):
|
||||
asyncio.create_task(
|
||||
_async_heartbeat_sender(
|
||||
worker_run_data, self.send_heartbeat_func
|
||||
worker_run_data,
|
||||
worker_run_data.worker_params.heartbeat_interval,
|
||||
self.send_heartbeat_func,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -371,7 +397,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
start_time = time.time()
|
||||
|
||||
async def _stop_worker(worker_run_data: WorkerRunData):
|
||||
worker_run_data.worker.stop()
|
||||
await self.run_blocking_func(worker_run_data.worker.stop)
|
||||
# Set stop event
|
||||
worker_run_data.stop_event.set()
|
||||
if worker_run_data._heartbeat_future:
|
||||
@@ -422,63 +448,25 @@ class LocalWorkerManager(WorkerManager):
|
||||
return WorkerApplyOutput(message=message, timecost=timecost)
|
||||
|
||||
|
||||
class RemoteWorkerManager(LocalWorkerManager):
|
||||
def __init__(self, model_registry: ModelRegistry = None) -> None:
|
||||
super().__init__(model_registry=model_registry)
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
from pilot.model.worker.remote_worker import RemoteModelWorker
|
||||
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
|
||||
worker_key, healthy_only
|
||||
)
|
||||
worker_instances = []
|
||||
for ins in instances:
|
||||
worker = RemoteModelWorker()
|
||||
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
|
||||
wr = WorkerRunData(
|
||||
worker_key=ins.model_name,
|
||||
worker=worker,
|
||||
worker_params=None,
|
||||
model_params=None,
|
||||
stop_event=asyncio.Event(),
|
||||
semaphore=asyncio.Semaphore(100), # Not limit in client
|
||||
)
|
||||
worker_instances.append(wr)
|
||||
return worker_instances
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
import httpx
|
||||
|
||||
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
||||
worker_addr = worker_run_data.worker.worker_addr
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
worker_addr + "/apply",
|
||||
headers=worker_run_data.worker.headers,
|
||||
json=apply_req.dict(),
|
||||
timeout=worker_run_data.worker.timeout,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
output = WorkerApplyOutput(**response.json())
|
||||
logger.info(f"worker_apply success: {output}")
|
||||
else:
|
||||
output = WorkerApplyOutput(message=response.text)
|
||||
logger.warn(f"worker_apply failed: {output}")
|
||||
return output
|
||||
|
||||
results = await self._apply_worker(apply_req, _remote_apply_func)
|
||||
if results:
|
||||
return results[0]
|
||||
|
||||
|
||||
class WorkerManagerAdapter(WorkerManager):
|
||||
def __init__(self, worker_manager: WorkerManager = None) -> None:
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
async def start(self):
|
||||
return await self.worker_manager.start()
|
||||
|
||||
async def stop(self):
|
||||
return await self.worker_manager.stop()
|
||||
|
||||
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||
return await self.worker_manager.supported_models()
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
return await self.worker_manager.model_startup(startup_req)
|
||||
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
return await self.worker_manager.model_shutdown(shutdown_req)
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
@@ -486,10 +474,10 @@ class WorkerManagerAdapter(WorkerManager):
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
async def select_one_instanes(
|
||||
async def select_one_instance(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> WorkerRunData:
|
||||
return await self.worker_manager.select_one_instanes(
|
||||
return await self.worker_manager.select_one_instance(
|
||||
worker_type, model_name, healthy_only
|
||||
)
|
||||
|
||||
@@ -535,37 +523,58 @@ async def api_generate_stream(request: PromptRequest):
|
||||
@router.post("/worker/generate")
|
||||
async def api_generate(request: PromptRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
output = await worker_manager.generate(params)
|
||||
return output
|
||||
return await worker_manager.generate(params)
|
||||
|
||||
|
||||
@router.post("/worker/embeddings")
|
||||
async def api_embeddings(request: EmbeddingsRequest):
|
||||
params = request.dict(exclude_none=True)
|
||||
output = await worker_manager.embeddings(params)
|
||||
return output
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
|
||||
@router.post("/worker/apply")
|
||||
async def api_worker_apply(request: WorkerApplyRequest):
|
||||
output = await worker_manager.worker_apply(request)
|
||||
return output
|
||||
return await worker_manager.worker_apply(request)
|
||||
|
||||
|
||||
@router.get("/worker/parameter/descriptions")
|
||||
async def api_worker_parameter_descs(
|
||||
model: str, worker_type: str = WorkerType.LLM.value
|
||||
):
|
||||
output = await worker_manager.parameter_descriptions(worker_type, model)
|
||||
return output
|
||||
return await worker_manager.parameter_descriptions(worker_type, model)
|
||||
|
||||
|
||||
@router.get("/worker/models/supports")
|
||||
async def api_supported_models():
|
||||
"""Get all supported models.
|
||||
|
||||
This method reads all models from the configuration file and tries to perform some basic checks on the model (like if the path exists).
|
||||
|
||||
If it's a RemoteWorkerManager, this method returns the list of models supported by the entire cluster.
|
||||
"""
|
||||
return await worker_manager.supported_models()
|
||||
|
||||
|
||||
@router.post("/worker/models/startup")
|
||||
async def api_model_startup(request: WorkerStartupRequest):
|
||||
"""Start up a specific model."""
|
||||
return await worker_manager.model_startup(request)
|
||||
|
||||
|
||||
@router.post("/worker/models/shutdown")
|
||||
async def api_model_shutdown(request: WorkerStartupRequest):
|
||||
"""Shut down a specific model."""
|
||||
return await worker_manager.model_shutdown(request)
|
||||
|
||||
|
||||
def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
||||
if not app:
|
||||
app = FastAPI()
|
||||
if worker_params.standalone:
|
||||
from pilot.model.controller.controller import router as controller_router
|
||||
from pilot.model.controller.controller import initialize_controller
|
||||
from pilot.model.cluster.controller.controller import initialize_controller
|
||||
from pilot.model.cluster.controller.controller import (
|
||||
router as controller_router,
|
||||
)
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
||||
@@ -577,9 +586,11 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
asyncio.create_task(
|
||||
worker_manager.worker_manager._start_all_worker(apply_req=None)
|
||||
)
|
||||
asyncio.create_task(worker_manager.worker_manager.start())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def startup_event():
|
||||
await worker_manager.worker_manager.stop()
|
||||
|
||||
return app
|
||||
|
||||
@@ -609,22 +620,23 @@ def _parse_worker_params(
|
||||
def _create_local_model_manager(
|
||||
worker_params: ModelWorkerParameters,
|
||||
) -> LocalWorkerManager:
|
||||
from pilot.utils.net_utils import _get_ip_address
|
||||
|
||||
host = (
|
||||
worker_params.worker_register_host
|
||||
if worker_params.worker_register_host
|
||||
else _get_ip_address()
|
||||
)
|
||||
port = worker_params.port
|
||||
if not worker_params.register or not worker_params.controller_addr:
|
||||
logger.info(
|
||||
f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}"
|
||||
)
|
||||
return LocalWorkerManager()
|
||||
return LocalWorkerManager(host=host, port=port)
|
||||
else:
|
||||
from pilot.model.controller.controller import ModelRegistryClient
|
||||
from pilot.utils.net_utils import _get_ip_address
|
||||
from pilot.model.cluster.controller.controller import ModelRegistryClient
|
||||
|
||||
client = ModelRegistryClient(worker_params.controller_addr)
|
||||
host = (
|
||||
worker_params.worker_register_host
|
||||
if worker_params.worker_register_host
|
||||
else _get_ip_address()
|
||||
)
|
||||
port = worker_params.port
|
||||
|
||||
async def register_func(worker_run_data: WorkerRunData):
|
||||
instance = ModelInstance(
|
||||
@@ -648,31 +660,33 @@ def _create_local_model_manager(
|
||||
register_func=register_func,
|
||||
deregister_func=deregister_func,
|
||||
send_heartbeat_func=send_heartbeat_func,
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
|
||||
def _start_local_worker(
|
||||
worker_manager: WorkerManagerAdapter,
|
||||
worker_params: ModelWorkerParameters,
|
||||
embedded_mod=True,
|
||||
):
|
||||
from pilot.utils.module_utils import import_from_checked_string
|
||||
|
||||
def _build_worker(worker_params: ModelWorkerParameters):
|
||||
if worker_params.worker_class:
|
||||
from pilot.utils.module_utils import import_from_checked_string
|
||||
|
||||
worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker)
|
||||
logger.info(
|
||||
f"Import worker class from {worker_params.worker_class} successfully"
|
||||
)
|
||||
worker: ModelWorker = worker_cls()
|
||||
else:
|
||||
from pilot.model.worker.default_worker import DefaultModelWorker
|
||||
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
|
||||
|
||||
worker = DefaultModelWorker()
|
||||
return worker
|
||||
|
||||
|
||||
def _start_local_worker(
|
||||
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
|
||||
):
|
||||
worker = _build_worker(worker_params)
|
||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||
worker_manager.worker_manager.add_worker(
|
||||
worker, worker_params, embedded_mod=embedded_mod
|
||||
)
|
||||
worker_manager.worker_manager.add_worker(worker, worker_params)
|
||||
|
||||
|
||||
def initialize_worker_manager_in_client(
|
||||
@@ -713,16 +727,13 @@ def initialize_worker_manager_in_client(
|
||||
worker_params.port = local_port
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
_setup_fastapi(worker_params, app)
|
||||
_start_local_worker(worker_manager, worker_params, True)
|
||||
# loop = asyncio.get_event_loop()
|
||||
# loop.run_until_complete(
|
||||
# worker_manager.worker_manager._start_all_worker(apply_req=None)
|
||||
# )
|
||||
_start_local_worker(worker_manager, worker_params)
|
||||
else:
|
||||
from pilot.model.controller.controller import (
|
||||
initialize_controller,
|
||||
from pilot.model.cluster.controller.controller import (
|
||||
ModelRegistryClient,
|
||||
initialize_controller,
|
||||
)
|
||||
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
raise ValueError("Controller can`t be None")
|
||||
@@ -758,13 +769,11 @@ def run_worker_manager(
|
||||
# Run worker manager independently
|
||||
embedded_mod = False
|
||||
app = _setup_fastapi(worker_params)
|
||||
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
|
||||
_start_local_worker(worker_manager, worker_params)
|
||||
else:
|
||||
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
|
||||
_start_local_worker(worker_manager, worker_params)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(
|
||||
worker_manager.worker_manager._start_all_worker(apply_req=None)
|
||||
)
|
||||
loop.run_until_complete(worker_manager.worker_manager.start())
|
||||
|
||||
if include_router:
|
||||
app.include_router(router, prefix="/api")
|
0
pilot/model/cluster/worker/ray_worker.py
Normal file
0
pilot/model/cluster/worker/ray_worker.py
Normal file
169
pilot/model/cluster/worker/remote_manager.py
Normal file
169
pilot/model/cluster/worker/remote_manager.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from typing import Callable, Any
|
||||
import httpx
|
||||
import asyncio
|
||||
from pilot.model.cluster.registry import ModelRegistry
|
||||
from pilot.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger
|
||||
from pilot.model.cluster.base import *
|
||||
from pilot.model.base import (
|
||||
ModelInstance,
|
||||
WorkerApplyOutput,
|
||||
WorkerSupportedModel,
|
||||
)
|
||||
|
||||
|
||||
class RemoteWorkerManager(LocalWorkerManager):
|
||||
def __init__(self, model_registry: ModelRegistry = None) -> None:
|
||||
super().__init__(model_registry=model_registry)
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
async def stop(self):
|
||||
pass
|
||||
|
||||
async def _fetch_from_worker(
|
||||
self,
|
||||
worker_run_data: WorkerRunData,
|
||||
endpoint: str,
|
||||
method: str = "GET",
|
||||
json: dict = None,
|
||||
params: dict = None,
|
||||
additional_headers: dict = None,
|
||||
success_handler: Callable = None,
|
||||
error_handler: Callable = None,
|
||||
) -> Any:
|
||||
url = worker_run_data.worker.worker_addr + endpoint
|
||||
headers = {**worker_run_data.worker.headers, **(additional_headers or {})}
|
||||
timeout = worker_run_data.worker.timeout
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
request = client.build_request(
|
||||
method,
|
||||
url,
|
||||
json=json, # using json for data to ensure it sends as application/json
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
response = await client.send(request)
|
||||
if response.status_code != 200:
|
||||
if error_handler:
|
||||
return error_handler(response)
|
||||
else:
|
||||
error_msg = f"Request to {url} failed, error: {response.text}"
|
||||
raise Exception(error_msg)
|
||||
if success_handler:
|
||||
return success_handler(response)
|
||||
return response.json()
|
||||
|
||||
async def _apply_to_worker_manager_instances(self):
|
||||
pass
|
||||
|
||||
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||
worker_instances = await self.get_model_instances(
|
||||
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
|
||||
)
|
||||
|
||||
async def get_supported_models(worker_run_data) -> List[WorkerSupportedModel]:
|
||||
def handler(response):
|
||||
return list(WorkerSupportedModel.from_dict(m) for m in response.json())
|
||||
|
||||
return await self._fetch_from_worker(
|
||||
worker_run_data, "/models/supports", success_handler=handler
|
||||
)
|
||||
|
||||
models = []
|
||||
results = await asyncio.gather(
|
||||
*(get_supported_models(worker) for worker in worker_instances)
|
||||
)
|
||||
for res in results:
|
||||
models += res
|
||||
return models
|
||||
|
||||
async def _get_worker_service_instance(
|
||||
self, host: str = None, port: int = None
|
||||
) -> List[WorkerRunData]:
|
||||
worker_instances = await self.get_model_instances(
|
||||
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
|
||||
)
|
||||
error_msg = f"Cound not found worker instances"
|
||||
if host and port:
|
||||
worker_instances = [
|
||||
ins for ins in worker_instances if ins.host == host and ins.port == port
|
||||
]
|
||||
error_msg = f"Cound not found worker instances for host {host} port {port}"
|
||||
if not worker_instances:
|
||||
raise Exception(error_msg)
|
||||
return worker_instances
|
||||
|
||||
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||
worker_instances = await self._get_worker_service_instance(
|
||||
startup_req.host, startup_req.port
|
||||
)
|
||||
worker_run_data = worker_instances[0]
|
||||
logger.info(f"Start model remote, startup_req: {startup_req}")
|
||||
return await self._fetch_from_worker(
|
||||
worker_run_data,
|
||||
"/models/startup",
|
||||
method="POST",
|
||||
json=startup_req.dict(),
|
||||
success_handler=lambda x: True,
|
||||
)
|
||||
|
||||
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||
worker_instances = await self._get_worker_service_instance(
|
||||
shutdown_req.host, shutdown_req.port
|
||||
)
|
||||
worker_run_data = worker_instances[0]
|
||||
logger.info(f"Shutdown model remote, shutdown_req: {shutdown_req}")
|
||||
return await self._fetch_from_worker(
|
||||
worker_run_data,
|
||||
"/models/shutdown",
|
||||
method="POST",
|
||||
json=shutdown_req.dict(),
|
||||
success_handler=lambda x: True,
|
||||
)
|
||||
|
||||
async def get_model_instances(
|
||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||
) -> List[WorkerRunData]:
|
||||
from pilot.model.cluster.worker.remote_worker import RemoteModelWorker
|
||||
|
||||
worker_key = self._worker_key(worker_type, model_name)
|
||||
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
|
||||
worker_key, healthy_only
|
||||
)
|
||||
worker_instances = []
|
||||
for ins in instances:
|
||||
worker = RemoteModelWorker()
|
||||
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
|
||||
wr = WorkerRunData(
|
||||
host=ins.host,
|
||||
port=ins.port,
|
||||
worker_key=ins.model_name,
|
||||
worker=worker,
|
||||
worker_params=None,
|
||||
model_params=None,
|
||||
stop_event=asyncio.Event(),
|
||||
semaphore=asyncio.Semaphore(100), # Not limit in client
|
||||
)
|
||||
worker_instances.append(wr)
|
||||
return worker_instances
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
||||
return await self._fetch_from_worker(
|
||||
worker_run_data,
|
||||
"/apply",
|
||||
method="POST",
|
||||
json=apply_req.dict(),
|
||||
success_handler=lambda res: WorkerApplyOutput(**res.json()),
|
||||
error_handler=lambda res: WorkerApplyOutput(
|
||||
message=res.text, success=False
|
||||
),
|
||||
)
|
||||
|
||||
results = await self._apply_worker(apply_req, _remote_apply_func)
|
||||
if results:
|
||||
return results[0]
|
@@ -3,7 +3,7 @@ from typing import Dict, Iterator, List
|
||||
import logging
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
|
||||
|
||||
class RemoteModelWorker(ModelWorker):
|
@@ -2,14 +2,18 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
import transformers
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict
|
||||
import cachetools
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.model.base import Message
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
|
||||
from pilot.model.base import Message, SupportedModel
|
||||
from pilot.utils.parameter_utils import _get_parameter_descriptions
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
@@ -128,3 +132,49 @@ def is_partial_stop(output: str, stop_str: str):
|
||||
if stop_str.startswith(output[-i:]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@cachetools.cached(cachetools.TTLCache(maxsize=100, ttl=60 * 5))
|
||||
def list_supported_models():
|
||||
from pilot.model.parameter import WorkerType
|
||||
|
||||
models = _list_supported_models(WorkerType.LLM.value, LLM_MODEL_CONFIG)
|
||||
models += _list_supported_models(WorkerType.TEXT2VEC.value, EMBEDDING_MODEL_CONFIG)
|
||||
return models
|
||||
|
||||
|
||||
def _list_supported_models(
|
||||
worker_type: str, model_config: Dict[str, str]
|
||||
) -> List[SupportedModel]:
|
||||
from pilot.model.adapter import get_llm_model_adapter
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.loader import _get_model_real_path
|
||||
|
||||
ret = []
|
||||
for model_name, model_path in model_config.items():
|
||||
model_path = _get_model_real_path(model_name, model_path)
|
||||
model = SupportedModel(
|
||||
model=model_name,
|
||||
path=model_path,
|
||||
worker_type=worker_type,
|
||||
path_exist=False,
|
||||
proxy=False,
|
||||
enabled=False,
|
||||
params=None,
|
||||
)
|
||||
if "proxyllm" in model_name:
|
||||
model.proxy = True
|
||||
else:
|
||||
path = Path(model_path)
|
||||
model.path_exist = path.exists()
|
||||
param_cls = None
|
||||
try:
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||
param_cls = llm_adapter.model_param_class()
|
||||
model.enabled = True
|
||||
params = _get_parameter_descriptions(param_cls)
|
||||
model.params = params
|
||||
except Exception:
|
||||
pass
|
||||
ret.append(model)
|
||||
return ret
|
||||
|
@@ -353,5 +353,6 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam
|
||||
def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters):
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
logger.info("Load proxyllm")
|
||||
model = ProxyModel(model_params)
|
||||
return model, model
|
||||
|
@@ -33,9 +33,13 @@ class ModelControllerParameters(BaseParameters):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWorkerParameters(BaseParameters):
|
||||
class BaseModelParameters(BaseParameters):
|
||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWorkerParameters(BaseModelParameters):
|
||||
worker_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"valid_values": WorkerType.values(), "help": "Worker type"},
|
||||
@@ -84,9 +88,7 @@ class ModelWorkerParameters(BaseParameters):
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelParameters(BaseParameters):
|
||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
||||
class EmbeddingModelParameters(BaseModelParameters):
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@@ -114,12 +116,6 @@ class EmbeddingModelParameters(BaseParameters):
|
||||
return kwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelParameters(BaseParameters):
|
||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParameters(BaseModelParameters):
|
||||
device: Optional[str] = field(
|
||||
|
@@ -139,7 +139,7 @@ class BaseChat(ABC):
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
from pilot.model.worker.manager import worker_manager
|
||||
from pilot.model.cluster import worker_manager
|
||||
|
||||
async for output in worker_manager.generate_stream(payload):
|
||||
yield output
|
||||
@@ -157,7 +157,7 @@ class BaseChat(ABC):
|
||||
logger.info(f"Request: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
from pilot.model.worker.manager import worker_manager
|
||||
from pilot.model.cluster import worker_manager
|
||||
|
||||
model_output = await worker_manager.generate(payload)
|
||||
|
||||
|
@@ -6,7 +6,7 @@ from pilot.configs.config import Config
|
||||
|
||||
from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||
@@ -48,7 +48,7 @@ class ChatKnowledge(BaseChat):
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
|
@@ -23,7 +23,7 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||
from pilot.openapi.base import validation_exception_handler
|
||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||
from pilot.model.worker.manager import initialize_worker_manager_in_client
|
||||
from pilot.model.cluster import initialize_worker_manager_in_client
|
||||
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
@@ -7,7 +7,10 @@ from fastapi import APIRouter, File, UploadFile, Form
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
|
||||
from pilot.openapi.api_view_model import Result
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
@@ -29,7 +32,9 @@ CFG = Config()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL])
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
|
||||
knowledge_space_service = KnowledgeService()
|
||||
|
||||
|
@@ -5,7 +5,10 @@ from datetime import datetime
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from pilot.logs import logger
|
||||
from pilot.server.knowledge.chunk_db import (
|
||||
DocumentChunkEntity,
|
||||
@@ -204,7 +207,7 @@ class KnowledgeService:
|
||||
client = EmbeddingEngine(
|
||||
knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={
|
||||
"vector_store_name": space_name,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
@@ -341,7 +344,7 @@ class KnowledgeService:
|
||||
"topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
"recall_score": 0.0,
|
||||
"recall_type": "TopK",
|
||||
"model": LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL].rsplit("/", 1)[-1],
|
||||
"model": EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL].rsplit("/", 1)[-1],
|
||||
"chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
"chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP,
|
||||
},
|
||||
|
@@ -9,7 +9,7 @@ sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.model.worker.manager import run_worker_manager
|
||||
from pilot.model.cluster import run_worker_manager
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
@@ -5,7 +5,7 @@ from pilot.common.schema import DBType
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
from pilot.scene.base import ChatScene
|
||||
@@ -36,7 +36,7 @@ class DBSummaryClient:
|
||||
|
||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_summary",
|
||||
@@ -90,7 +90,7 @@ class DBSummaryClient:
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||
@@ -108,7 +108,7 @@ class DBSummaryClient:
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
@@ -134,7 +134,7 @@ class DBSummaryClient:
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
table_summery = knowledge_embedding_client.similar_search(query, 1)
|
||||
|
@@ -4,6 +4,7 @@ import subprocess
|
||||
from typing import List, Dict
|
||||
import psutil
|
||||
import platform
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
def _get_abspath_of_current_command(command_path: str):
|
||||
@@ -137,6 +138,7 @@ def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]:
|
||||
return ports
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _detect_controller_address() -> str:
|
||||
controller_addr = os.getenv("CONTROLLER_ADDRESS")
|
||||
if controller_addr:
|
||||
|
@@ -2,7 +2,7 @@ from typing import Type
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def import_from_string(module_path: str):
|
||||
def import_from_string(module_path: str, ignore_import_error: bool = False):
|
||||
try:
|
||||
module_path, class_name = module_path.rsplit(".", 1)
|
||||
except ValueError:
|
||||
@@ -12,6 +12,8 @@ def import_from_string(module_path: str):
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
if ignore_import_error:
|
||||
return None
|
||||
raise ImportError(
|
||||
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
||||
)
|
||||
|
@@ -1,21 +1,50 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from dataclasses import dataclass, fields, MISSING, asdict, field
|
||||
from typing import Any, List, Optional, Type, Union, Callable, Dict
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterDescription:
|
||||
param_class: str
|
||||
param_name: str
|
||||
param_type: str
|
||||
description: str
|
||||
default_value: Optional[Any]
|
||||
description: str
|
||||
valid_values: Optional[List[Any]]
|
||||
ext_metadata: Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseParameters:
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = False
|
||||
) -> "BaseParameters":
|
||||
"""Create an instance of the dataclass from a dictionary.
|
||||
|
||||
Args:
|
||||
data: A dictionary containing values for the dataclass fields.
|
||||
ignore_extra_fields: If True, any extra fields in the data dictionary that are
|
||||
not part of the dataclass will be ignored.
|
||||
If False, extra fields will raise an error. Defaults to False.
|
||||
Returns:
|
||||
An instance of the dataclass with values populated from the given dictionary.
|
||||
|
||||
Raises:
|
||||
TypeError: If `ignore_extra_fields` is False and there are fields in the
|
||||
dictionary that aren't present in the dataclass.
|
||||
"""
|
||||
all_field_names = {f.name for f in fields(cls)}
|
||||
if ignore_extra_fields:
|
||||
data = {key: value for key, value in data.items() if key in all_field_names}
|
||||
else:
|
||||
extra_fields = set(data.keys()) - all_field_names
|
||||
if extra_fields:
|
||||
raise TypeError(f"Unexpected fields: {', '.join(extra_fields)}")
|
||||
return cls(**data)
|
||||
|
||||
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
||||
"""
|
||||
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
||||
@@ -68,6 +97,35 @@ class BaseParameters:
|
||||
)
|
||||
return "\n".join(parameters)
|
||||
|
||||
def to_command_args(self, args_prefix: str = "--") -> List[str]:
|
||||
"""Convert the fields of the dataclass to a list of command line arguments.
|
||||
|
||||
Args:
|
||||
args_prefix: args prefix
|
||||
Returns:
|
||||
A list of strings where each field is represented by two items:
|
||||
one for the field name prefixed by args_prefix, and one for its value.
|
||||
"""
|
||||
return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
|
||||
|
||||
|
||||
def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]:
|
||||
"""Convert dict to a list of command line arguments
|
||||
|
||||
Args:
|
||||
obj: dict
|
||||
Returns:
|
||||
A list of strings where each field is represented by two items:
|
||||
one for the field name prefixed by args_prefix, and one for its value.
|
||||
"""
|
||||
args = []
|
||||
for key, value in obj.items():
|
||||
if value is None:
|
||||
continue
|
||||
args.append(f"{args_prefix}{key}")
|
||||
args.append(str(value))
|
||||
return args
|
||||
|
||||
|
||||
def _get_simple_privacy_field_value(obj, field_info):
|
||||
"""Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary.
|
||||
@@ -81,9 +139,9 @@ def _get_simple_privacy_field_value(obj, field_info):
|
||||
- str: if length > 5, masks the middle part and returns first and last char;
|
||||
otherwise, returns "******"
|
||||
|
||||
Parameters:
|
||||
- obj: The dataclass instance.
|
||||
- field_info: A Field object that contains information about the dataclass field.
|
||||
Args:
|
||||
obj: The dataclass instance.
|
||||
field_info: A Field object that contains information about the dataclass field.
|
||||
|
||||
Returns:
|
||||
The original or modified value of the field based on the privacy rules.
|
||||
@@ -202,11 +260,42 @@ class EnvArgumentParser:
|
||||
parser.add_argument(f"--{field.name}", **argument_kwargs)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def _create_click_option_from_field(field_name: str, field: Type, is_func=True):
|
||||
import click
|
||||
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
cli_params = {
|
||||
"default": None if field.default is MISSING else field.default,
|
||||
"help": help_text,
|
||||
"show_default": True,
|
||||
"required": field.default is MISSING,
|
||||
}
|
||||
if valid_values:
|
||||
cli_params["type"] = click.Choice(valid_values)
|
||||
real_type = EnvArgumentParser._get_argparse_type(field.type)
|
||||
if real_type is int:
|
||||
cli_params["type"] = click.INT
|
||||
elif real_type is float:
|
||||
cli_params["type"] = click.FLOAT
|
||||
elif real_type is str:
|
||||
cli_params["type"] = click.STRING
|
||||
elif real_type is bool:
|
||||
cli_params["is_flag"] = True
|
||||
name = f"--{field_name}"
|
||||
if is_func:
|
||||
return click.option(
|
||||
name,
|
||||
**cli_params,
|
||||
)
|
||||
else:
|
||||
return click.Option([name], **cli_params)
|
||||
|
||||
@staticmethod
|
||||
def create_click_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
):
|
||||
import click
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -222,30 +311,8 @@ class EnvArgumentParser:
|
||||
|
||||
def decorator(func):
|
||||
for field_name, field in reversed(combined_fields.items()):
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
cli_params = {
|
||||
"default": None if field.default is MISSING else field.default,
|
||||
"help": help_text,
|
||||
"show_default": True,
|
||||
"required": field.default is MISSING,
|
||||
}
|
||||
if valid_values:
|
||||
cli_params["type"] = click.Choice(valid_values)
|
||||
real_type = EnvArgumentParser._get_argparse_type(field.type)
|
||||
if real_type is int:
|
||||
cli_params["type"] = click.INT
|
||||
elif real_type is float:
|
||||
cli_params["type"] = click.FLOAT
|
||||
elif real_type is str:
|
||||
cli_params["type"] = click.STRING
|
||||
elif real_type is bool:
|
||||
cli_params["is_flag"] = True
|
||||
|
||||
option_decorator = click.option(
|
||||
# f"--{field_name.replace('_', '-')}", **cli_params
|
||||
f"--{field_name}",
|
||||
**cli_params,
|
||||
option_decorator = EnvArgumentParser._create_click_option_from_field(
|
||||
field_name, field
|
||||
)
|
||||
func = option_decorator(func)
|
||||
|
||||
@@ -257,6 +324,23 @@ class EnvArgumentParser:
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _create_raw_click_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
):
|
||||
combined_fields = _merge_dataclass_types(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
)
|
||||
options = []
|
||||
|
||||
for field_name, field in reversed(combined_fields.items()):
|
||||
options.append(
|
||||
EnvArgumentParser._create_click_option_from_field(
|
||||
field_name, field, is_func=False
|
||||
)
|
||||
)
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def create_argparse_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
@@ -366,21 +450,70 @@ def _merge_dataclass_types(
|
||||
return combined_fields
|
||||
|
||||
|
||||
def _type_str_to_python_type(type_str: str) -> Type:
|
||||
type_mapping: Dict[str, Type] = {
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
}
|
||||
return type_mapping.get(type_str, str)
|
||||
|
||||
|
||||
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
||||
descriptions = []
|
||||
for field in fields(dataclass_type):
|
||||
ext_metadata = {
|
||||
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
|
||||
}
|
||||
|
||||
descriptions.append(
|
||||
ParameterDescription(
|
||||
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
|
||||
param_name=field.name,
|
||||
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
||||
description=field.metadata.get("help", None),
|
||||
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
|
||||
default_value=field.default if field.default != MISSING else None,
|
||||
valid_values=field.metadata.get("valid_values", None),
|
||||
ext_metadata=ext_metadata,
|
||||
)
|
||||
)
|
||||
return descriptions
|
||||
|
||||
|
||||
def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
||||
from pilot.utils.module_utils import import_from_string
|
||||
|
||||
if not desc:
|
||||
raise ValueError("Parameter descriptions cant be empty")
|
||||
param_class_str = desc[0].param_class
|
||||
param_class = import_from_string(param_class_str, ignore_import_error=True)
|
||||
if param_class:
|
||||
return param_class
|
||||
module_name, _, class_name = param_class_str.rpartition(".")
|
||||
|
||||
fields_dict = {} # This will store field names and their default values or field()
|
||||
annotations = {} # This will store the type annotations for the fields
|
||||
|
||||
for d in desc:
|
||||
metadata = d.ext_metadata if d.ext_metadata else {}
|
||||
metadata["help"] = d.description
|
||||
metadata["valid_values"] = d.valid_values
|
||||
|
||||
annotations[d.param_name] = _type_str_to_python_type(
|
||||
d.param_type
|
||||
) # Set type annotation
|
||||
fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata)
|
||||
|
||||
# Create the new class. Note the setting of __annotations__ for type hints
|
||||
new_class = type(
|
||||
class_name, (object,), {**fields_dict, "__annotations__": annotations}
|
||||
)
|
||||
result_class = dataclass(new_class) # Make it a dataclass
|
||||
|
||||
return result_class
|
||||
|
||||
|
||||
class _SimpleArgParser:
|
||||
def __init__(self, *args):
|
||||
self.params = {arg.replace("_", "-"): None for arg in args}
|
||||
@@ -422,3 +555,24 @@ class _SimpleArgParser:
|
||||
return "\n".join(
|
||||
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
|
||||
)
|
||||
|
||||
|
||||
def build_lazy_click_command(*dataclass_types: Type, _dynamic_factory=None):
|
||||
import click
|
||||
|
||||
class LazyCommand(click.Command):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LazyCommand, self).__init__(*args, **kwargs)
|
||||
self.dynamic_params_added = False
|
||||
|
||||
def get_params(self, ctx):
|
||||
if ctx and not self.dynamic_params_added:
|
||||
dynamic_params = EnvArgumentParser._create_raw_click_option(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
)
|
||||
for param in reversed(dynamic_params):
|
||||
self.params.append(param)
|
||||
self.dynamic_params_added = True
|
||||
return super(LazyCommand, self).get_params(ctx)
|
||||
|
||||
return LazyCommand
|
||||
|
@@ -29,10 +29,10 @@ def knownledge_tovec_st(filename):
|
||||
"""Use sentence transformers to embedding the document.
|
||||
https://github.com/UKPLab/sentence-transformers
|
||||
"""
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
|
||||
)
|
||||
|
||||
with open(filename, "r") as f:
|
||||
@@ -57,10 +57,10 @@ def load_knownledge_from_doc():
|
||||
"Not Exists Local DataSets, We will answers the Question use model default."
|
||||
)
|
||||
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
|
||||
)
|
||||
|
||||
files = os.listdir(DATASETS_DIR)
|
||||
|
@@ -16,7 +16,7 @@ from langchain.vectorstores import Chroma
|
||||
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
LLM_MODEL_CONFIG,
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
VECTORE_PATH,
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ class KnownLedge2Vector:
|
||||
"""
|
||||
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
|
||||
model_name = EMBEDDING_MODEL_CONFIG["sentence-transforms"]
|
||||
|
||||
def __init__(self, model_name=None) -> None:
|
||||
if not model_name:
|
||||
|
@@ -79,4 +79,5 @@ duckdb
|
||||
duckdb-engine
|
||||
|
||||
# cli
|
||||
prettytable
|
||||
prettytable
|
||||
cachetools
|
Reference in New Issue
Block a user