mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-02 08:40:36 +00:00
feat: Multi-model command line
This commit is contained in:
parent
d467092766
commit
dd86fb86b1
151
pilot/model/cli.py
Normal file
151
pilot/model/cli.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
import click
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from pilot.model.controller.registry import ModelRegistryClient
|
||||||
|
from pilot.model.worker.manager import (
|
||||||
|
RemoteWorkerManager,
|
||||||
|
WorkerApplyRequest,
|
||||||
|
WorkerApplyType,
|
||||||
|
)
|
||||||
|
from pilot.utils import get_or_create_event_loop
|
||||||
|
|
||||||
|
|
||||||
|
@click.group("model")
|
||||||
|
def model_cli_group():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@model_cli_group.command()
|
||||||
|
@click.option(
|
||||||
|
"--address",
|
||||||
|
type=str,
|
||||||
|
default="http://127.0.0.1:8000",
|
||||||
|
required=False,
|
||||||
|
help=(
|
||||||
|
"Address of the Model Controller to connect to."
|
||||||
|
"Just support light deploy model"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model-name", type=str, default=None, required=False, help=("The name of model")
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model-type", type=str, default="llm", required=False, help=("The type of model")
|
||||||
|
)
|
||||||
|
def list(address: str, model_name: str, model_type: str):
|
||||||
|
"""List model instances"""
|
||||||
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
|
loop = get_or_create_event_loop()
|
||||||
|
registry = ModelRegistryClient(address)
|
||||||
|
|
||||||
|
if not model_name:
|
||||||
|
instances = loop.run_until_complete(registry.get_all_model_instances())
|
||||||
|
else:
|
||||||
|
if not model_type:
|
||||||
|
model_type = "llm"
|
||||||
|
register_model_name = f"{model_name}@{model_type}"
|
||||||
|
instances = loop.run_until_complete(
|
||||||
|
registry.get_all_instances(register_model_name)
|
||||||
|
)
|
||||||
|
table = PrettyTable()
|
||||||
|
|
||||||
|
table.field_names = [
|
||||||
|
"Model Name",
|
||||||
|
"Model Type",
|
||||||
|
"Host",
|
||||||
|
"Port",
|
||||||
|
"Healthy",
|
||||||
|
"Enabled",
|
||||||
|
"Prompt Template",
|
||||||
|
"Last Heartbeat",
|
||||||
|
]
|
||||||
|
for instance in instances:
|
||||||
|
model_name, model_type = instance.model_name.split("@")
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
model_name,
|
||||||
|
model_type,
|
||||||
|
instance.host,
|
||||||
|
instance.port,
|
||||||
|
instance.healthy,
|
||||||
|
instance.enabled,
|
||||||
|
instance.prompt_template,
|
||||||
|
instance.last_heartbeat,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_options(func):
|
||||||
|
@click.option(
|
||||||
|
"--address",
|
||||||
|
type=str,
|
||||||
|
default="http://127.0.0.1:8000",
|
||||||
|
required=False,
|
||||||
|
help=(
|
||||||
|
"Address of the Model Controller to connect to."
|
||||||
|
"Just support light deploy model"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help=("The name of model"),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model-type",
|
||||||
|
type=str,
|
||||||
|
default="llm",
|
||||||
|
required=False,
|
||||||
|
help=("The type of model"),
|
||||||
|
)
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@model_cli_group.command()
|
||||||
|
@add_model_options
|
||||||
|
def stop(address: str, model_name: str, model_type: str):
|
||||||
|
"""Stop model instances"""
|
||||||
|
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
|
||||||
|
|
||||||
|
|
||||||
|
@model_cli_group.command()
|
||||||
|
@add_model_options
|
||||||
|
def start(address: str, model_name: str, model_type: str):
|
||||||
|
"""Start model instances"""
|
||||||
|
worker_apply(address, model_name, model_type, WorkerApplyType.START)
|
||||||
|
|
||||||
|
|
||||||
|
@model_cli_group.command()
|
||||||
|
@add_model_options
|
||||||
|
def restart(address: str, model_name: str, model_type: str):
|
||||||
|
"""Restart model instances"""
|
||||||
|
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
|
||||||
|
|
||||||
|
|
||||||
|
# @model_cli_group.command()
|
||||||
|
# @add_model_options
|
||||||
|
# def modify(address: str, model_name: str, model_type: str):
|
||||||
|
# """Restart model instances"""
|
||||||
|
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
|
||||||
|
|
||||||
|
|
||||||
|
def worker_apply(
|
||||||
|
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||||
|
):
|
||||||
|
loop = get_or_create_event_loop()
|
||||||
|
registry = ModelRegistryClient(address)
|
||||||
|
worker_manager = RemoteWorkerManager(registry)
|
||||||
|
apply_req = WorkerApplyRequest(
|
||||||
|
model=model_name, worker_type=model_type, apply_type=apply_type
|
||||||
|
)
|
||||||
|
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
|
||||||
|
print(res)
|
@ -27,6 +27,9 @@ class ModelController:
|
|||||||
)
|
)
|
||||||
return await self.registry.get_all_instances(model_name, healthy_only)
|
return await self.registry.get_all_instances(model_name, healthy_only)
|
||||||
|
|
||||||
|
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||||
|
return await self.registry.get_all_model_instances()
|
||||||
|
|
||||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||||
return await self.registry.send_heartbeat(instance)
|
return await self.registry.send_heartbeat(instance)
|
||||||
|
|
||||||
@ -51,10 +54,12 @@ async def api_deregister_instance(request: ModelInstance):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/controller/models")
|
@router.get("/controller/models")
|
||||||
async def api_get_all_instances(model_name: str, healthy_only: bool = False):
|
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
||||||
|
if not model_name:
|
||||||
|
return await controller.get_all_model_instances()
|
||||||
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/controller/heartbeat")
|
@router.post("/controller/heartbeat")
|
||||||
async def api_get_all_instances(request: ModelInstance):
|
async def api_model_heartbeat(request: ModelInstance):
|
||||||
return await controller.send_heartbeat(request)
|
return await controller.send_heartbeat(request)
|
||||||
|
@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
import itertools
|
||||||
|
|
||||||
from pilot.model.base import ModelInstance
|
from pilot.model.base import ModelInstance
|
||||||
|
|
||||||
@ -57,6 +58,15 @@ class ModelRegistry(ABC):
|
|||||||
- List[ModelInstance]: A list of instances for the given model.
|
- List[ModelInstance]: A list of instances for the given model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||||
|
"""
|
||||||
|
Fetch all instances of all models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- List[ModelInstance]: A list of instances for the all models.
|
||||||
|
"""
|
||||||
|
|
||||||
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
||||||
"""
|
"""
|
||||||
Selects one healthy and enabled instance for a given model.
|
Selects one healthy and enabled instance for a given model.
|
||||||
@ -154,12 +164,15 @@ class EmbeddedModelRegistry(ModelRegistry):
|
|||||||
async def get_all_instances(
|
async def get_all_instances(
|
||||||
self, model_name: str, healthy_only: bool = False
|
self, model_name: str, healthy_only: bool = False
|
||||||
) -> List[ModelInstance]:
|
) -> List[ModelInstance]:
|
||||||
print(self.registry)
|
|
||||||
instances = self.registry[model_name]
|
instances = self.registry[model_name]
|
||||||
if healthy_only:
|
if healthy_only:
|
||||||
instances = [ins for ins in instances if ins.healthy == True]
|
instances = [ins for ins in instances if ins.healthy == True]
|
||||||
return instances
|
return instances
|
||||||
|
|
||||||
|
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||||
|
print(self.registry)
|
||||||
|
return list(itertools.chain(*self.registry.values()))
|
||||||
|
|
||||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||||
_, exist_ins = self._get_instances(
|
_, exist_ins = self._get_instances(
|
||||||
instance.model_name, instance.host, instance.port, healthy_only=False
|
instance.model_name, instance.host, instance.port, healthy_only=False
|
||||||
@ -194,6 +207,10 @@ class ModelRegistryClient(ModelRegistry):
|
|||||||
) -> List[ModelInstance]:
|
) -> List[ModelInstance]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@api_remote(path="/api/controller/models")
|
||||||
|
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||||
|
pass
|
||||||
|
|
||||||
@api_remote(path="/api/controller/models")
|
@api_remote(path="/api/controller/models")
|
||||||
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
||||||
instances = await self.get_all_instances(model_name, healthy_only=True)
|
instances = await self.get_all_instances(model_name, healthy_only=True)
|
||||||
|
@ -118,7 +118,6 @@ class ModelLoader:
|
|||||||
def loader_with_params(self, model_params: ModelParameters):
|
def loader_with_params(self, model_params: ModelParameters):
|
||||||
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||||
model_type = llm_adapter.model_type()
|
model_type = llm_adapter.model_type()
|
||||||
param_cls = llm_adapter.model_param_class(model_type)
|
|
||||||
self.prompt_template = model_params.prompt_template
|
self.prompt_template = model_params.prompt_template
|
||||||
logger.info(f"model_params:\n{model_params}")
|
logger.info(f"model_params:\n{model_params}")
|
||||||
if model_type == ModelType.HF:
|
if model_type == ModelType.HF:
|
||||||
|
@ -184,6 +184,19 @@ class BaseParameters:
|
|||||||
|
|
||||||
return updated
|
return updated
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
class_name = self.__class__.__name__
|
||||||
|
parameters = [
|
||||||
|
f"\n\n=========================== {class_name} ===========================\n"
|
||||||
|
]
|
||||||
|
for field_info in fields(self):
|
||||||
|
value = getattr(self, field_info.name)
|
||||||
|
parameters.append(f"{field_info.name}: {value}")
|
||||||
|
parameters.append(
|
||||||
|
"\n======================================================================\n\n"
|
||||||
|
)
|
||||||
|
return "\n".join(parameters)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelWorkerParameters(BaseParameters):
|
class ModelWorkerParameters(BaseParameters):
|
||||||
|
@ -4,12 +4,12 @@ from typing import Dict, Iterator, List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import DEVICE
|
||||||
from pilot.model.adapter import get_llm_model_adapter
|
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||||
from pilot.model.base import ModelOutput
|
from pilot.model.base import ModelOutput
|
||||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||||
from pilot.model.parameter import EnvArgumentParser, ModelParameters
|
from pilot.model.parameter import EnvArgumentParser, ModelParameters
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.worker.base import ModelWorker
|
||||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_torch_cache
|
||||||
|
|
||||||
logger = logging.getLogger("model_worker")
|
logger = logging.getLogger("model_worker")
|
||||||
@ -20,6 +20,8 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self._model_params = None
|
self._model_params = None
|
||||||
|
self.llm_adapter: BaseLLMAdaper = None
|
||||||
|
self.llm_chat_adapter: BaseChatAdpter = None
|
||||||
|
|
||||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||||
if model_path.endswith("/"):
|
if model_path.endswith("/"):
|
||||||
@ -28,9 +30,9 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
|
||||||
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||||
model_type = llm_adapter.model_type()
|
model_type = self.llm_adapter.model_type()
|
||||||
self.param_cls = llm_adapter.model_param_class(model_type)
|
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
||||||
|
|
||||||
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
|
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
|
||||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
||||||
@ -50,12 +52,14 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
param_cls = self.model_param_class()
|
param_cls = self.model_param_class()
|
||||||
model_args = EnvArgumentParser()
|
model_args = EnvArgumentParser()
|
||||||
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
|
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
|
||||||
|
model_type = self.llm_adapter.model_type()
|
||||||
model_params: ModelParameters = model_args.parse_args_into_dataclass(
|
model_params: ModelParameters = model_args.parse_args_into_dataclass(
|
||||||
param_cls,
|
param_cls,
|
||||||
env_prefix=env_prefix,
|
env_prefix=env_prefix,
|
||||||
command_args=command_args,
|
command_args=command_args,
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
model_path=self.model_path,
|
model_path=self.model_path,
|
||||||
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
if not model_params.device:
|
if not model_params.device:
|
||||||
model_params.device = DEVICE
|
model_params.device = DEVICE
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import httpx
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -147,7 +148,7 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
self.model_registry = model_registry
|
self.model_registry = model_registry
|
||||||
|
|
||||||
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
||||||
return f"$${worker_type}_$$_{model_name}"
|
return f"{model_name}@{worker_type}"
|
||||||
|
|
||||||
def add_worker(
|
def add_worker(
|
||||||
self,
|
self,
|
||||||
@ -311,7 +312,9 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
# Apply to all workers
|
# Apply to all workers
|
||||||
worker_instances = list(itertools.chain(*self.workers.values()))
|
worker_instances = list(itertools.chain(*self.workers.values()))
|
||||||
logger.info(f"Apply to all workers: {worker_instances}")
|
logger.info(f"Apply to all workers: {worker_instances}")
|
||||||
await asyncio.gather(*(apply_func(worker) for worker in worker_instances))
|
return await asyncio.gather(
|
||||||
|
*(apply_func(worker) for worker in worker_instances)
|
||||||
|
)
|
||||||
|
|
||||||
async def _start_all_worker(
|
async def _start_all_worker(
|
||||||
self, apply_req: WorkerApplyRequest
|
self, apply_req: WorkerApplyRequest
|
||||||
@ -423,6 +426,28 @@ class RemoteWorkerManager(LocalWorkerManager):
|
|||||||
worker_instances.append(wr)
|
worker_instances.append(wr)
|
||||||
return worker_instances
|
return worker_instances
|
||||||
|
|
||||||
|
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||||
|
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
||||||
|
worker_addr = worker_run_data.worker.worker_addr
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
worker_addr + "/apply",
|
||||||
|
headers=worker_run_data.worker.headers,
|
||||||
|
json=apply_req.dict(),
|
||||||
|
timeout=worker_run_data.worker.timeout,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
output = WorkerApplyOutput(**response.json())
|
||||||
|
logger.info(f"worker_apply success: {output}")
|
||||||
|
else:
|
||||||
|
output = WorkerApplyOutput(message=response.text)
|
||||||
|
logger.warn(f"worker_apply failed: {output}")
|
||||||
|
return output
|
||||||
|
|
||||||
|
results = await self._apply_worker(apply_req, _remote_apply_func)
|
||||||
|
if results:
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
|
||||||
class WorkerManagerAdapter(WorkerManager):
|
class WorkerManagerAdapter(WorkerManager):
|
||||||
def __init__(self, worker_manager: WorkerManager = None) -> None:
|
def __init__(self, worker_manager: WorkerManager = None) -> None:
|
||||||
|
@ -28,10 +28,7 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
|||||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||||
from pilot.utils import (
|
from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop
|
||||||
build_logger,
|
|
||||||
server_error_msg,
|
|
||||||
)
|
|
||||||
from pilot.scene.base_message import (
|
from pilot.scene.base_message import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
@ -222,13 +219,10 @@ class BaseChat(ABC):
|
|||||||
return self.current_ai_response()
|
return self.current_ai_response()
|
||||||
|
|
||||||
def _blocking_stream_call(self):
|
def _blocking_stream_call(self):
|
||||||
import asyncio
|
|
||||||
|
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
|
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
|
||||||
)
|
)
|
||||||
loop = asyncio.new_event_loop()
|
loop = get_or_create_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
async_gen = self.stream_call()
|
async_gen = self.stream_call()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -238,13 +232,10 @@ class BaseChat(ABC):
|
|||||||
break
|
break
|
||||||
|
|
||||||
def _blocking_nostream_call(self):
|
def _blocking_nostream_call(self):
|
||||||
import asyncio
|
|
||||||
|
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance"
|
"_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance"
|
||||||
)
|
)
|
||||||
loop = asyncio.new_event_loop()
|
loop = get_or_create_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
return loop.run_until_complete(self.nostream_call())
|
return loop.run_until_complete(self.nostream_call())
|
||||||
|
|
||||||
def call(self):
|
def call(self):
|
||||||
|
0
pilot/scripts/__init__.py
Normal file
0
pilot/scripts/__init__.py
Normal file
45
pilot/scripts/cli_scripts.py
Normal file
45
pilot/scripts/cli_scripts.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import sys
|
||||||
|
import click
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
|
||||||
|
sys.path.append(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
@click.option(
|
||||||
|
"--log-level",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default="warn",
|
||||||
|
help="Log level",
|
||||||
|
)
|
||||||
|
@click.version_option()
|
||||||
|
def cli(log_level: str):
|
||||||
|
# TODO not working now
|
||||||
|
logging.basicConfig(level=log_level, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def add_command_alias(command, name: str, hidden: bool = False):
|
||||||
|
new_command = copy.deepcopy(command)
|
||||||
|
new_command.hidden = hidden
|
||||||
|
cli.add_command(new_command, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pilot.model.cli import model_cli_group
|
||||||
|
|
||||||
|
add_command_alias(model_cli_group, name="model")
|
||||||
|
except ImportError as e:
|
||||||
|
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
return cli()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -75,14 +75,20 @@ app.include_router(api_editor_route_v1, prefix="/api")
|
|||||||
app.include_router(knowledge_router)
|
app.include_router(knowledge_router)
|
||||||
# app.include_router(api_editor_route_v1)
|
# app.include_router(api_editor_route_v1)
|
||||||
|
|
||||||
|
|
||||||
def mount_static_files(app):
|
def mount_static_files(app):
|
||||||
os.makedirs(static_message_img_path, exist_ok=True)
|
os.makedirs(static_message_img_path, exist_ok=True)
|
||||||
app.mount(
|
app.mount(
|
||||||
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2"
|
"/images",
|
||||||
|
StaticFiles(directory=static_message_img_path, html=True),
|
||||||
|
name="static2",
|
||||||
|
)
|
||||||
|
app.mount(
|
||||||
|
"/_next/static", StaticFiles(directory=static_file_path + "/_next/static")
|
||||||
)
|
)
|
||||||
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
|
|
||||||
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
||||||
|
|
||||||
|
|
||||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -2,7 +2,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from playsound import playsound
|
|
||||||
|
|
||||||
from pilot.speech.base import VoiceBase
|
from pilot.speech.base import VoiceBase
|
||||||
|
|
||||||
@ -23,6 +22,8 @@ class BrianSpeech(VoiceBase):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if the request was successful, False otherwise
|
bool: True if the request was successful, False otherwise
|
||||||
"""
|
"""
|
||||||
|
from playsound import playsound
|
||||||
|
|
||||||
tts_url = (
|
tts_url = (
|
||||||
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
|
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from playsound import playsound
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.speech.base import VoiceBase
|
from pilot.speech.base import VoiceBase
|
||||||
@ -70,6 +69,7 @@ class ElevenLabsSpeech(VoiceBase):
|
|||||||
bool: True if the request was successful, False otherwise
|
bool: True if the request was successful, False otherwise
|
||||||
"""
|
"""
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
|
from playsound import playsound
|
||||||
|
|
||||||
tts_url = (
|
tts_url = (
|
||||||
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
|
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import gtts
|
import gtts
|
||||||
from playsound import playsound
|
|
||||||
|
|
||||||
from pilot.speech.base import VoiceBase
|
from pilot.speech.base import VoiceBase
|
||||||
|
|
||||||
@ -15,6 +14,8 @@ class GTTSVoice(VoiceBase):
|
|||||||
|
|
||||||
def _speech(self, text: str, _: int = 0) -> bool:
|
def _speech(self, text: str, _: int = 0) -> bool:
|
||||||
"""Play the given text."""
|
"""Play the given text."""
|
||||||
|
from playsound import playsound
|
||||||
|
|
||||||
tts = gtts.gTTS(text)
|
tts = gtts.gTTS(text)
|
||||||
tts.save("speech.mp3")
|
tts.save("speech.mp3")
|
||||||
playsound("speech.mp3", True)
|
playsound("speech.mp3", True)
|
||||||
|
@ -5,4 +5,5 @@ from .utils import (
|
|||||||
disable_torch_init,
|
disable_torch_init,
|
||||||
pretty_print_semaphore,
|
pretty_print_semaphore,
|
||||||
server_error_msg,
|
server_error_msg,
|
||||||
|
get_or_create_event_loop,
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import httpx
|
import httpx
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
import typing_inspect
|
import typing_inspect
|
||||||
|
import logging
|
||||||
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
|
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
|
||||||
from dataclasses import is_dataclass, asdict
|
from dataclasses import is_dataclass, asdict
|
||||||
|
|
||||||
@ -21,7 +22,9 @@ def _api_remote(path, method="GET"):
|
|||||||
raise TypeError("Return type must be annotated in the decorated function.")
|
raise TypeError("Return type must be annotated in the decorated function.")
|
||||||
|
|
||||||
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
||||||
print(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
|
logging.debug(
|
||||||
|
f"return_type: {return_type}, actual_dataclass: {actual_dataclass}"
|
||||||
|
)
|
||||||
if not actual_dataclass:
|
if not actual_dataclass:
|
||||||
actual_dataclass = return_type
|
actual_dataclass = return_type
|
||||||
sig = signature(func)
|
sig = signature(func)
|
||||||
@ -57,7 +60,9 @@ def _api_remote(path, method="GET"):
|
|||||||
else: # For GET, DELETE, etc.
|
else: # For GET, DELETE, etc.
|
||||||
request_params["params"] = request_data
|
request_params["params"] = request_data
|
||||||
|
|
||||||
print(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
|
logging.info(
|
||||||
|
f"request_params: {request_params}, args: {args}, kwargs: {kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.request(**request_params)
|
response = await client.request(**request_params)
|
||||||
|
@ -5,8 +5,8 @@ import logging
|
|||||||
import logging.handlers
|
import logging.handlers
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import torch
|
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
|
||||||
server_error_msg = (
|
server_error_msg = (
|
||||||
@ -17,6 +17,8 @@ handler = None
|
|||||||
|
|
||||||
|
|
||||||
def get_gpu_memory(max_gpus=None):
|
def get_gpu_memory(max_gpus=None):
|
||||||
|
import torch
|
||||||
|
|
||||||
gpu_memory = []
|
gpu_memory = []
|
||||||
num_gpus = (
|
num_gpus = (
|
||||||
torch.cuda.device_count()
|
torch.cuda.device_count()
|
||||||
@ -130,3 +132,15 @@ def pretty_print_semaphore(semaphore):
|
|||||||
if semaphore is None:
|
if semaphore is None:
|
||||||
return "None"
|
return "None"
|
||||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
except Exception as e:
|
||||||
|
if not "no running event loop" in str(e):
|
||||||
|
raise e
|
||||||
|
logging.warning("Cant not get running event loop, create new event loop now")
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
return loop
|
||||||
|
@ -77,3 +77,6 @@ bardapi==0.1.29
|
|||||||
pymysql
|
pymysql
|
||||||
duckdb
|
duckdb
|
||||||
duckdb-engine
|
duckdb-engine
|
||||||
|
|
||||||
|
# cli
|
||||||
|
prettytable
|
4
setup.py
4
setup.py
@ -267,7 +267,7 @@ def llama_cpp_python_cuda_requires():
|
|||||||
llama_cpp_version = "0.1.77"
|
llama_cpp_version = "0.1.77"
|
||||||
py_version = "cp310"
|
py_version = "cp310"
|
||||||
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||||
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||||
extra_index_url, _ = encode_url(extra_index_url)
|
extra_index_url, _ = encode_url(extra_index_url)
|
||||||
print(f"Install llama_cpp_python_cuda from {extra_index_url}")
|
print(f"Install llama_cpp_python_cuda from {extra_index_url}")
|
||||||
|
|
||||||
@ -361,7 +361,7 @@ setuptools.setup(
|
|||||||
extras_require=setup_spec.extras,
|
extras_require=setup_spec.extras,
|
||||||
entry_points={
|
entry_points={
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"dbgpt_server=pilot.server:webserver",
|
"dbgpt=pilot.scripts.cli_scripts:main",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,6 @@ sys.path.append(
|
|||||||
|
|
||||||
from pilot.configs.model_config import DATASETS_DIR
|
from pilot.configs.model_config import DATASETS_DIR
|
||||||
|
|
||||||
from tools.cli.knowledge_client import knowledge_init
|
|
||||||
|
|
||||||
API_ADDRESS: str = "http://127.0.0.1:5000"
|
API_ADDRESS: str = "http://127.0.0.1:5000"
|
||||||
|
|
||||||
@ -97,6 +96,8 @@ def knowledge(
|
|||||||
verbose: bool,
|
verbose: bool,
|
||||||
):
|
):
|
||||||
"""Knowledge command line tool"""
|
"""Knowledge command line tool"""
|
||||||
|
from tools.cli.knowledge_client import knowledge_init
|
||||||
|
|
||||||
knowledge_init(
|
knowledge_init(
|
||||||
API_ADDRESS,
|
API_ADDRESS,
|
||||||
vector_name,
|
vector_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user