feat: Multi-model command line

This commit is contained in:
FangYin Cheng 2023-08-30 11:07:35 +08:00
parent d467092766
commit dd86fb86b1
20 changed files with 317 additions and 35 deletions

151
pilot/model/cli.py Normal file
View File

@ -0,0 +1,151 @@
import click
import functools
from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.worker.manager import (
RemoteWorkerManager,
WorkerApplyRequest,
WorkerApplyType,
)
from pilot.utils import get_or_create_event_loop
@click.group("model")
def model_cli_group():
pass
@model_cli_group.command()
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
help=(
"Address of the Model Controller to connect to."
"Just support light deploy model"
),
)
@click.option(
"--model-name", type=str, default=None, required=False, help=("The name of model")
)
@click.option(
"--model-type", type=str, default="llm", required=False, help=("The type of model")
)
def list(address: str, model_name: str, model_type: str):
"""List model instances"""
from prettytable import PrettyTable
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
if not model_name:
instances = loop.run_until_complete(registry.get_all_model_instances())
else:
if not model_type:
model_type = "llm"
register_model_name = f"{model_name}@{model_type}"
instances = loop.run_until_complete(
registry.get_all_instances(register_model_name)
)
table = PrettyTable()
table.field_names = [
"Model Name",
"Model Type",
"Host",
"Port",
"Healthy",
"Enabled",
"Prompt Template",
"Last Heartbeat",
]
for instance in instances:
model_name, model_type = instance.model_name.split("@")
table.add_row(
[
model_name,
model_type,
instance.host,
instance.port,
instance.healthy,
instance.enabled,
instance.prompt_template,
instance.last_heartbeat,
]
)
print(table)
def add_model_options(func):
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
help=(
"Address of the Model Controller to connect to."
"Just support light deploy model"
),
)
@click.option(
"--model-name",
type=str,
default=None,
required=True,
help=("The name of model"),
)
@click.option(
"--model-type",
type=str,
default="llm",
required=False,
help=("The type of model"),
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@model_cli_group.command()
@add_model_options
def stop(address: str, model_name: str, model_type: str):
"""Stop model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
@model_cli_group.command()
@add_model_options
def start(address: str, model_name: str, model_type: str):
"""Start model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.START)
@model_cli_group.command()
@add_model_options
def restart(address: str, model_name: str, model_type: str):
"""Restart model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
# @model_cli_group.command()
# @add_model_options
# def modify(address: str, model_name: str, model_type: str):
# """Restart model instances"""
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
):
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
apply_req = WorkerApplyRequest(
model=model_name, worker_type=model_type, apply_type=apply_type
)
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
print(res)

View File

@ -27,6 +27,9 @@ class ModelController:
)
return await self.registry.get_all_instances(model_name, healthy_only)
async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.registry.get_all_model_instances()
async def send_heartbeat(self, instance: ModelInstance) -> bool:
return await self.registry.send_heartbeat(instance)
@ -51,10 +54,12 @@ async def api_deregister_instance(request: ModelInstance):
@router.get("/controller/models")
async def api_get_all_instances(model_name: str, healthy_only: bool = False):
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
if not model_name:
return await controller.get_all_model_instances()
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
@router.post("/controller/heartbeat")
async def api_get_all_instances(request: ModelInstance):
async def api_model_heartbeat(request: ModelInstance):
return await controller.send_heartbeat(request)

View File

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
import itertools
from pilot.model.base import ModelInstance
@ -57,6 +58,15 @@ class ModelRegistry(ABC):
- List[ModelInstance]: A list of instances for the given model.
"""
@abstractmethod
async def get_all_model_instances(self) -> List[ModelInstance]:
"""
Fetch all instances of all models
Returns:
- List[ModelInstance]: A list of instances for the all models.
"""
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
"""
Selects one healthy and enabled instance for a given model.
@ -154,12 +164,15 @@ class EmbeddedModelRegistry(ModelRegistry):
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
print(self.registry)
instances = self.registry[model_name]
if healthy_only:
instances = [ins for ins in instances if ins.healthy == True]
return instances
async def get_all_model_instances(self) -> List[ModelInstance]:
print(self.registry)
return list(itertools.chain(*self.registry.values()))
async def send_heartbeat(self, instance: ModelInstance) -> bool:
_, exist_ins = self._get_instances(
instance.model_name, instance.host, instance.port, healthy_only=False
@ -194,6 +207,10 @@ class ModelRegistryClient(ModelRegistry):
) -> List[ModelInstance]:
pass
@api_remote(path="/api/controller/models")
async def get_all_model_instances(self) -> List[ModelInstance]:
pass
@api_remote(path="/api/controller/models")
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
instances = await self.get_all_instances(model_name, healthy_only=True)

View File

@ -118,7 +118,6 @@ class ModelLoader:
def loader_with_params(self, model_params: ModelParameters):
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = llm_adapter.model_type()
param_cls = llm_adapter.model_param_class(model_type)
self.prompt_template = model_params.prompt_template
logger.info(f"model_params:\n{model_params}")
if model_type == ModelType.HF:

View File

@ -184,6 +184,19 @@ class BaseParameters:
return updated
def __str__(self) -> str:
class_name = self.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(self):
value = getattr(self, field_info.name)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
@dataclass
class ModelWorkerParameters(BaseParameters):

View File

@ -4,12 +4,12 @@ from typing import Dict, Iterator, List
import torch
from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import EnvArgumentParser, ModelParameters
from pilot.model.worker.base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
logger = logging.getLogger("model_worker")
@ -20,6 +20,8 @@ class DefaultModelWorker(ModelWorker):
self.model = None
self.tokenizer = None
self._model_params = None
self.llm_adapter: BaseLLMAdaper = None
self.llm_chat_adapter: BaseChatAdpter = None
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
@ -28,9 +30,9 @@ class DefaultModelWorker(ModelWorker):
self.model_name = model_name
self.model_path = model_path
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = llm_adapter.model_type()
self.param_cls = llm_adapter.model_param_class(model_type)
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = self.llm_adapter.model_type()
self.param_cls = self.llm_adapter.model_param_class(model_type)
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
@ -50,12 +52,14 @@ class DefaultModelWorker(ModelWorker):
param_cls = self.model_param_class()
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
model_type = self.llm_adapter.model_type()
model_params: ModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,
model_name=self.model_name,
model_path=self.model_path,
model_type=model_type,
)
if not model_params.device:
model_params.device = DEVICE

View File

@ -1,4 +1,5 @@
import asyncio
import httpx
import itertools
import json
import os
@ -147,7 +148,7 @@ class LocalWorkerManager(WorkerManager):
self.model_registry = model_registry
def _worker_key(self, worker_type: str, model_name: str) -> str:
return f"$${worker_type}_$$_{model_name}"
return f"{model_name}@{worker_type}"
def add_worker(
self,
@ -311,7 +312,9 @@ class LocalWorkerManager(WorkerManager):
# Apply to all workers
worker_instances = list(itertools.chain(*self.workers.values()))
logger.info(f"Apply to all workers: {worker_instances}")
await asyncio.gather(*(apply_func(worker) for worker in worker_instances))
return await asyncio.gather(
*(apply_func(worker) for worker in worker_instances)
)
async def _start_all_worker(
self, apply_req: WorkerApplyRequest
@ -423,6 +426,28 @@ class RemoteWorkerManager(LocalWorkerManager):
worker_instances.append(wr)
return worker_instances
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
async def _remote_apply_func(worker_run_data: WorkerRunData):
worker_addr = worker_run_data.worker.worker_addr
async with httpx.AsyncClient() as client:
response = await client.post(
worker_addr + "/apply",
headers=worker_run_data.worker.headers,
json=apply_req.dict(),
timeout=worker_run_data.worker.timeout,
)
if response.status_code == 200:
output = WorkerApplyOutput(**response.json())
logger.info(f"worker_apply success: {output}")
else:
output = WorkerApplyOutput(message=response.text)
logger.warn(f"worker_apply failed: {output}")
return output
results = await self._apply_worker(apply_req, _remote_apply_func)
if results:
return results[0]
class WorkerManagerAdapter(WorkerManager):
def __init__(self, worker_manager: WorkerManager = None) -> None:

View File

@ -28,10 +28,7 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import (
build_logger,
server_error_msg,
)
from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop
from pilot.scene.base_message import (
BaseMessage,
SystemMessage,
@ -222,13 +219,10 @@ class BaseChat(ABC):
return self.current_ai_response()
def _blocking_stream_call(self):
import asyncio
logger.warn(
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = get_or_create_event_loop()
async_gen = self.stream_call()
while True:
try:
@ -238,13 +232,10 @@ class BaseChat(ABC):
break
def _blocking_nostream_call(self):
import asyncio
logger.warn(
"_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance"
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.nostream_call())
def call(self):

View File

View File

@ -0,0 +1,45 @@
import sys
import click
import os
import copy
import logging
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
)
@click.group()
@click.option(
"--log-level",
required=False,
type=str,
default="warn",
help="Log level",
)
@click.version_option()
def cli(log_level: str):
# TODO not working now
logging.basicConfig(level=log_level, encoding="utf-8")
def add_command_alias(command, name: str, hidden: bool = False):
new_command = copy.deepcopy(command)
new_command.hidden = hidden
cli.add_command(new_command, name=name)
try:
from pilot.model.cli import model_cli_group
add_command_alias(model_cli_group, name="model")
except ImportError as e:
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
def main():
return cli()
if __name__ == "__main__":
main()

View File

@ -75,14 +75,20 @@ app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(knowledge_router)
# app.include_router(api_editor_route_v1)
def mount_static_files(app):
os.makedirs(static_message_img_path, exist_ok=True)
app.mount(
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2"
"/images",
StaticFiles(directory=static_message_img_path, html=True),
name="static2",
)
app.mount(
"/_next/static", StaticFiles(directory=static_file_path + "/_next/static")
)
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__":

View File

@ -2,7 +2,6 @@ import logging
import os
import requests
from playsound import playsound
from pilot.speech.base import VoiceBase
@ -23,6 +22,8 @@ class BrianSpeech(VoiceBase):
Returns:
bool: True if the request was successful, False otherwise
"""
from playsound import playsound
tts_url = (
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
)

View File

@ -2,7 +2,6 @@
import os
import requests
from playsound import playsound
from pilot.configs.config import Config
from pilot.speech.base import VoiceBase
@ -70,6 +69,7 @@ class ElevenLabsSpeech(VoiceBase):
bool: True if the request was successful, False otherwise
"""
from pilot.logs import logger
from playsound import playsound
tts_url = (
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"

View File

@ -2,7 +2,6 @@
import os
import gtts
from playsound import playsound
from pilot.speech.base import VoiceBase
@ -15,6 +14,8 @@ class GTTSVoice(VoiceBase):
def _speech(self, text: str, _: int = 0) -> bool:
"""Play the given text."""
from playsound import playsound
tts = gtts.gTTS(text)
tts.save("speech.mp3")
playsound("speech.mp3", True)

View File

@ -5,4 +5,5 @@ from .utils import (
disable_torch_init,
pretty_print_semaphore,
server_error_msg,
get_or_create_event_loop,
)

View File

@ -1,6 +1,7 @@
import httpx
from inspect import signature
import typing_inspect
import logging
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
from dataclasses import is_dataclass, asdict
@ -21,7 +22,9 @@ def _api_remote(path, method="GET"):
raise TypeError("Return type must be annotated in the decorated function.")
actual_dataclass = _extract_dataclass_from_generic(return_type)
print(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
logging.debug(
f"return_type: {return_type}, actual_dataclass: {actual_dataclass}"
)
if not actual_dataclass:
actual_dataclass = return_type
sig = signature(func)
@ -57,7 +60,9 @@ def _api_remote(path, method="GET"):
else: # For GET, DELETE, etc.
request_params["params"] = request_data
print(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
logging.info(
f"request_params: {request_params}, args: {args}, kwargs: {kwargs}"
)
async with httpx.AsyncClient() as client:
response = await client.request(**request_params)

View File

@ -5,8 +5,8 @@ import logging
import logging.handlers
import os
import sys
import asyncio
import torch
from pilot.configs.model_config import LOGDIR
server_error_msg = (
@ -17,6 +17,8 @@ handler = None
def get_gpu_memory(max_gpus=None):
import torch
gpu_memory = []
num_gpus = (
torch.cuda.device_count()
@ -130,3 +132,15 @@ def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
try:
loop = asyncio.get_event_loop()
except Exception as e:
if not "no running event loop" in str(e):
raise e
logging.warning("Cant not get running event loop, create new event loop now")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop

View File

@ -76,4 +76,7 @@ bardapi==0.1.29
# TODO moved to optional dependencies
pymysql
duckdb
duckdb-engine
duckdb-engine
# cli
prettytable

View File

@ -267,7 +267,7 @@ def llama_cpp_python_cuda_requires():
llama_cpp_version = "0.1.77"
py_version = "cp310"
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl"
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
extra_index_url, _ = encode_url(extra_index_url)
print(f"Install llama_cpp_python_cuda from {extra_index_url}")
@ -361,7 +361,7 @@ setuptools.setup(
extras_require=setup_spec.extras,
entry_points={
"console_scripts": [
"dbgpt_server=pilot.server:webserver",
"dbgpt=pilot.scripts.cli_scripts:main",
],
},
)

View File

@ -25,7 +25,6 @@ sys.path.append(
from pilot.configs.model_config import DATASETS_DIR
from tools.cli.knowledge_client import knowledge_init
API_ADDRESS: str = "http://127.0.0.1:5000"
@ -97,6 +96,8 @@ def knowledge(
verbose: bool,
):
"""Knowledge command line tool"""
from tools.cli.knowledge_client import knowledge_init
knowledge_init(
API_ADDRESS,
vector_name,