mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 19:40:13 +00:00
feat(model): supports the deployment of multiple models through the API and add the corresponding command line interface (#570)
- Split `LLM_MODEL_CONFIG` into `LLM_MODEL_CONFIG` and `EMBEDDING_MODEL_CONFIG`. - New HTTP API to obtain the list of models and configuration parameters supported by the current cluster. - New HTTP API to launch models on a specified machine. - The command line supports above HTTP API.
This commit is contained in:
@@ -41,25 +41,12 @@ LLM_MODEL_CONFIG = {
|
|||||||
# (Llama2 based) see https://huggingface.co/lmsys/vicuna-13b-v1.5
|
# (Llama2 based) see https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||||
"vicuna-13b-v1.5": os.path.join(MODEL_PATH, "vicuna-13b-v1.5"),
|
"vicuna-13b-v1.5": os.path.join(MODEL_PATH, "vicuna-13b-v1.5"),
|
||||||
"vicuna-7b-v1.5": os.path.join(MODEL_PATH, "vicuna-7b-v1.5"),
|
"vicuna-7b-v1.5": os.path.join(MODEL_PATH, "vicuna-7b-v1.5"),
|
||||||
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
|
||||||
# https://huggingface.co/moka-ai/m3e-large
|
|
||||||
"m3e-base": os.path.join(MODEL_PATH, "m3e-base"),
|
|
||||||
# https://huggingface.co/moka-ai/m3e-base
|
|
||||||
"m3e-large": os.path.join(MODEL_PATH, "m3e-large"),
|
|
||||||
# https://huggingface.co/BAAI/bge-large-en
|
|
||||||
"bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"),
|
|
||||||
"bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"),
|
|
||||||
# https://huggingface.co/BAAI/bge-large-zh
|
|
||||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
|
||||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
|
||||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
|
||||||
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
|
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
|
||||||
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
|
"codet5p-2b": os.path.join(MODEL_PATH, "codet5p-2b"),
|
||||||
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
|
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
|
||||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||||
"chatglm2-6b": os.path.join(MODEL_PATH, "chatglm2-6b"),
|
"chatglm2-6b": os.path.join(MODEL_PATH, "chatglm2-6b"),
|
||||||
"chatglm2-6b-int4": os.path.join(MODEL_PATH, "chatglm2-6b-int4"),
|
"chatglm2-6b-int4": os.path.join(MODEL_PATH, "chatglm2-6b-int4"),
|
||||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
|
||||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||||
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
||||||
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||||
@@ -84,6 +71,22 @@ LLM_MODEL_CONFIG = {
|
|||||||
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EMBEDDING_MODEL_CONFIG = {
|
||||||
|
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||||
|
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||||
|
# https://huggingface.co/moka-ai/m3e-large
|
||||||
|
"m3e-base": os.path.join(MODEL_PATH, "m3e-base"),
|
||||||
|
# https://huggingface.co/moka-ai/m3e-base
|
||||||
|
"m3e-large": os.path.join(MODEL_PATH, "m3e-large"),
|
||||||
|
# https://huggingface.co/BAAI/bge-large-en
|
||||||
|
"bge-large-en": os.path.join(MODEL_PATH, "bge-large-en"),
|
||||||
|
"bge-base-en": os.path.join(MODEL_PATH, "bge-base-en"),
|
||||||
|
# https://huggingface.co/BAAI/bge-large-zh
|
||||||
|
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||||
|
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||||
|
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||||
|
}
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
ISDEBUG = False
|
ISDEBUG = False
|
||||||
|
|
||||||
|
@@ -108,6 +108,17 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
|||||||
return [param_class]
|
return [param_class]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
|
||||||
|
try:
|
||||||
|
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||||
|
return llm_adapter.model_param_class()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn(
|
||||||
|
f"Parse model parameters with model name {model_name} and model {model_path} failed {str(e)}, return `ModelParameters`"
|
||||||
|
)
|
||||||
|
return ModelParameters
|
||||||
|
|
||||||
|
|
||||||
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||||
|
|
||||||
|
|
||||||
|
@@ -2,9 +2,10 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TypedDict, Optional, Dict
|
from typing import TypedDict, Optional, Dict, List
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pilot.utils.parameter_utils import ParameterDescription
|
||||||
|
|
||||||
|
|
||||||
class Message(TypedDict):
|
class Message(TypedDict):
|
||||||
@@ -46,5 +47,40 @@ class ModelOutput:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class WorkerApplyOutput:
|
class WorkerApplyOutput:
|
||||||
message: str
|
message: str
|
||||||
|
success: Optional[bool] = True
|
||||||
# The seconds cost to apply some action to worker instances
|
# The seconds cost to apply some action to worker instances
|
||||||
timecost: Optional[int] = -1
|
timecost: Optional[int] = -1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SupportedModel:
|
||||||
|
model: str
|
||||||
|
path: str
|
||||||
|
worker_type: str
|
||||||
|
path_exist: bool
|
||||||
|
proxy: bool
|
||||||
|
enabled: bool
|
||||||
|
params: List[ParameterDescription]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, model_data: Dict) -> "SupportedModel":
|
||||||
|
params = model_data.get("params", [])
|
||||||
|
if params:
|
||||||
|
params = [ParameterDescription(**param) for param in params]
|
||||||
|
model_data["params"] = params
|
||||||
|
return cls(**model_data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkerSupportedModel:
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
models: List[SupportedModel]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, worker_data: Dict) -> "WorkerSupportedModel":
|
||||||
|
models = [
|
||||||
|
SupportedModel.from_dict(model_data) for model_data in worker_data["models"]
|
||||||
|
]
|
||||||
|
worker_data["models"] = models
|
||||||
|
return cls(**worker_data)
|
||||||
|
@@ -2,19 +2,27 @@ import click
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Callable, List, Type
|
from typing import Callable, List, Type, Optional
|
||||||
|
|
||||||
from pilot.model.controller.controller import ModelRegistryClient
|
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.model.base import WorkerApplyType
|
from pilot.model.base import WorkerApplyType
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
ModelControllerParameters,
|
ModelControllerParameters,
|
||||||
ModelWorkerParameters,
|
ModelWorkerParameters,
|
||||||
ModelParameters,
|
ModelParameters,
|
||||||
|
BaseParameters,
|
||||||
)
|
)
|
||||||
from pilot.utils import get_or_create_event_loop
|
from pilot.utils import get_or_create_event_loop
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import (
|
||||||
from pilot.utils.command_utils import _run_current_with_daemon, _stop_service
|
EnvArgumentParser,
|
||||||
|
_build_parameter_class,
|
||||||
|
build_lazy_click_command,
|
||||||
|
)
|
||||||
|
from pilot.utils.command_utils import (
|
||||||
|
_run_current_with_daemon,
|
||||||
|
_stop_service,
|
||||||
|
_detect_controller_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
||||||
@@ -22,6 +30,14 @@ MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
|||||||
logger = logging.getLogger("dbgpt_cli")
|
logger = logging.getLogger("dbgpt_cli")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_worker_manager(address: str):
|
||||||
|
from pilot.model.cluster import RemoteWorkerManager, ModelRegistryClient
|
||||||
|
|
||||||
|
registry = ModelRegistryClient(address)
|
||||||
|
worker_manager = RemoteWorkerManager(registry)
|
||||||
|
return worker_manager
|
||||||
|
|
||||||
|
|
||||||
@click.group("model")
|
@click.group("model")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--address",
|
"--address",
|
||||||
@@ -38,8 +54,6 @@ def model_cli_group(address: str):
|
|||||||
"""Clients that manage model serving"""
|
"""Clients that manage model serving"""
|
||||||
global MODEL_CONTROLLER_ADDRESS
|
global MODEL_CONTROLLER_ADDRESS
|
||||||
if not address:
|
if not address:
|
||||||
from pilot.utils.command_utils import _detect_controller_address
|
|
||||||
|
|
||||||
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
|
MODEL_CONTROLLER_ADDRESS = _detect_controller_address()
|
||||||
else:
|
else:
|
||||||
MODEL_CONTROLLER_ADDRESS = address
|
MODEL_CONTROLLER_ADDRESS = address
|
||||||
@@ -55,6 +69,7 @@ def model_cli_group(address: str):
|
|||||||
def list(model_name: str, model_type: str):
|
def list(model_name: str, model_type: str):
|
||||||
"""List model instances"""
|
"""List model instances"""
|
||||||
from prettytable import PrettyTable
|
from prettytable import PrettyTable
|
||||||
|
from pilot.model.cluster import ModelRegistryClient
|
||||||
|
|
||||||
loop = get_or_create_event_loop()
|
loop = get_or_create_event_loop()
|
||||||
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
|
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
|
||||||
@@ -90,7 +105,7 @@ def list(model_name: str, model_type: str):
|
|||||||
instance.port,
|
instance.port,
|
||||||
instance.healthy,
|
instance.healthy,
|
||||||
instance.enabled,
|
instance.enabled,
|
||||||
instance.prompt_template,
|
instance.prompt_template if instance.prompt_template else "",
|
||||||
instance.last_heartbeat,
|
instance.last_heartbeat,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -122,18 +137,156 @@ def add_model_options(func):
|
|||||||
|
|
||||||
@model_cli_group.command()
|
@model_cli_group.command()
|
||||||
@add_model_options
|
@add_model_options
|
||||||
def stop(model_name: str, model_type: str):
|
@click.option(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help=("The remote host to stop model"),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help=("The remote port to stop model"),
|
||||||
|
)
|
||||||
|
def stop(model_name: str, model_type: str, host: str, port: int):
|
||||||
"""Stop model instances"""
|
"""Stop model instances"""
|
||||||
worker_apply(MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.STOP)
|
from pilot.model.cluster import WorkerStartupRequest, RemoteWorkerManager
|
||||||
|
|
||||||
|
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
|
||||||
@model_cli_group.command()
|
req = WorkerStartupRequest(
|
||||||
@add_model_options
|
host=host,
|
||||||
def start(model_name: str, model_type: str):
|
port=port,
|
||||||
"""Start model instances"""
|
worker_type=model_type,
|
||||||
worker_apply(
|
model=model_name,
|
||||||
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.START
|
params={},
|
||||||
)
|
)
|
||||||
|
loop = get_or_create_event_loop()
|
||||||
|
res = loop.run_until_complete(worker_manager.model_shutdown(req))
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
|
def _remote_model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||||
|
from pilot.model.adapter import _dynamic_model_parser
|
||||||
|
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||||
|
from pilot.model.cluster import RemoteWorkerManager
|
||||||
|
from pilot.model.parameter import WorkerType
|
||||||
|
from dataclasses import dataclass, field, fields
|
||||||
|
|
||||||
|
pre_args = _SimpleArgParser("model_name", "address", "host", "port")
|
||||||
|
pre_args.parse()
|
||||||
|
model_name = pre_args.get("model_name")
|
||||||
|
address = pre_args.get("address")
|
||||||
|
host = pre_args.get("host")
|
||||||
|
port = pre_args.get("port")
|
||||||
|
if port:
|
||||||
|
port = int(port)
|
||||||
|
|
||||||
|
if not address:
|
||||||
|
address = _detect_controller_address()
|
||||||
|
|
||||||
|
worker_manager: RemoteWorkerManager = _get_worker_manager(address)
|
||||||
|
loop = get_or_create_event_loop()
|
||||||
|
models = loop.run_until_complete(worker_manager.supported_models())
|
||||||
|
|
||||||
|
fields_dict = {}
|
||||||
|
fields_dict["model_name"] = (
|
||||||
|
str,
|
||||||
|
field(default=None, metadata={"help": "The model name to deploy"}),
|
||||||
|
)
|
||||||
|
fields_dict["host"] = (
|
||||||
|
str,
|
||||||
|
field(default=None, metadata={"help": "The remote host to deploy model"}),
|
||||||
|
)
|
||||||
|
fields_dict["port"] = (
|
||||||
|
int,
|
||||||
|
field(default=None, metadata={"help": "The remote port to deploy model"}),
|
||||||
|
)
|
||||||
|
result_class = dataclass(
|
||||||
|
type("RemoteModelWorkerParameters", (object,), fields_dict)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
return [result_class]
|
||||||
|
|
||||||
|
valid_models = []
|
||||||
|
valid_model_cls = []
|
||||||
|
for model in models:
|
||||||
|
if host and host != model.host:
|
||||||
|
continue
|
||||||
|
if port and port != model.port:
|
||||||
|
continue
|
||||||
|
valid_models += [m.model for m in model.models]
|
||||||
|
valid_model_cls += [
|
||||||
|
(m, _build_parameter_class(m.params)) for m in model.models if m.params
|
||||||
|
]
|
||||||
|
real_model, real_params_cls = valid_model_cls[0]
|
||||||
|
real_path = None
|
||||||
|
real_worker_type = "llm"
|
||||||
|
if model_name:
|
||||||
|
params_cls_list = [m for m in valid_model_cls if m[0].model == model_name]
|
||||||
|
if not params_cls_list:
|
||||||
|
raise ValueError(f"Not supported model with model name: {model_name}")
|
||||||
|
real_model, real_params_cls = params_cls_list[0]
|
||||||
|
real_path = real_model.path
|
||||||
|
real_worker_type = real_model.worker_type
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RemoteModelWorkerParameters(BaseParameters):
|
||||||
|
model_name: str = field(
|
||||||
|
metadata={"valid_values": valid_models, "help": "The model name to deploy"}
|
||||||
|
)
|
||||||
|
model_path: Optional[str] = field(
|
||||||
|
default=real_path, metadata={"help": "The model path to deploy"}
|
||||||
|
)
|
||||||
|
host: Optional[str] = field(
|
||||||
|
default=models[0].host,
|
||||||
|
metadata={
|
||||||
|
"valid_values": [model.host for model in models],
|
||||||
|
"help": "The remote host to deploy model",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
port: Optional[int] = field(
|
||||||
|
default=models[0].port,
|
||||||
|
metadata={
|
||||||
|
"valid_values": [model.port for model in models],
|
||||||
|
"help": "The remote port to deploy model",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
worker_type: Optional[str] = field(
|
||||||
|
default=real_worker_type,
|
||||||
|
metadata={
|
||||||
|
"valid_values": WorkerType.values(),
|
||||||
|
"help": "Worker type",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return [RemoteModelWorkerParameters, real_params_cls]
|
||||||
|
|
||||||
|
|
||||||
|
@model_cli_group.command(
|
||||||
|
cls=build_lazy_click_command(_dynamic_factory=_remote_model_dynamic_factory)
|
||||||
|
)
|
||||||
|
def start(**kwargs):
|
||||||
|
"""Start model instances"""
|
||||||
|
from pilot.model.cluster import WorkerStartupRequest, RemoteWorkerManager
|
||||||
|
|
||||||
|
worker_manager: RemoteWorkerManager = _get_worker_manager(MODEL_CONTROLLER_ADDRESS)
|
||||||
|
req = WorkerStartupRequest(
|
||||||
|
host=kwargs["host"],
|
||||||
|
port=kwargs["port"],
|
||||||
|
worker_type=kwargs["worker_type"],
|
||||||
|
model=kwargs["model_name"],
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
del kwargs["host"]
|
||||||
|
del kwargs["port"]
|
||||||
|
del kwargs["worker_type"]
|
||||||
|
req.params = kwargs
|
||||||
|
loop = get_or_create_event_loop()
|
||||||
|
res = loop.run_until_complete(worker_manager.model_startup(req))
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
@model_cli_group.command()
|
@model_cli_group.command()
|
||||||
@@ -165,25 +318,10 @@ def chat(model_name: str, system: str):
|
|||||||
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
|
_cli_chat(MODEL_CONTROLLER_ADDRESS, model_name, system)
|
||||||
|
|
||||||
|
|
||||||
# @model_cli_group.command()
|
|
||||||
# @add_model_options
|
|
||||||
# def modify(address: str, model_name: str, model_type: str):
|
|
||||||
# """Restart model instances"""
|
|
||||||
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_worker_manager(address: str):
|
|
||||||
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
|
|
||||||
|
|
||||||
registry = ModelRegistryClient(address)
|
|
||||||
worker_manager = RemoteWorkerManager(registry)
|
|
||||||
return worker_manager
|
|
||||||
|
|
||||||
|
|
||||||
def worker_apply(
|
def worker_apply(
|
||||||
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||||
):
|
):
|
||||||
from pilot.model.worker.manager import WorkerApplyRequest
|
from pilot.model.cluster import WorkerApplyRequest
|
||||||
|
|
||||||
loop = get_or_create_event_loop()
|
loop = get_or_create_event_loop()
|
||||||
worker_manager = _get_worker_manager(address)
|
worker_manager = _get_worker_manager(address)
|
||||||
@@ -201,7 +339,7 @@ def _cli_chat(address: str, model_name: str, system_prompt: str = None):
|
|||||||
|
|
||||||
|
|
||||||
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
|
async def _chat_stream(worker_manager, model_name: str, system_prompt: str = None):
|
||||||
from pilot.model.worker.manager import PromptRequest
|
from pilot.model.cluster import PromptRequest
|
||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
|
print(f"Chatbot started with model {model_name}. Type 'exit' to leave the chat.")
|
||||||
@@ -249,13 +387,11 @@ def add_stop_server_options(func):
|
|||||||
def start_model_controller(**kwargs):
|
def start_model_controller(**kwargs):
|
||||||
"""Start model controller"""
|
"""Start model controller"""
|
||||||
|
|
||||||
from pilot.model.controller.controller import run_model_controller
|
|
||||||
|
|
||||||
if kwargs["daemon"]:
|
if kwargs["daemon"]:
|
||||||
log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log")
|
log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log")
|
||||||
_run_current_with_daemon("ModelController", log_file)
|
_run_current_with_daemon("ModelController", log_file)
|
||||||
else:
|
else:
|
||||||
from pilot.model.controller.controller import run_model_controller
|
from pilot.model.cluster import run_model_controller
|
||||||
|
|
||||||
run_model_controller()
|
run_model_controller()
|
||||||
|
|
||||||
@@ -279,9 +415,8 @@ def _model_dynamic_factory() -> Callable[[None], List[Type]]:
|
|||||||
return fix_class
|
return fix_class
|
||||||
|
|
||||||
|
|
||||||
@click.command(name="worker")
|
@click.command(
|
||||||
@EnvArgumentParser.create_click_option(
|
name="worker", cls=build_lazy_click_command(_dynamic_factory=_model_dynamic_factory)
|
||||||
ModelWorkerParameters, ModelParameters, _dynamic_factory=_model_dynamic_factory
|
|
||||||
)
|
)
|
||||||
def start_model_worker(**kwargs):
|
def start_model_worker(**kwargs):
|
||||||
"""Start model worker"""
|
"""Start model worker"""
|
||||||
@@ -291,7 +426,7 @@ def start_model_worker(**kwargs):
|
|||||||
log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log")
|
log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log")
|
||||||
_run_current_with_daemon("ModelWorker", log_file)
|
_run_current_with_daemon("ModelWorker", log_file)
|
||||||
else:
|
else:
|
||||||
from pilot.model.worker.manager import run_worker_manager
|
from pilot.model.cluster import run_worker_manager
|
||||||
|
|
||||||
run_worker_manager()
|
run_worker_manager()
|
||||||
|
|
||||||
|
33
pilot/model/cluster/__init__.py
Normal file
33
pilot/model/cluster/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from pilot.model.cluster.base import (
|
||||||
|
EmbeddingsRequest,
|
||||||
|
PromptRequest,
|
||||||
|
WorkerApplyRequest,
|
||||||
|
WorkerParameterRequest,
|
||||||
|
WorkerStartupRequest,
|
||||||
|
)
|
||||||
|
from pilot.model.cluster.worker.manager import (
|
||||||
|
initialize_worker_manager_in_client,
|
||||||
|
run_worker_manager,
|
||||||
|
worker_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pilot.model.cluster.registry import ModelRegistry
|
||||||
|
from pilot.model.cluster.controller.controller import (
|
||||||
|
ModelRegistryClient,
|
||||||
|
run_model_controller,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EmbeddingsRequest",
|
||||||
|
"PromptRequest",
|
||||||
|
"WorkerApplyRequest",
|
||||||
|
"WorkerParameterRequest"
|
||||||
|
"WorkerStartupRequest"
|
||||||
|
"worker_manager"
|
||||||
|
"run_worker_manager",
|
||||||
|
"initialize_worker_manager_in_client",
|
||||||
|
"ModelRegistry",
|
||||||
|
"ModelRegistryClient" "RemoteWorkerManager" "run_model_controller",
|
||||||
|
]
|
46
pilot/model/cluster/base.py
Normal file
46
pilot/model/cluster/base.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pilot.model.base import WorkerApplyType
|
||||||
|
from pilot.model.parameter import WorkerType
|
||||||
|
from pilot.scene.base_message import ModelMessage
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
WORKER_MANAGER_SERVICE_TYPE = "service"
|
||||||
|
WORKER_MANAGER_SERVICE_NAME = "WorkerManager"
|
||||||
|
|
||||||
|
|
||||||
|
class PromptRequest(BaseModel):
|
||||||
|
messages: List[ModelMessage]
|
||||||
|
model: str
|
||||||
|
prompt: str = None
|
||||||
|
temperature: float = None
|
||||||
|
max_new_tokens: int = None
|
||||||
|
stop: str = None
|
||||||
|
echo: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
input: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerApplyRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
apply_type: WorkerApplyType
|
||||||
|
worker_type: WorkerType = WorkerType.LLM
|
||||||
|
params: Dict = None
|
||||||
|
apply_user: str = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerParameterRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
worker_type: WorkerType = WorkerType.LLM
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerStartupRequest(BaseModel):
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
model: str
|
||||||
|
worker_type: WorkerType
|
||||||
|
params: Dict
|
@@ -6,7 +6,7 @@ from typing import List
|
|||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from pilot.model.base import ModelInstance
|
from pilot.model.base import ModelInstance
|
||||||
from pilot.model.parameter import ModelControllerParameters
|
from pilot.model.parameter import ModelControllerParameters
|
||||||
from pilot.model.controller.registry import EmbeddedModelRegistry, ModelRegistry
|
from pilot.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
from pilot.utils.api_utils import _api_remote as api_remote
|
from pilot.utils.api_utils import _api_remote as api_remote
|
||||||
|
|
@@ -2,9 +2,8 @@ import pytest
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import patch
|
|
||||||
from pilot.model.base import ModelInstance
|
from pilot.model.base import ModelInstance
|
||||||
from pilot.model.controller.registry import ModelRegistry, EmbeddedModelRegistry
|
from pilot.model.cluster.registry import EmbeddedModelRegistry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -16,7 +15,7 @@ def model_registry():
|
|||||||
def model_instance():
|
def model_instance():
|
||||||
return ModelInstance(
|
return ModelInstance(
|
||||||
model_name="test_model",
|
model_name="test_model",
|
||||||
ip="192.168.1.1",
|
host="192.168.1.1",
|
||||||
port=5000,
|
port=5000,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,12 +88,7 @@ async def test_send_heartbeat(model_registry, model_instance):
|
|||||||
await model_registry.register_instance(model_instance)
|
await model_registry.register_instance(model_instance)
|
||||||
last_heartbeat = datetime.now() - timedelta(seconds=10)
|
last_heartbeat = datetime.now() - timedelta(seconds=10)
|
||||||
model_instance.last_heartbeat = last_heartbeat
|
model_instance.last_heartbeat = last_heartbeat
|
||||||
assert (
|
assert await model_registry.send_heartbeat(model_instance) == True
|
||||||
await model_registry.send_heartbeat(
|
|
||||||
model_instance.model_name, model_instance.ip, model_instance.port
|
|
||||||
)
|
|
||||||
== True
|
|
||||||
)
|
|
||||||
assert (
|
assert (
|
||||||
model_registry.registry[model_instance.model_name][0].last_heartbeat
|
model_registry.registry[model_instance.model_name][0].last_heartbeat
|
||||||
> last_heartbeat
|
> last_heartbeat
|
||||||
@@ -125,7 +119,7 @@ async def test_multiple_instances(model_registry, model_instance):
|
|||||||
"""
|
"""
|
||||||
model_instance2 = ModelInstance(
|
model_instance2 = ModelInstance(
|
||||||
model_name="test_model",
|
model_name="test_model",
|
||||||
ip="192.168.1.2",
|
host="192.168.1.2",
|
||||||
port=5000,
|
port=5000,
|
||||||
)
|
)
|
||||||
await model_registry.register_instance(model_instance)
|
await model_registry.register_instance(model_instance)
|
||||||
@@ -138,11 +132,11 @@ async def test_same_model_name_different_ip_port(model_registry):
|
|||||||
"""
|
"""
|
||||||
Test if instances with the same model name but different IP and port are handled correctly
|
Test if instances with the same model name but different IP and port are handled correctly
|
||||||
"""
|
"""
|
||||||
instance1 = ModelInstance(model_name="test_model", ip="192.168.1.1", port=5000)
|
instance1 = ModelInstance(model_name="test_model", host="192.168.1.1", port=5000)
|
||||||
instance2 = ModelInstance(model_name="test_model", ip="192.168.1.2", port=6000)
|
instance2 = ModelInstance(model_name="test_model", host="192.168.1.2", port=6000)
|
||||||
await model_registry.register_instance(instance1)
|
await model_registry.register_instance(instance1)
|
||||||
await model_registry.register_instance(instance2)
|
await model_registry.register_instance(instance2)
|
||||||
instances = await model_registry.get_all_instances("test_model")
|
instances = await model_registry.get_all_instances("test_model")
|
||||||
assert len(instances) == 2
|
assert len(instances) == 2
|
||||||
assert instances[0].ip != instances[1].ip
|
assert instances[0].host != instances[1].host
|
||||||
assert instances[0].port != instances[1].port
|
assert instances[0].port != instances[1].port
|
82
pilot/model/cluster/manager_base.py
Normal file
82
pilot/model/cluster/manager_base.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Dict, Iterator
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from concurrent.futures import Future
|
||||||
|
from pilot.model.base import WorkerSupportedModel, ModelOutput, WorkerApplyOutput
|
||||||
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
|
from pilot.model.cluster.base import WorkerStartupRequest, WorkerApplyRequest
|
||||||
|
from pilot.model.parameter import ModelWorkerParameters, ModelParameters
|
||||||
|
from pilot.utils.parameter_utils import ParameterDescription
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkerRunData:
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
worker_key: str
|
||||||
|
worker: ModelWorker
|
||||||
|
worker_params: ModelWorkerParameters
|
||||||
|
model_params: ModelParameters
|
||||||
|
stop_event: asyncio.Event
|
||||||
|
semaphore: asyncio.Semaphore = None
|
||||||
|
command_args: List[str] = None
|
||||||
|
_heartbeat_future: Optional[Future] = None
|
||||||
|
_last_heartbeat: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerManager(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def start(self):
|
||||||
|
"""Start worker manager"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop worker manager"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_model_instances(
|
||||||
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||||
|
) -> List[WorkerRunData]:
|
||||||
|
"""Get model instances by worker type and model name"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def select_one_instance(
|
||||||
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||||
|
) -> WorkerRunData:
|
||||||
|
"""Select one instance"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||||
|
"""List supported models"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||||
|
"""Create and start a model instance"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||||
|
"""Shutdown model instance"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
|
||||||
|
"""Generate stream result, chat scene"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate(self, params: Dict) -> ModelOutput:
|
||||||
|
"""Generate non stream result"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||||
|
"""Embed input"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||||
|
"""Worker apply"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def parameter_descriptions(
|
||||||
|
self, worker_type: str, model_name: str
|
||||||
|
) -> List[ParameterDescription]:
|
||||||
|
"""Get parameter descriptions of model"""
|
@@ -7,7 +7,7 @@ from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
|||||||
from pilot.model.base import ModelOutput
|
from pilot.model.base import ModelOutput
|
||||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||||
from pilot.model.parameter import ModelParameters
|
from pilot.model.parameter import ModelParameters
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_torch_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
@@ -7,7 +7,7 @@ from pilot.model.parameter import (
|
|||||||
EmbeddingModelParameters,
|
EmbeddingModelParameters,
|
||||||
WorkerType,
|
WorkerType,
|
||||||
)
|
)
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_torch_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
@@ -4,13 +4,11 @@ import json
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from dataclasses import asdict
|
||||||
from dataclasses import asdict, dataclass
|
from typing import Awaitable, Callable, Dict, Iterator, List
|
||||||
from datetime import datetime
|
|
||||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI, Request
|
from fastapi import APIRouter, FastAPI
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.model.base import (
|
from pilot.model.base import (
|
||||||
@@ -18,103 +16,34 @@ from pilot.model.base import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
WorkerApplyOutput,
|
WorkerApplyOutput,
|
||||||
WorkerApplyType,
|
WorkerApplyType,
|
||||||
|
WorkerSupportedModel,
|
||||||
)
|
)
|
||||||
from pilot.model.controller.registry import ModelRegistry
|
from pilot.model.cluster.registry import ModelRegistry
|
||||||
|
from pilot.model.llm_utils import list_supported_models
|
||||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.scene.base_message import ModelMessage
|
from pilot.model.cluster.manager_base import WorkerManager, WorkerRunData
|
||||||
|
from pilot.model.cluster.base import *
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser, ParameterDescription
|
from pilot.utils.parameter_utils import (
|
||||||
from pydantic import BaseModel
|
EnvArgumentParser,
|
||||||
|
ParameterDescription,
|
||||||
|
_dict_to_command_args,
|
||||||
|
)
|
||||||
|
|
||||||
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
|
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
|
||||||
|
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
|
||||||
messages: List[ModelMessage]
|
|
||||||
model: str
|
|
||||||
prompt: str = None
|
|
||||||
temperature: float = None
|
|
||||||
max_new_tokens: int = None
|
|
||||||
stop: str = None
|
|
||||||
echo: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
input: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerApplyRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
apply_type: WorkerApplyType
|
|
||||||
worker_type: WorkerType = WorkerType.LLM
|
|
||||||
params: Dict = None
|
|
||||||
apply_user: str = None
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerParameterRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
worker_type: WorkerType = WorkerType.LLM
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WorkerRunData:
|
|
||||||
worker_key: str
|
|
||||||
worker: ModelWorker
|
|
||||||
worker_params: ModelWorkerParameters
|
|
||||||
model_params: ModelParameters
|
|
||||||
stop_event: asyncio.Event
|
|
||||||
semaphore: asyncio.Semaphore = None
|
|
||||||
command_args: List[str] = None
|
|
||||||
_heartbeat_future: Optional[Future] = None
|
|
||||||
_last_heartbeat: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||||
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||||
SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]]
|
SendHeartbeatFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||||
ApplyFunction = Callable[[WorkerRunData], Awaitable[None]]
|
ApplyFunction = Callable[[WorkerRunData], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
class WorkerManager(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
async def get_model_instances(
|
|
||||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
||||||
) -> List[WorkerRunData]:
|
|
||||||
"""Get model instances by worker type and model name"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def select_one_instanes(
|
|
||||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
||||||
) -> WorkerRunData:
|
|
||||||
"""Select one instances"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def generate_stream(self, params: Dict, **kwargs) -> Iterator[ModelOutput]:
|
|
||||||
"""Generate stream result, chat scene"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def generate(self, params: Dict) -> ModelOutput:
|
|
||||||
"""Generate non stream result"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
|
||||||
"""Embed input"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
|
||||||
"""Worker apply"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def parameter_descriptions(
|
|
||||||
self, worker_type: str, model_name: str
|
|
||||||
) -> List[ParameterDescription]:
|
|
||||||
"""Get parameter descriptions of model"""
|
|
||||||
|
|
||||||
|
|
||||||
async def _async_heartbeat_sender(
|
async def _async_heartbeat_sender(
|
||||||
worker_run_data: WorkerRunData, send_heartbeat_func: SendHeartbeatFunc
|
worker_run_data: WorkerRunData,
|
||||||
|
heartbeat_interval,
|
||||||
|
send_heartbeat_func: SendHeartbeatFunc,
|
||||||
):
|
):
|
||||||
while not worker_run_data.stop_event.is_set():
|
while not worker_run_data.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
@@ -122,7 +51,7 @@ async def _async_heartbeat_sender(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn(f"Send heartbeat func error: {str(e)}")
|
logger.warn(f"Send heartbeat func error: {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
await asyncio.sleep(worker_run_data.worker_params.heartbeat_interval)
|
await asyncio.sleep(heartbeat_interval)
|
||||||
|
|
||||||
|
|
||||||
class LocalWorkerManager(WorkerManager):
|
class LocalWorkerManager(WorkerManager):
|
||||||
@@ -132,6 +61,8 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
deregister_func: DeregisterFunc = None,
|
deregister_func: DeregisterFunc = None,
|
||||||
send_heartbeat_func: SendHeartbeatFunc = None,
|
send_heartbeat_func: SendHeartbeatFunc = None,
|
||||||
model_registry: ModelRegistry = None,
|
model_registry: ModelRegistry = None,
|
||||||
|
host: str = None,
|
||||||
|
port: int = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.workers: Dict[str, List[WorkerRunData]] = dict()
|
self.workers: Dict[str, List[WorkerRunData]] = dict()
|
||||||
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
||||||
@@ -139,19 +70,58 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
self.deregister_func = deregister_func
|
self.deregister_func = deregister_func
|
||||||
self.send_heartbeat_func = send_heartbeat_func
|
self.send_heartbeat_func = send_heartbeat_func
|
||||||
self.model_registry = model_registry
|
self.model_registry = model_registry
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
self.run_data = WorkerRunData(
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
worker_key=self._worker_key(
|
||||||
|
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
|
||||||
|
),
|
||||||
|
worker=None,
|
||||||
|
worker_params=None,
|
||||||
|
model_params=None,
|
||||||
|
stop_event=asyncio.Event(),
|
||||||
|
semaphore=None,
|
||||||
|
command_args=None,
|
||||||
|
)
|
||||||
|
|
||||||
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
||||||
if isinstance(worker_type, WorkerType):
|
if isinstance(worker_type, WorkerType):
|
||||||
worker_type = worker_type.value
|
worker_type = worker_type.value
|
||||||
return f"{model_name}@{worker_type}"
|
return f"{model_name}@{worker_type}"
|
||||||
|
|
||||||
|
async def run_blocking_func(self, func, *args):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(self.executor, func, *args)
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if len(self.workers) > 0:
|
||||||
|
await self._start_all_worker(apply_req=None)
|
||||||
|
if self.register_func:
|
||||||
|
await self.register_func(self.run_data)
|
||||||
|
if self.send_heartbeat_func:
|
||||||
|
asyncio.create_task(
|
||||||
|
_async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
if not self.run_data.stop_event.is_set():
|
||||||
|
logger.info("Stop all workers")
|
||||||
|
self.run_data.stop_event.clear()
|
||||||
|
stop_tasks = []
|
||||||
|
stop_tasks.append(self._stop_all_worker(apply_req=None))
|
||||||
|
if self.deregister_func:
|
||||||
|
stop_tasks.append(self.deregister_func(self.run_data))
|
||||||
|
await asyncio.gather(*stop_tasks)
|
||||||
|
|
||||||
def add_worker(
|
def add_worker(
|
||||||
self,
|
self,
|
||||||
worker: ModelWorker,
|
worker: ModelWorker,
|
||||||
worker_params: ModelWorkerParameters,
|
worker_params: ModelWorkerParameters,
|
||||||
embedded_mod: bool = True,
|
|
||||||
command_args: List[str] = None,
|
command_args: List[str] = None,
|
||||||
):
|
) -> bool:
|
||||||
if not command_args:
|
if not command_args:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -179,6 +149,8 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
model_params = worker.parse_parameters(command_args=command_args)
|
model_params = worker.parse_parameters(command_args=command_args)
|
||||||
|
|
||||||
worker_run_data = WorkerRunData(
|
worker_run_data = WorkerRunData(
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
worker_key=worker_key,
|
worker_key=worker_key,
|
||||||
worker=worker,
|
worker=worker,
|
||||||
worker_params=worker_params,
|
worker_params=worker_params,
|
||||||
@@ -187,14 +159,66 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
|
semaphore=asyncio.Semaphore(worker_params.limit_model_concurrency),
|
||||||
command_args=command_args,
|
command_args=command_args,
|
||||||
)
|
)
|
||||||
if not embedded_mod:
|
|
||||||
exist_instances = [
|
exist_instances = [
|
||||||
(w, p) for w, p in instances if p.host == host and p.port == port
|
ins for ins in instances if ins.host == host and ins.port == port
|
||||||
]
|
]
|
||||||
if not exist_instances:
|
if not exist_instances:
|
||||||
instances.append(worker_run_data)
|
instances.append(worker_run_data)
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
instances.append(worker_run_data)
|
# TODO Update worker
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||||
|
"""Start model"""
|
||||||
|
model_name = startup_req.model
|
||||||
|
worker_type = startup_req.worker_type
|
||||||
|
params = startup_req.params
|
||||||
|
logger.debug(
|
||||||
|
f"start model, model name {model_name}, worker type {worker_type}, params: {params}"
|
||||||
|
)
|
||||||
|
worker_params: ModelWorkerParameters = ModelWorkerParameters.from_dict(
|
||||||
|
params, ignore_extra_fields=True
|
||||||
|
)
|
||||||
|
if not worker_params.model_name:
|
||||||
|
worker_params.model_name = model_name
|
||||||
|
assert model_name == worker_params.model_name
|
||||||
|
worker = _build_worker(worker_params)
|
||||||
|
command_args = _dict_to_command_args(params)
|
||||||
|
success = await self.run_blocking_func(
|
||||||
|
self.add_worker, worker, worker_params, command_args
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
logger.warn(
|
||||||
|
f"Add worker failed, worker instances is exist, worker_params: {worker_params}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
supported_types = WorkerType.values()
|
||||||
|
if worker_type not in supported_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported worker type: {worker_type}, now supported worker type: {supported_types}"
|
||||||
|
)
|
||||||
|
start_apply_req = WorkerApplyRequest(
|
||||||
|
model=model_name, apply_type=WorkerApplyType.START, worker_type=worker_type
|
||||||
|
)
|
||||||
|
await self.worker_apply(start_apply_req)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||||
|
logger.info(f"Begin shutdown model, shutdown_req: {shutdown_req}")
|
||||||
|
apply_req = WorkerApplyRequest(
|
||||||
|
model=shutdown_req.model,
|
||||||
|
apply_type=WorkerApplyType.STOP,
|
||||||
|
worker_type=shutdown_req.worker_type,
|
||||||
|
)
|
||||||
|
out = await self._stop_all_worker(apply_req)
|
||||||
|
if out.success:
|
||||||
|
return True
|
||||||
|
raise Exception(out.message)
|
||||||
|
|
||||||
|
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||||
|
models = await self.run_blocking_func(list_supported_models)
|
||||||
|
return [WorkerSupportedModel(host=self.host, port=self.port, models=models)]
|
||||||
|
|
||||||
async def get_model_instances(
|
async def get_model_instances(
|
||||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||||
@@ -202,7 +226,7 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
worker_key = self._worker_key(worker_type, model_name)
|
worker_key = self._worker_key(worker_type, model_name)
|
||||||
return self.workers.get(worker_key)
|
return self.workers.get(worker_key)
|
||||||
|
|
||||||
async def select_one_instanes(
|
async def select_one_instance(
|
||||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||||
) -> WorkerRunData:
|
) -> WorkerRunData:
|
||||||
worker_instances = await self.get_model_instances(
|
worker_instances = await self.get_model_instances(
|
||||||
@@ -219,7 +243,7 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
model = params.get("model")
|
model = params.get("model")
|
||||||
if not model:
|
if not model:
|
||||||
raise Exception("Model name count not be empty")
|
raise Exception("Model name count not be empty")
|
||||||
return await self.select_one_instanes(worker_type, model, healthy_only=True)
|
return await self.select_one_instance(worker_type, model, healthy_only=True)
|
||||||
|
|
||||||
async def generate_stream(
|
async def generate_stream(
|
||||||
self, params: Dict, async_wrapper=None, **kwargs
|
self, params: Dict, async_wrapper=None, **kwargs
|
||||||
@@ -262,9 +286,8 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
if worker_run_data.worker.support_async():
|
if worker_run_data.worker.support_async():
|
||||||
return await worker_run_data.worker.async_generate(params)
|
return await worker_run_data.worker.async_generate(params)
|
||||||
else:
|
else:
|
||||||
loop = asyncio.get_event_loop()
|
return await self.run_blocking_func(
|
||||||
return await loop.run_in_executor(
|
worker_run_data.worker.generate, params
|
||||||
self.executor, worker_run_data.worker.generate, params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def embeddings(self, params: Dict) -> List[List[float]]:
|
async def embeddings(self, params: Dict) -> List[List[float]]:
|
||||||
@@ -277,9 +300,8 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
if worker_run_data.worker.support_async():
|
if worker_run_data.worker.support_async():
|
||||||
return await worker_run_data.worker.async_embeddings(params)
|
return await worker_run_data.worker.async_embeddings(params)
|
||||||
else:
|
else:
|
||||||
loop = asyncio.get_event_loop()
|
return await self.run_blocking_func(
|
||||||
return await loop.run_in_executor(
|
worker_run_data.worker.embeddings, params
|
||||||
self.executor, worker_run_data.worker.embeddings, params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||||
@@ -342,8 +364,10 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
||||||
|
|
||||||
async def _start_worker(worker_run_data: WorkerRunData):
|
async def _start_worker(worker_run_data: WorkerRunData):
|
||||||
worker_run_data.worker.start(
|
await self.run_blocking_func(
|
||||||
worker_run_data.model_params, worker_run_data.command_args
|
worker_run_data.worker.start,
|
||||||
|
worker_run_data.model_params,
|
||||||
|
worker_run_data.command_args,
|
||||||
)
|
)
|
||||||
worker_run_data.stop_event.clear()
|
worker_run_data.stop_event.clear()
|
||||||
if worker_run_data.worker_params.register and self.register_func:
|
if worker_run_data.worker_params.register and self.register_func:
|
||||||
@@ -355,7 +379,9 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
):
|
):
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_async_heartbeat_sender(
|
_async_heartbeat_sender(
|
||||||
worker_run_data, self.send_heartbeat_func
|
worker_run_data,
|
||||||
|
worker_run_data.worker_params.heartbeat_interval,
|
||||||
|
self.send_heartbeat_func,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -371,7 +397,7 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
async def _stop_worker(worker_run_data: WorkerRunData):
|
async def _stop_worker(worker_run_data: WorkerRunData):
|
||||||
worker_run_data.worker.stop()
|
await self.run_blocking_func(worker_run_data.worker.stop)
|
||||||
# Set stop event
|
# Set stop event
|
||||||
worker_run_data.stop_event.set()
|
worker_run_data.stop_event.set()
|
||||||
if worker_run_data._heartbeat_future:
|
if worker_run_data._heartbeat_future:
|
||||||
@@ -422,63 +448,25 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
return WorkerApplyOutput(message=message, timecost=timecost)
|
return WorkerApplyOutput(message=message, timecost=timecost)
|
||||||
|
|
||||||
|
|
||||||
class RemoteWorkerManager(LocalWorkerManager):
|
|
||||||
def __init__(self, model_registry: ModelRegistry = None) -> None:
|
|
||||||
super().__init__(model_registry=model_registry)
|
|
||||||
|
|
||||||
async def get_model_instances(
|
|
||||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
|
||||||
) -> List[WorkerRunData]:
|
|
||||||
from pilot.model.worker.remote_worker import RemoteModelWorker
|
|
||||||
|
|
||||||
worker_key = self._worker_key(worker_type, model_name)
|
|
||||||
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
|
|
||||||
worker_key, healthy_only
|
|
||||||
)
|
|
||||||
worker_instances = []
|
|
||||||
for ins in instances:
|
|
||||||
worker = RemoteModelWorker()
|
|
||||||
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
|
|
||||||
wr = WorkerRunData(
|
|
||||||
worker_key=ins.model_name,
|
|
||||||
worker=worker,
|
|
||||||
worker_params=None,
|
|
||||||
model_params=None,
|
|
||||||
stop_event=asyncio.Event(),
|
|
||||||
semaphore=asyncio.Semaphore(100), # Not limit in client
|
|
||||||
)
|
|
||||||
worker_instances.append(wr)
|
|
||||||
return worker_instances
|
|
||||||
|
|
||||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
|
||||||
worker_addr = worker_run_data.worker.worker_addr
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post(
|
|
||||||
worker_addr + "/apply",
|
|
||||||
headers=worker_run_data.worker.headers,
|
|
||||||
json=apply_req.dict(),
|
|
||||||
timeout=worker_run_data.worker.timeout,
|
|
||||||
)
|
|
||||||
if response.status_code == 200:
|
|
||||||
output = WorkerApplyOutput(**response.json())
|
|
||||||
logger.info(f"worker_apply success: {output}")
|
|
||||||
else:
|
|
||||||
output = WorkerApplyOutput(message=response.text)
|
|
||||||
logger.warn(f"worker_apply failed: {output}")
|
|
||||||
return output
|
|
||||||
|
|
||||||
results = await self._apply_worker(apply_req, _remote_apply_func)
|
|
||||||
if results:
|
|
||||||
return results[0]
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerManagerAdapter(WorkerManager):
|
class WorkerManagerAdapter(WorkerManager):
|
||||||
def __init__(self, worker_manager: WorkerManager = None) -> None:
|
def __init__(self, worker_manager: WorkerManager = None) -> None:
|
||||||
self.worker_manager = worker_manager
|
self.worker_manager = worker_manager
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
return await self.worker_manager.start()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
return await self.worker_manager.stop()
|
||||||
|
|
||||||
|
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||||
|
return await self.worker_manager.supported_models()
|
||||||
|
|
||||||
|
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||||
|
return await self.worker_manager.model_startup(startup_req)
|
||||||
|
|
||||||
|
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||||
|
return await self.worker_manager.model_shutdown(shutdown_req)
|
||||||
|
|
||||||
async def get_model_instances(
|
async def get_model_instances(
|
||||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||||
) -> List[WorkerRunData]:
|
) -> List[WorkerRunData]:
|
||||||
@@ -486,10 +474,10 @@ class WorkerManagerAdapter(WorkerManager):
|
|||||||
worker_type, model_name, healthy_only
|
worker_type, model_name, healthy_only
|
||||||
)
|
)
|
||||||
|
|
||||||
async def select_one_instanes(
|
async def select_one_instance(
|
||||||
self, worker_type: str, model_name: str, healthy_only: bool = True
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||||
) -> WorkerRunData:
|
) -> WorkerRunData:
|
||||||
return await self.worker_manager.select_one_instanes(
|
return await self.worker_manager.select_one_instance(
|
||||||
worker_type, model_name, healthy_only
|
worker_type, model_name, healthy_only
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -535,37 +523,58 @@ async def api_generate_stream(request: PromptRequest):
|
|||||||
@router.post("/worker/generate")
|
@router.post("/worker/generate")
|
||||||
async def api_generate(request: PromptRequest):
|
async def api_generate(request: PromptRequest):
|
||||||
params = request.dict(exclude_none=True)
|
params = request.dict(exclude_none=True)
|
||||||
output = await worker_manager.generate(params)
|
return await worker_manager.generate(params)
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/worker/embeddings")
|
@router.post("/worker/embeddings")
|
||||||
async def api_embeddings(request: EmbeddingsRequest):
|
async def api_embeddings(request: EmbeddingsRequest):
|
||||||
params = request.dict(exclude_none=True)
|
params = request.dict(exclude_none=True)
|
||||||
output = await worker_manager.embeddings(params)
|
return await worker_manager.embeddings(params)
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/worker/apply")
|
@router.post("/worker/apply")
|
||||||
async def api_worker_apply(request: WorkerApplyRequest):
|
async def api_worker_apply(request: WorkerApplyRequest):
|
||||||
output = await worker_manager.worker_apply(request)
|
return await worker_manager.worker_apply(request)
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/worker/parameter/descriptions")
|
@router.get("/worker/parameter/descriptions")
|
||||||
async def api_worker_parameter_descs(
|
async def api_worker_parameter_descs(
|
||||||
model: str, worker_type: str = WorkerType.LLM.value
|
model: str, worker_type: str = WorkerType.LLM.value
|
||||||
):
|
):
|
||||||
output = await worker_manager.parameter_descriptions(worker_type, model)
|
return await worker_manager.parameter_descriptions(worker_type, model)
|
||||||
return output
|
|
||||||
|
|
||||||
|
@router.get("/worker/models/supports")
|
||||||
|
async def api_supported_models():
|
||||||
|
"""Get all supported models.
|
||||||
|
|
||||||
|
This method reads all models from the configuration file and tries to perform some basic checks on the model (like if the path exists).
|
||||||
|
|
||||||
|
If it's a RemoteWorkerManager, this method returns the list of models supported by the entire cluster.
|
||||||
|
"""
|
||||||
|
return await worker_manager.supported_models()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/worker/models/startup")
|
||||||
|
async def api_model_startup(request: WorkerStartupRequest):
|
||||||
|
"""Start up a specific model."""
|
||||||
|
return await worker_manager.model_startup(request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/worker/models/shutdown")
|
||||||
|
async def api_model_shutdown(request: WorkerStartupRequest):
|
||||||
|
"""Shut down a specific model."""
|
||||||
|
return await worker_manager.model_shutdown(request)
|
||||||
|
|
||||||
|
|
||||||
def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
||||||
if not app:
|
if not app:
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
if worker_params.standalone:
|
if worker_params.standalone:
|
||||||
from pilot.model.controller.controller import router as controller_router
|
from pilot.model.cluster.controller.controller import initialize_controller
|
||||||
from pilot.model.controller.controller import initialize_controller
|
from pilot.model.cluster.controller.controller import (
|
||||||
|
router as controller_router,
|
||||||
|
)
|
||||||
|
|
||||||
if not worker_params.controller_addr:
|
if not worker_params.controller_addr:
|
||||||
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
worker_params.controller_addr = f"http://127.0.0.1:{worker_params.port}"
|
||||||
@@ -577,9 +586,11 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
|||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
asyncio.create_task(
|
asyncio.create_task(worker_manager.worker_manager.start())
|
||||||
worker_manager.worker_manager._start_all_worker(apply_req=None)
|
|
||||||
)
|
@app.on_event("shutdown")
|
||||||
|
async def startup_event():
|
||||||
|
await worker_manager.worker_manager.stop()
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
@@ -609,22 +620,23 @@ def _parse_worker_params(
|
|||||||
def _create_local_model_manager(
|
def _create_local_model_manager(
|
||||||
worker_params: ModelWorkerParameters,
|
worker_params: ModelWorkerParameters,
|
||||||
) -> LocalWorkerManager:
|
) -> LocalWorkerManager:
|
||||||
if not worker_params.register or not worker_params.controller_addr:
|
|
||||||
logger.info(
|
|
||||||
f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}"
|
|
||||||
)
|
|
||||||
return LocalWorkerManager()
|
|
||||||
else:
|
|
||||||
from pilot.model.controller.controller import ModelRegistryClient
|
|
||||||
from pilot.utils.net_utils import _get_ip_address
|
from pilot.utils.net_utils import _get_ip_address
|
||||||
|
|
||||||
client = ModelRegistryClient(worker_params.controller_addr)
|
|
||||||
host = (
|
host = (
|
||||||
worker_params.worker_register_host
|
worker_params.worker_register_host
|
||||||
if worker_params.worker_register_host
|
if worker_params.worker_register_host
|
||||||
else _get_ip_address()
|
else _get_ip_address()
|
||||||
)
|
)
|
||||||
port = worker_params.port
|
port = worker_params.port
|
||||||
|
if not worker_params.register or not worker_params.controller_addr:
|
||||||
|
logger.info(
|
||||||
|
f"Not register current to controller, register: {worker_params.register}, controller_addr: {worker_params.controller_addr}"
|
||||||
|
)
|
||||||
|
return LocalWorkerManager(host=host, port=port)
|
||||||
|
else:
|
||||||
|
from pilot.model.cluster.controller.controller import ModelRegistryClient
|
||||||
|
|
||||||
|
client = ModelRegistryClient(worker_params.controller_addr)
|
||||||
|
|
||||||
async def register_func(worker_run_data: WorkerRunData):
|
async def register_func(worker_run_data: WorkerRunData):
|
||||||
instance = ModelInstance(
|
instance = ModelInstance(
|
||||||
@@ -648,31 +660,33 @@ def _create_local_model_manager(
|
|||||||
register_func=register_func,
|
register_func=register_func,
|
||||||
deregister_func=deregister_func,
|
deregister_func=deregister_func,
|
||||||
send_heartbeat_func=send_heartbeat_func,
|
send_heartbeat_func=send_heartbeat_func,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _start_local_worker(
|
def _build_worker(worker_params: ModelWorkerParameters):
|
||||||
worker_manager: WorkerManagerAdapter,
|
if worker_params.worker_class:
|
||||||
worker_params: ModelWorkerParameters,
|
|
||||||
embedded_mod=True,
|
|
||||||
):
|
|
||||||
from pilot.utils.module_utils import import_from_checked_string
|
from pilot.utils.module_utils import import_from_checked_string
|
||||||
|
|
||||||
if worker_params.worker_class:
|
|
||||||
worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker)
|
worker_cls = import_from_checked_string(worker_params.worker_class, ModelWorker)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Import worker class from {worker_params.worker_class} successfully"
|
f"Import worker class from {worker_params.worker_class} successfully"
|
||||||
)
|
)
|
||||||
worker: ModelWorker = worker_cls()
|
worker: ModelWorker = worker_cls()
|
||||||
else:
|
else:
|
||||||
from pilot.model.worker.default_worker import DefaultModelWorker
|
from pilot.model.cluster.worker.default_worker import DefaultModelWorker
|
||||||
|
|
||||||
worker = DefaultModelWorker()
|
worker = DefaultModelWorker()
|
||||||
|
return worker
|
||||||
|
|
||||||
|
|
||||||
|
def _start_local_worker(
|
||||||
|
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
|
||||||
|
):
|
||||||
|
worker = _build_worker(worker_params)
|
||||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||||
worker_manager.worker_manager.add_worker(
|
worker_manager.worker_manager.add_worker(worker, worker_params)
|
||||||
worker, worker_params, embedded_mod=embedded_mod
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_worker_manager_in_client(
|
def initialize_worker_manager_in_client(
|
||||||
@@ -713,16 +727,13 @@ def initialize_worker_manager_in_client(
|
|||||||
worker_params.port = local_port
|
worker_params.port = local_port
|
||||||
logger.info(f"Worker params: {worker_params}")
|
logger.info(f"Worker params: {worker_params}")
|
||||||
_setup_fastapi(worker_params, app)
|
_setup_fastapi(worker_params, app)
|
||||||
_start_local_worker(worker_manager, worker_params, True)
|
_start_local_worker(worker_manager, worker_params)
|
||||||
# loop = asyncio.get_event_loop()
|
|
||||||
# loop.run_until_complete(
|
|
||||||
# worker_manager.worker_manager._start_all_worker(apply_req=None)
|
|
||||||
# )
|
|
||||||
else:
|
else:
|
||||||
from pilot.model.controller.controller import (
|
from pilot.model.cluster.controller.controller import (
|
||||||
initialize_controller,
|
|
||||||
ModelRegistryClient,
|
ModelRegistryClient,
|
||||||
|
initialize_controller,
|
||||||
)
|
)
|
||||||
|
from pilot.model.cluster.worker.remote_manager import RemoteWorkerManager
|
||||||
|
|
||||||
if not worker_params.controller_addr:
|
if not worker_params.controller_addr:
|
||||||
raise ValueError("Controller can`t be None")
|
raise ValueError("Controller can`t be None")
|
||||||
@@ -758,13 +769,11 @@ def run_worker_manager(
|
|||||||
# Run worker manager independently
|
# Run worker manager independently
|
||||||
embedded_mod = False
|
embedded_mod = False
|
||||||
app = _setup_fastapi(worker_params)
|
app = _setup_fastapi(worker_params)
|
||||||
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
|
_start_local_worker(worker_manager, worker_params)
|
||||||
else:
|
else:
|
||||||
_start_local_worker(worker_manager, worker_params, embedded_mod=False)
|
_start_local_worker(worker_manager, worker_params)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
loop.run_until_complete(
|
loop.run_until_complete(worker_manager.worker_manager.start())
|
||||||
worker_manager.worker_manager._start_all_worker(apply_req=None)
|
|
||||||
)
|
|
||||||
|
|
||||||
if include_router:
|
if include_router:
|
||||||
app.include_router(router, prefix="/api")
|
app.include_router(router, prefix="/api")
|
0
pilot/model/cluster/worker/ray_worker.py
Normal file
0
pilot/model/cluster/worker/ray_worker.py
Normal file
169
pilot/model/cluster/worker/remote_manager.py
Normal file
169
pilot/model/cluster/worker/remote_manager.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
from typing import Callable, Any
|
||||||
|
import httpx
|
||||||
|
import asyncio
|
||||||
|
from pilot.model.cluster.registry import ModelRegistry
|
||||||
|
from pilot.model.cluster.worker.manager import LocalWorkerManager, WorkerRunData, logger
|
||||||
|
from pilot.model.cluster.base import *
|
||||||
|
from pilot.model.base import (
|
||||||
|
ModelInstance,
|
||||||
|
WorkerApplyOutput,
|
||||||
|
WorkerSupportedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteWorkerManager(LocalWorkerManager):
|
||||||
|
def __init__(self, model_registry: ModelRegistry = None) -> None:
|
||||||
|
super().__init__(model_registry=model_registry)
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _fetch_from_worker(
|
||||||
|
self,
|
||||||
|
worker_run_data: WorkerRunData,
|
||||||
|
endpoint: str,
|
||||||
|
method: str = "GET",
|
||||||
|
json: dict = None,
|
||||||
|
params: dict = None,
|
||||||
|
additional_headers: dict = None,
|
||||||
|
success_handler: Callable = None,
|
||||||
|
error_handler: Callable = None,
|
||||||
|
) -> Any:
|
||||||
|
url = worker_run_data.worker.worker_addr + endpoint
|
||||||
|
headers = {**worker_run_data.worker.headers, **(additional_headers or {})}
|
||||||
|
timeout = worker_run_data.worker.timeout
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
request = client.build_request(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
json=json, # using json for data to ensure it sends as application/json
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.send(request)
|
||||||
|
if response.status_code != 200:
|
||||||
|
if error_handler:
|
||||||
|
return error_handler(response)
|
||||||
|
else:
|
||||||
|
error_msg = f"Request to {url} failed, error: {response.text}"
|
||||||
|
raise Exception(error_msg)
|
||||||
|
if success_handler:
|
||||||
|
return success_handler(response)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def _apply_to_worker_manager_instances(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def supported_models(self) -> List[WorkerSupportedModel]:
|
||||||
|
worker_instances = await self.get_model_instances(
|
||||||
|
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_supported_models(worker_run_data) -> List[WorkerSupportedModel]:
|
||||||
|
def handler(response):
|
||||||
|
return list(WorkerSupportedModel.from_dict(m) for m in response.json())
|
||||||
|
|
||||||
|
return await self._fetch_from_worker(
|
||||||
|
worker_run_data, "/models/supports", success_handler=handler
|
||||||
|
)
|
||||||
|
|
||||||
|
models = []
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(get_supported_models(worker) for worker in worker_instances)
|
||||||
|
)
|
||||||
|
for res in results:
|
||||||
|
models += res
|
||||||
|
return models
|
||||||
|
|
||||||
|
async def _get_worker_service_instance(
|
||||||
|
self, host: str = None, port: int = None
|
||||||
|
) -> List[WorkerRunData]:
|
||||||
|
worker_instances = await self.get_model_instances(
|
||||||
|
WORKER_MANAGER_SERVICE_TYPE, WORKER_MANAGER_SERVICE_NAME
|
||||||
|
)
|
||||||
|
error_msg = f"Cound not found worker instances"
|
||||||
|
if host and port:
|
||||||
|
worker_instances = [
|
||||||
|
ins for ins in worker_instances if ins.host == host and ins.port == port
|
||||||
|
]
|
||||||
|
error_msg = f"Cound not found worker instances for host {host} port {port}"
|
||||||
|
if not worker_instances:
|
||||||
|
raise Exception(error_msg)
|
||||||
|
return worker_instances
|
||||||
|
|
||||||
|
async def model_startup(self, startup_req: WorkerStartupRequest) -> bool:
|
||||||
|
worker_instances = await self._get_worker_service_instance(
|
||||||
|
startup_req.host, startup_req.port
|
||||||
|
)
|
||||||
|
worker_run_data = worker_instances[0]
|
||||||
|
logger.info(f"Start model remote, startup_req: {startup_req}")
|
||||||
|
return await self._fetch_from_worker(
|
||||||
|
worker_run_data,
|
||||||
|
"/models/startup",
|
||||||
|
method="POST",
|
||||||
|
json=startup_req.dict(),
|
||||||
|
success_handler=lambda x: True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def model_shutdown(self, shutdown_req: WorkerStartupRequest) -> bool:
|
||||||
|
worker_instances = await self._get_worker_service_instance(
|
||||||
|
shutdown_req.host, shutdown_req.port
|
||||||
|
)
|
||||||
|
worker_run_data = worker_instances[0]
|
||||||
|
logger.info(f"Shutdown model remote, shutdown_req: {shutdown_req}")
|
||||||
|
return await self._fetch_from_worker(
|
||||||
|
worker_run_data,
|
||||||
|
"/models/shutdown",
|
||||||
|
method="POST",
|
||||||
|
json=shutdown_req.dict(),
|
||||||
|
success_handler=lambda x: True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_model_instances(
|
||||||
|
self, worker_type: str, model_name: str, healthy_only: bool = True
|
||||||
|
) -> List[WorkerRunData]:
|
||||||
|
from pilot.model.cluster.worker.remote_worker import RemoteModelWorker
|
||||||
|
|
||||||
|
worker_key = self._worker_key(worker_type, model_name)
|
||||||
|
instances: List[ModelInstance] = await self.model_registry.get_all_instances(
|
||||||
|
worker_key, healthy_only
|
||||||
|
)
|
||||||
|
worker_instances = []
|
||||||
|
for ins in instances:
|
||||||
|
worker = RemoteModelWorker()
|
||||||
|
worker.load_worker(model_name, model_name, host=ins.host, port=ins.port)
|
||||||
|
wr = WorkerRunData(
|
||||||
|
host=ins.host,
|
||||||
|
port=ins.port,
|
||||||
|
worker_key=ins.model_name,
|
||||||
|
worker=worker,
|
||||||
|
worker_params=None,
|
||||||
|
model_params=None,
|
||||||
|
stop_event=asyncio.Event(),
|
||||||
|
semaphore=asyncio.Semaphore(100), # Not limit in client
|
||||||
|
)
|
||||||
|
worker_instances.append(wr)
|
||||||
|
return worker_instances
|
||||||
|
|
||||||
|
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||||
|
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
||||||
|
return await self._fetch_from_worker(
|
||||||
|
worker_run_data,
|
||||||
|
"/apply",
|
||||||
|
method="POST",
|
||||||
|
json=apply_req.dict(),
|
||||||
|
success_handler=lambda res: WorkerApplyOutput(**res.json()),
|
||||||
|
error_handler=lambda res: WorkerApplyOutput(
|
||||||
|
message=res.text, success=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await self._apply_worker(apply_req, _remote_apply_func)
|
||||||
|
if results:
|
||||||
|
return results[0]
|
@@ -3,7 +3,7 @@ from typing import Dict, Iterator, List
|
|||||||
import logging
|
import logging
|
||||||
from pilot.model.base import ModelOutput
|
from pilot.model.base import ModelOutput
|
||||||
from pilot.model.parameter import ModelParameters
|
from pilot.model.parameter import ModelParameters
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
|
|
||||||
|
|
||||||
class RemoteModelWorker(ModelWorker):
|
class RemoteModelWorker(ModelWorker):
|
@@ -2,14 +2,18 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Dict
|
||||||
|
import cachetools
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.model.base import Message
|
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
|
||||||
|
from pilot.model.base import Message, SupportedModel
|
||||||
|
from pilot.utils.parameter_utils import _get_parameter_descriptions
|
||||||
|
|
||||||
|
|
||||||
def create_chat_completion(
|
def create_chat_completion(
|
||||||
@@ -128,3 +132,49 @@ def is_partial_stop(output: str, stop_str: str):
|
|||||||
if stop_str.startswith(output[-i:]):
|
if stop_str.startswith(output[-i:]):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@cachetools.cached(cachetools.TTLCache(maxsize=100, ttl=60 * 5))
|
||||||
|
def list_supported_models():
|
||||||
|
from pilot.model.parameter import WorkerType
|
||||||
|
|
||||||
|
models = _list_supported_models(WorkerType.LLM.value, LLM_MODEL_CONFIG)
|
||||||
|
models += _list_supported_models(WorkerType.TEXT2VEC.value, EMBEDDING_MODEL_CONFIG)
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def _list_supported_models(
|
||||||
|
worker_type: str, model_config: Dict[str, str]
|
||||||
|
) -> List[SupportedModel]:
|
||||||
|
from pilot.model.adapter import get_llm_model_adapter
|
||||||
|
from pilot.model.parameter import ModelParameters
|
||||||
|
from pilot.model.loader import _get_model_real_path
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
for model_name, model_path in model_config.items():
|
||||||
|
model_path = _get_model_real_path(model_name, model_path)
|
||||||
|
model = SupportedModel(
|
||||||
|
model=model_name,
|
||||||
|
path=model_path,
|
||||||
|
worker_type=worker_type,
|
||||||
|
path_exist=False,
|
||||||
|
proxy=False,
|
||||||
|
enabled=False,
|
||||||
|
params=None,
|
||||||
|
)
|
||||||
|
if "proxyllm" in model_name:
|
||||||
|
model.proxy = True
|
||||||
|
else:
|
||||||
|
path = Path(model_path)
|
||||||
|
model.path_exist = path.exists()
|
||||||
|
param_cls = None
|
||||||
|
try:
|
||||||
|
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||||
|
param_cls = llm_adapter.model_param_class()
|
||||||
|
model.enabled = True
|
||||||
|
params = _get_parameter_descriptions(param_cls)
|
||||||
|
model.params = params
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
ret.append(model)
|
||||||
|
return ret
|
||||||
|
@@ -353,5 +353,6 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam
|
|||||||
def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters):
|
def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters):
|
||||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||||
|
|
||||||
|
logger.info("Load proxyllm")
|
||||||
model = ProxyModel(model_params)
|
model = ProxyModel(model_params)
|
||||||
return model, model
|
return model, model
|
||||||
|
@@ -33,9 +33,13 @@ class ModelControllerParameters(BaseParameters):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelWorkerParameters(BaseParameters):
|
class BaseModelParameters(BaseParameters):
|
||||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelWorkerParameters(BaseModelParameters):
|
||||||
worker_type: Optional[str] = field(
|
worker_type: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"valid_values": WorkerType.values(), "help": "Worker type"},
|
metadata={"valid_values": WorkerType.values(), "help": "Worker type"},
|
||||||
@@ -84,9 +88,7 @@ class ModelWorkerParameters(BaseParameters):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingModelParameters(BaseParameters):
|
class EmbeddingModelParameters(BaseModelParameters):
|
||||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
|
||||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
|
||||||
device: Optional[str] = field(
|
device: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -114,12 +116,6 @@ class EmbeddingModelParameters(BaseParameters):
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseModelParameters(BaseParameters):
|
|
||||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
|
||||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelParameters(BaseModelParameters):
|
class ModelParameters(BaseModelParameters):
|
||||||
device: Optional[str] = field(
|
device: Optional[str] = field(
|
||||||
|
@@ -139,7 +139,7 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
from pilot.model.worker.manager import worker_manager
|
from pilot.model.cluster import worker_manager
|
||||||
|
|
||||||
async for output in worker_manager.generate_stream(payload):
|
async for output in worker_manager.generate_stream(payload):
|
||||||
yield output
|
yield output
|
||||||
@@ -157,7 +157,7 @@ class BaseChat(ABC):
|
|||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
from pilot.model.worker.manager import worker_manager
|
from pilot.model.cluster import worker_manager
|
||||||
|
|
||||||
model_output = await worker_manager.generate(payload)
|
model_output = await worker_manager.generate(payload)
|
||||||
|
|
||||||
|
@@ -6,7 +6,7 @@ from pilot.configs.config import Config
|
|||||||
|
|
||||||
from pilot.configs.model_config import (
|
from pilot.configs.model_config import (
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
LLM_MODEL_CONFIG,
|
EMBEDDING_MODEL_CONFIG,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||||
@@ -48,7 +48,7 @@ class ChatKnowledge(BaseChat):
|
|||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
self.knowledge_embedding_client = EmbeddingEngine(
|
self.knowledge_embedding_client = EmbeddingEngine(
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -23,7 +23,7 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1
|
|||||||
from pilot.openapi.base import validation_exception_handler
|
from pilot.openapi.base import validation_exception_handler
|
||||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||||
from pilot.model.worker.manager import initialize_worker_manager_in_client
|
from pilot.model.cluster import initialize_worker_manager_in_client
|
||||||
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
||||||
|
|
||||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||||
|
@@ -7,7 +7,10 @@ from fastapi import APIRouter, File, UploadFile, Form
|
|||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import (
|
||||||
|
EMBEDDING_MODEL_CONFIG,
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
)
|
||||||
|
|
||||||
from pilot.openapi.api_view_model import Result
|
from pilot.openapi.api_view_model import Result
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
@@ -29,7 +32,9 @@ CFG = Config()
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL])
|
embeddings = HuggingFaceEmbeddings(
|
||||||
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||||
|
)
|
||||||
|
|
||||||
knowledge_space_service = KnowledgeService()
|
knowledge_space_service = KnowledgeService()
|
||||||
|
|
||||||
|
@@ -5,7 +5,10 @@ from datetime import datetime
|
|||||||
from pilot.vector_store.connector import VectorStoreConnector
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import (
|
||||||
|
EMBEDDING_MODEL_CONFIG,
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
)
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
from pilot.server.knowledge.chunk_db import (
|
from pilot.server.knowledge.chunk_db import (
|
||||||
DocumentChunkEntity,
|
DocumentChunkEntity,
|
||||||
@@ -204,7 +207,7 @@ class KnowledgeService:
|
|||||||
client = EmbeddingEngine(
|
client = EmbeddingEngine(
|
||||||
knowledge_source=doc.content,
|
knowledge_source=doc.content,
|
||||||
knowledge_type=doc.doc_type.upper(),
|
knowledge_type=doc.doc_type.upper(),
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config={
|
vector_store_config={
|
||||||
"vector_store_name": space_name,
|
"vector_store_name": space_name,
|
||||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||||
@@ -341,7 +344,7 @@ class KnowledgeService:
|
|||||||
"topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
"topk": CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||||
"recall_score": 0.0,
|
"recall_score": 0.0,
|
||||||
"recall_type": "TopK",
|
"recall_type": "TopK",
|
||||||
"model": LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL].rsplit("/", 1)[-1],
|
"model": EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL].rsplit("/", 1)[-1],
|
||||||
"chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE,
|
"chunk_size": CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||||
"chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP,
|
"chunk_overlap": CFG.KNOWLEDGE_CHUNK_OVERLAP,
|
||||||
},
|
},
|
||||||
|
@@ -9,7 +9,7 @@ sys.path.append(ROOT_PATH)
|
|||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||||
from pilot.model.worker.manager import run_worker_manager
|
from pilot.model.cluster import run_worker_manager
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
@@ -5,7 +5,7 @@ from pilot.common.schema import DBType
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import (
|
from pilot.configs.model_config import (
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
LLM_MODEL_CONFIG,
|
EMBEDDING_MODEL_CONFIG,
|
||||||
LOGDIR,
|
LOGDIR,
|
||||||
)
|
)
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
@@ -36,7 +36,7 @@ class DBSummaryClient:
|
|||||||
|
|
||||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||||
embeddings = HuggingFaceEmbeddings(
|
embeddings = HuggingFaceEmbeddings(
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||||
)
|
)
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": dbname + "_summary",
|
"vector_store_name": dbname + "_summary",
|
||||||
@@ -90,7 +90,7 @@ class DBSummaryClient:
|
|||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
knowledge_embedding_client = EmbeddingEngine(
|
knowledge_embedding_client = EmbeddingEngine(
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
)
|
)
|
||||||
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
table_docs = knowledge_embedding_client.similar_search(query, topk)
|
||||||
@@ -108,7 +108,7 @@ class DBSummaryClient:
|
|||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
knowledge_embedding_client = EmbeddingEngine(
|
knowledge_embedding_client = EmbeddingEngine(
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
)
|
)
|
||||||
if CFG.SUMMARY_CONFIG == "FAST":
|
if CFG.SUMMARY_CONFIG == "FAST":
|
||||||
@@ -134,7 +134,7 @@ class DBSummaryClient:
|
|||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
}
|
}
|
||||||
knowledge_embedding_client = EmbeddingEngine(
|
knowledge_embedding_client = EmbeddingEngine(
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=vector_store_config,
|
||||||
)
|
)
|
||||||
table_summery = knowledge_embedding_client.similar_search(query, 1)
|
table_summery = knowledge_embedding_client.similar_search(query, 1)
|
||||||
|
@@ -4,6 +4,7 @@ import subprocess
|
|||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
import psutil
|
import psutil
|
||||||
import platform
|
import platform
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
def _get_abspath_of_current_command(command_path: str):
|
def _get_abspath_of_current_command(command_path: str):
|
||||||
@@ -137,6 +138,7 @@ def _get_ports_by_cmdline_part(service_keys: List[str]) -> List[int]:
|
|||||||
return ports
|
return ports
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def _detect_controller_address() -> str:
|
def _detect_controller_address() -> str:
|
||||||
controller_addr = os.getenv("CONTROLLER_ADDRESS")
|
controller_addr = os.getenv("CONTROLLER_ADDRESS")
|
||||||
if controller_addr:
|
if controller_addr:
|
||||||
|
@@ -2,7 +2,7 @@ from typing import Type
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
|
|
||||||
def import_from_string(module_path: str):
|
def import_from_string(module_path: str, ignore_import_error: bool = False):
|
||||||
try:
|
try:
|
||||||
module_path, class_name = module_path.rsplit(".", 1)
|
module_path, class_name = module_path.rsplit(".", 1)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -12,6 +12,8 @@ def import_from_string(module_path: str):
|
|||||||
try:
|
try:
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
if ignore_import_error:
|
||||||
|
return None
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
||||||
)
|
)
|
||||||
|
@@ -1,21 +1,50 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, fields, MISSING
|
from dataclasses import dataclass, fields, MISSING, asdict, field
|
||||||
from typing import Any, List, Optional, Type, Union, Callable, Dict
|
from typing import Any, List, Optional, Type, Union, Callable, Dict
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParameterDescription:
|
class ParameterDescription:
|
||||||
|
param_class: str
|
||||||
param_name: str
|
param_name: str
|
||||||
param_type: str
|
param_type: str
|
||||||
description: str
|
|
||||||
default_value: Optional[Any]
|
default_value: Optional[Any]
|
||||||
|
description: str
|
||||||
valid_values: Optional[List[Any]]
|
valid_values: Optional[List[Any]]
|
||||||
|
ext_metadata: Dict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseParameters:
|
class BaseParameters:
|
||||||
|
@classmethod
|
||||||
|
def from_dict(
|
||||||
|
cls, data: dict, ignore_extra_fields: bool = False
|
||||||
|
) -> "BaseParameters":
|
||||||
|
"""Create an instance of the dataclass from a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: A dictionary containing values for the dataclass fields.
|
||||||
|
ignore_extra_fields: If True, any extra fields in the data dictionary that are
|
||||||
|
not part of the dataclass will be ignored.
|
||||||
|
If False, extra fields will raise an error. Defaults to False.
|
||||||
|
Returns:
|
||||||
|
An instance of the dataclass with values populated from the given dictionary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `ignore_extra_fields` is False and there are fields in the
|
||||||
|
dictionary that aren't present in the dataclass.
|
||||||
|
"""
|
||||||
|
all_field_names = {f.name for f in fields(cls)}
|
||||||
|
if ignore_extra_fields:
|
||||||
|
data = {key: value for key, value in data.items() if key in all_field_names}
|
||||||
|
else:
|
||||||
|
extra_fields = set(data.keys()) - all_field_names
|
||||||
|
if extra_fields:
|
||||||
|
raise TypeError(f"Unexpected fields: {', '.join(extra_fields)}")
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
def update_from(self, source: Union["BaseParameters", dict]) -> bool:
|
||||||
"""
|
"""
|
||||||
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
Update the attributes of this object using the values from another object (of the same or parent type) or a dictionary.
|
||||||
@@ -68,6 +97,35 @@ class BaseParameters:
|
|||||||
)
|
)
|
||||||
return "\n".join(parameters)
|
return "\n".join(parameters)
|
||||||
|
|
||||||
|
def to_command_args(self, args_prefix: str = "--") -> List[str]:
|
||||||
|
"""Convert the fields of the dataclass to a list of command line arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_prefix: args prefix
|
||||||
|
Returns:
|
||||||
|
A list of strings where each field is represented by two items:
|
||||||
|
one for the field name prefixed by args_prefix, and one for its value.
|
||||||
|
"""
|
||||||
|
return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]:
|
||||||
|
"""Convert dict to a list of command line arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: dict
|
||||||
|
Returns:
|
||||||
|
A list of strings where each field is represented by two items:
|
||||||
|
one for the field name prefixed by args_prefix, and one for its value.
|
||||||
|
"""
|
||||||
|
args = []
|
||||||
|
for key, value in obj.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
args.append(f"{args_prefix}{key}")
|
||||||
|
args.append(str(value))
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
def _get_simple_privacy_field_value(obj, field_info):
|
def _get_simple_privacy_field_value(obj, field_info):
|
||||||
"""Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary.
|
"""Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary.
|
||||||
@@ -81,9 +139,9 @@ def _get_simple_privacy_field_value(obj, field_info):
|
|||||||
- str: if length > 5, masks the middle part and returns first and last char;
|
- str: if length > 5, masks the middle part and returns first and last char;
|
||||||
otherwise, returns "******"
|
otherwise, returns "******"
|
||||||
|
|
||||||
Parameters:
|
Args:
|
||||||
- obj: The dataclass instance.
|
obj: The dataclass instance.
|
||||||
- field_info: A Field object that contains information about the dataclass field.
|
field_info: A Field object that contains information about the dataclass field.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The original or modified value of the field based on the privacy rules.
|
The original or modified value of the field based on the privacy rules.
|
||||||
@@ -203,25 +261,9 @@ class EnvArgumentParser:
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_click_option(
|
def _create_click_option_from_field(field_name: str, field: Type, is_func=True):
|
||||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
|
||||||
):
|
|
||||||
import click
|
import click
|
||||||
import functools
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
combined_fields = OrderedDict()
|
|
||||||
if _dynamic_factory:
|
|
||||||
_types = _dynamic_factory()
|
|
||||||
if _types:
|
|
||||||
dataclass_types = list(_types)
|
|
||||||
for dataclass_type in dataclass_types:
|
|
||||||
for field in fields(dataclass_type):
|
|
||||||
if field.name not in combined_fields:
|
|
||||||
combined_fields[field.name] = field
|
|
||||||
|
|
||||||
def decorator(func):
|
|
||||||
for field_name, field in reversed(combined_fields.items()):
|
|
||||||
help_text = field.metadata.get("help", "")
|
help_text = field.metadata.get("help", "")
|
||||||
valid_values = field.metadata.get("valid_values", None)
|
valid_values = field.metadata.get("valid_values", None)
|
||||||
cli_params = {
|
cli_params = {
|
||||||
@@ -241,12 +283,37 @@ class EnvArgumentParser:
|
|||||||
cli_params["type"] = click.STRING
|
cli_params["type"] = click.STRING
|
||||||
elif real_type is bool:
|
elif real_type is bool:
|
||||||
cli_params["is_flag"] = True
|
cli_params["is_flag"] = True
|
||||||
|
name = f"--{field_name}"
|
||||||
option_decorator = click.option(
|
if is_func:
|
||||||
# f"--{field_name.replace('_', '-')}", **cli_params
|
return click.option(
|
||||||
f"--{field_name}",
|
name,
|
||||||
**cli_params,
|
**cli_params,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
return click.Option([name], **cli_params)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_click_option(
|
||||||
|
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||||
|
):
|
||||||
|
import functools
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
combined_fields = OrderedDict()
|
||||||
|
if _dynamic_factory:
|
||||||
|
_types = _dynamic_factory()
|
||||||
|
if _types:
|
||||||
|
dataclass_types = list(_types)
|
||||||
|
for dataclass_type in dataclass_types:
|
||||||
|
for field in fields(dataclass_type):
|
||||||
|
if field.name not in combined_fields:
|
||||||
|
combined_fields[field.name] = field
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
for field_name, field in reversed(combined_fields.items()):
|
||||||
|
option_decorator = EnvArgumentParser._create_click_option_from_field(
|
||||||
|
field_name, field
|
||||||
|
)
|
||||||
func = option_decorator(func)
|
func = option_decorator(func)
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
@@ -257,6 +324,23 @@ class EnvArgumentParser:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_raw_click_option(
|
||||||
|
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||||
|
):
|
||||||
|
combined_fields = _merge_dataclass_types(
|
||||||
|
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||||
|
)
|
||||||
|
options = []
|
||||||
|
|
||||||
|
for field_name, field in reversed(combined_fields.items()):
|
||||||
|
options.append(
|
||||||
|
EnvArgumentParser._create_click_option_from_field(
|
||||||
|
field_name, field, is_func=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return options
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_argparse_option(
|
def create_argparse_option(
|
||||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||||
@@ -366,21 +450,70 @@ def _merge_dataclass_types(
|
|||||||
return combined_fields
|
return combined_fields
|
||||||
|
|
||||||
|
|
||||||
|
def _type_str_to_python_type(type_str: str) -> Type:
|
||||||
|
type_mapping: Dict[str, Type] = {
|
||||||
|
"int": int,
|
||||||
|
"float": float,
|
||||||
|
"bool": bool,
|
||||||
|
"str": str,
|
||||||
|
}
|
||||||
|
return type_mapping.get(type_str, str)
|
||||||
|
|
||||||
|
|
||||||
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
||||||
descriptions = []
|
descriptions = []
|
||||||
for field in fields(dataclass_type):
|
for field in fields(dataclass_type):
|
||||||
|
ext_metadata = {
|
||||||
|
k: v for k, v in field.metadata.items() if k not in ["help", "valid_values"]
|
||||||
|
}
|
||||||
|
|
||||||
descriptions.append(
|
descriptions.append(
|
||||||
ParameterDescription(
|
ParameterDescription(
|
||||||
|
param_class=f"{dataclass_type.__module__}.{dataclass_type.__name__}",
|
||||||
param_name=field.name,
|
param_name=field.name,
|
||||||
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
param_type=EnvArgumentParser._get_argparse_type_str(field.type),
|
||||||
description=field.metadata.get("help", None),
|
description=field.metadata.get("help", None),
|
||||||
default_value=field.default, # TODO handle dataclasses._MISSING_TYPE
|
default_value=field.default if field.default != MISSING else None,
|
||||||
valid_values=field.metadata.get("valid_values", None),
|
valid_values=field.metadata.get("valid_values", None),
|
||||||
|
ext_metadata=ext_metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return descriptions
|
return descriptions
|
||||||
|
|
||||||
|
|
||||||
|
def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
||||||
|
from pilot.utils.module_utils import import_from_string
|
||||||
|
|
||||||
|
if not desc:
|
||||||
|
raise ValueError("Parameter descriptions cant be empty")
|
||||||
|
param_class_str = desc[0].param_class
|
||||||
|
param_class = import_from_string(param_class_str, ignore_import_error=True)
|
||||||
|
if param_class:
|
||||||
|
return param_class
|
||||||
|
module_name, _, class_name = param_class_str.rpartition(".")
|
||||||
|
|
||||||
|
fields_dict = {} # This will store field names and their default values or field()
|
||||||
|
annotations = {} # This will store the type annotations for the fields
|
||||||
|
|
||||||
|
for d in desc:
|
||||||
|
metadata = d.ext_metadata if d.ext_metadata else {}
|
||||||
|
metadata["help"] = d.description
|
||||||
|
metadata["valid_values"] = d.valid_values
|
||||||
|
|
||||||
|
annotations[d.param_name] = _type_str_to_python_type(
|
||||||
|
d.param_type
|
||||||
|
) # Set type annotation
|
||||||
|
fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata)
|
||||||
|
|
||||||
|
# Create the new class. Note the setting of __annotations__ for type hints
|
||||||
|
new_class = type(
|
||||||
|
class_name, (object,), {**fields_dict, "__annotations__": annotations}
|
||||||
|
)
|
||||||
|
result_class = dataclass(new_class) # Make it a dataclass
|
||||||
|
|
||||||
|
return result_class
|
||||||
|
|
||||||
|
|
||||||
class _SimpleArgParser:
|
class _SimpleArgParser:
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
self.params = {arg.replace("_", "-"): None for arg in args}
|
self.params = {arg.replace("_", "-"): None for arg in args}
|
||||||
@@ -422,3 +555,24 @@ class _SimpleArgParser:
|
|||||||
return "\n".join(
|
return "\n".join(
|
||||||
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
|
[f'{key.replace("-", "_")}: {value}' for key, value in self.params.items()]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_lazy_click_command(*dataclass_types: Type, _dynamic_factory=None):
|
||||||
|
import click
|
||||||
|
|
||||||
|
class LazyCommand(click.Command):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(LazyCommand, self).__init__(*args, **kwargs)
|
||||||
|
self.dynamic_params_added = False
|
||||||
|
|
||||||
|
def get_params(self, ctx):
|
||||||
|
if ctx and not self.dynamic_params_added:
|
||||||
|
dynamic_params = EnvArgumentParser._create_raw_click_option(
|
||||||
|
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||||
|
)
|
||||||
|
for param in reversed(dynamic_params):
|
||||||
|
self.params.append(param)
|
||||||
|
self.dynamic_params_added = True
|
||||||
|
return super(LazyCommand, self).get_params(ctx)
|
||||||
|
|
||||||
|
return LazyCommand
|
||||||
|
@@ -29,10 +29,10 @@ def knownledge_tovec_st(filename):
|
|||||||
"""Use sentence transformers to embedding the document.
|
"""Use sentence transformers to embedding the document.
|
||||||
https://github.com/UKPLab/sentence-transformers
|
https://github.com/UKPLab/sentence-transformers
|
||||||
"""
|
"""
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||||
|
|
||||||
embeddings = HuggingFaceEmbeddings(
|
embeddings = HuggingFaceEmbeddings(
|
||||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
|
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
@@ -57,10 +57,10 @@ def load_knownledge_from_doc():
|
|||||||
"Not Exists Local DataSets, We will answers the Question use model default."
|
"Not Exists Local DataSets, We will answers the Question use model default."
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||||
|
|
||||||
embeddings = HuggingFaceEmbeddings(
|
embeddings = HuggingFaceEmbeddings(
|
||||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"]
|
model_name=EMBEDDING_MODEL_CONFIG["sentence-transforms"]
|
||||||
)
|
)
|
||||||
|
|
||||||
files = os.listdir(DATASETS_DIR)
|
files = os.listdir(DATASETS_DIR)
|
||||||
|
@@ -16,7 +16,7 @@ from langchain.vectorstores import Chroma
|
|||||||
|
|
||||||
from pilot.configs.model_config import (
|
from pilot.configs.model_config import (
|
||||||
DATASETS_DIR,
|
DATASETS_DIR,
|
||||||
LLM_MODEL_CONFIG,
|
EMBEDDING_MODEL_CONFIG,
|
||||||
VECTORE_PATH,
|
VECTORE_PATH,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class KnownLedge2Vector:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
embeddings: object = None
|
embeddings: object = None
|
||||||
model_name = LLM_MODEL_CONFIG["sentence-transforms"]
|
model_name = EMBEDDING_MODEL_CONFIG["sentence-transforms"]
|
||||||
|
|
||||||
def __init__(self, model_name=None) -> None:
|
def __init__(self, model_name=None) -> None:
|
||||||
if not model_name:
|
if not model_name:
|
||||||
|
@@ -80,3 +80,4 @@ duckdb-engine
|
|||||||
|
|
||||||
# cli
|
# cli
|
||||||
prettytable
|
prettytable
|
||||||
|
cachetools
|
Reference in New Issue
Block a user