mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 07:34:07 +00:00
feat: Multi-model command line
This commit is contained in:
parent
d467092766
commit
dd86fb86b1
151
pilot/model/cli.py
Normal file
151
pilot/model/cli.py
Normal file
@ -0,0 +1,151 @@
|
||||
import click
|
||||
import functools
|
||||
|
||||
from pilot.model.controller.registry import ModelRegistryClient
|
||||
from pilot.model.worker.manager import (
|
||||
RemoteWorkerManager,
|
||||
WorkerApplyRequest,
|
||||
WorkerApplyType,
|
||||
)
|
||||
from pilot.utils import get_or_create_event_loop
|
||||
|
||||
|
||||
@click.group("model")
|
||||
def model_cli_group():
|
||||
pass
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default="http://127.0.0.1:8000",
|
||||
required=False,
|
||||
help=(
|
||||
"Address of the Model Controller to connect to."
|
||||
"Just support light deploy model"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--model-name", type=str, default=None, required=False, help=("The name of model")
|
||||
)
|
||||
@click.option(
|
||||
"--model-type", type=str, default="llm", required=False, help=("The type of model")
|
||||
)
|
||||
def list(address: str, model_name: str, model_type: str):
|
||||
"""List model instances"""
|
||||
from prettytable import PrettyTable
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
registry = ModelRegistryClient(address)
|
||||
|
||||
if not model_name:
|
||||
instances = loop.run_until_complete(registry.get_all_model_instances())
|
||||
else:
|
||||
if not model_type:
|
||||
model_type = "llm"
|
||||
register_model_name = f"{model_name}@{model_type}"
|
||||
instances = loop.run_until_complete(
|
||||
registry.get_all_instances(register_model_name)
|
||||
)
|
||||
table = PrettyTable()
|
||||
|
||||
table.field_names = [
|
||||
"Model Name",
|
||||
"Model Type",
|
||||
"Host",
|
||||
"Port",
|
||||
"Healthy",
|
||||
"Enabled",
|
||||
"Prompt Template",
|
||||
"Last Heartbeat",
|
||||
]
|
||||
for instance in instances:
|
||||
model_name, model_type = instance.model_name.split("@")
|
||||
table.add_row(
|
||||
[
|
||||
model_name,
|
||||
model_type,
|
||||
instance.host,
|
||||
instance.port,
|
||||
instance.healthy,
|
||||
instance.enabled,
|
||||
instance.prompt_template,
|
||||
instance.last_heartbeat,
|
||||
]
|
||||
)
|
||||
|
||||
print(table)
|
||||
|
||||
|
||||
def add_model_options(func):
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default="http://127.0.0.1:8000",
|
||||
required=False,
|
||||
help=(
|
||||
"Address of the Model Controller to connect to."
|
||||
"Just support light deploy model"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help=("The name of model"),
|
||||
)
|
||||
@click.option(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="llm",
|
||||
required=False,
|
||||
help=("The type of model"),
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def stop(address: str, model_name: str, model_type: str):
|
||||
"""Stop model instances"""
|
||||
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def start(address: str, model_name: str, model_type: str):
|
||||
"""Start model instances"""
|
||||
worker_apply(address, model_name, model_type, WorkerApplyType.START)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def restart(address: str, model_name: str, model_type: str):
|
||||
"""Restart model instances"""
|
||||
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
|
||||
|
||||
|
||||
# @model_cli_group.command()
|
||||
# @add_model_options
|
||||
# def modify(address: str, model_name: str, model_type: str):
|
||||
# """Restart model instances"""
|
||||
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
|
||||
|
||||
|
||||
def worker_apply(
|
||||
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||
):
|
||||
loop = get_or_create_event_loop()
|
||||
registry = ModelRegistryClient(address)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
apply_req = WorkerApplyRequest(
|
||||
model=model_name, worker_type=model_type, apply_type=apply_type
|
||||
)
|
||||
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
|
||||
print(res)
|
@ -27,6 +27,9 @@ class ModelController:
|
||||
)
|
||||
return await self.registry.get_all_instances(model_name, healthy_only)
|
||||
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
return await self.registry.get_all_model_instances()
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
return await self.registry.send_heartbeat(instance)
|
||||
|
||||
@ -51,10 +54,12 @@ async def api_deregister_instance(request: ModelInstance):
|
||||
|
||||
|
||||
@router.get("/controller/models")
|
||||
async def api_get_all_instances(model_name: str, healthy_only: bool = False):
|
||||
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
|
||||
if not model_name:
|
||||
return await controller.get_all_model_instances()
|
||||
return await controller.get_all_instances(model_name, healthy_only=healthy_only)
|
||||
|
||||
|
||||
@router.post("/controller/heartbeat")
|
||||
async def api_get_all_instances(request: ModelInstance):
|
||||
async def api_model_heartbeat(request: ModelInstance):
|
||||
return await controller.send_heartbeat(request)
|
||||
|
@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Tuple
|
||||
import itertools
|
||||
|
||||
from pilot.model.base import ModelInstance
|
||||
|
||||
@ -57,6 +58,15 @@ class ModelRegistry(ABC):
|
||||
- List[ModelInstance]: A list of instances for the given model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
"""
|
||||
Fetch all instances of all models
|
||||
|
||||
Returns:
|
||||
- List[ModelInstance]: A list of instances for the all models.
|
||||
"""
|
||||
|
||||
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
||||
"""
|
||||
Selects one healthy and enabled instance for a given model.
|
||||
@ -154,12 +164,15 @@ class EmbeddedModelRegistry(ModelRegistry):
|
||||
async def get_all_instances(
|
||||
self, model_name: str, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
print(self.registry)
|
||||
instances = self.registry[model_name]
|
||||
if healthy_only:
|
||||
instances = [ins for ins in instances if ins.healthy == True]
|
||||
return instances
|
||||
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
print(self.registry)
|
||||
return list(itertools.chain(*self.registry.values()))
|
||||
|
||||
async def send_heartbeat(self, instance: ModelInstance) -> bool:
|
||||
_, exist_ins = self._get_instances(
|
||||
instance.model_name, instance.host, instance.port, healthy_only=False
|
||||
@ -194,6 +207,10 @@ class ModelRegistryClient(ModelRegistry):
|
||||
) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def get_all_model_instances(self) -> List[ModelInstance]:
|
||||
pass
|
||||
|
||||
@api_remote(path="/api/controller/models")
|
||||
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
|
||||
instances = await self.get_all_instances(model_name, healthy_only=True)
|
||||
|
@ -118,7 +118,6 @@ class ModelLoader:
|
||||
def loader_with_params(self, model_params: ModelParameters):
|
||||
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||
model_type = llm_adapter.model_type()
|
||||
param_cls = llm_adapter.model_param_class(model_type)
|
||||
self.prompt_template = model_params.prompt_template
|
||||
logger.info(f"model_params:\n{model_params}")
|
||||
if model_type == ModelType.HF:
|
||||
|
@ -184,6 +184,19 @@ class BaseParameters:
|
||||
|
||||
return updated
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
parameters = [
|
||||
f"\n\n=========================== {class_name} ===========================\n"
|
||||
]
|
||||
for field_info in fields(self):
|
||||
value = getattr(self, field_info.name)
|
||||
parameters.append(f"{field_info.name}: {value}")
|
||||
parameters.append(
|
||||
"\n======================================================================\n\n"
|
||||
)
|
||||
return "\n".join(parameters)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWorkerParameters(BaseParameters):
|
||||
|
@ -4,12 +4,12 @@ from typing import Dict, Iterator, List
|
||||
|
||||
import torch
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.model.adapter import get_llm_model_adapter
|
||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import EnvArgumentParser, ModelParameters
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
|
||||
logger = logging.getLogger("model_worker")
|
||||
@ -20,6 +20,8 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self._model_params = None
|
||||
self.llm_adapter: BaseLLMAdaper = None
|
||||
self.llm_chat_adapter: BaseChatAdpter = None
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
if model_path.endswith("/"):
|
||||
@ -28,9 +30,9 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
|
||||
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||
model_type = llm_adapter.model_type()
|
||||
self.param_cls = llm_adapter.model_param_class(model_type)
|
||||
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||
model_type = self.llm_adapter.model_type()
|
||||
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
||||
@ -50,12 +52,14 @@ class DefaultModelWorker(ModelWorker):
|
||||
param_cls = self.model_param_class()
|
||||
model_args = EnvArgumentParser()
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
|
||||
model_type = self.llm_adapter.model_type()
|
||||
model_params: ModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefix=env_prefix,
|
||||
command_args=command_args,
|
||||
model_name=self.model_name,
|
||||
model_path=self.model_path,
|
||||
model_type=model_type,
|
||||
)
|
||||
if not model_params.device:
|
||||
model_params.device = DEVICE
|
||||
|
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import httpx
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
@ -147,7 +148,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
self.model_registry = model_registry
|
||||
|
||||
def _worker_key(self, worker_type: str, model_name: str) -> str:
|
||||
return f"$${worker_type}_$$_{model_name}"
|
||||
return f"{model_name}@{worker_type}"
|
||||
|
||||
def add_worker(
|
||||
self,
|
||||
@ -311,7 +312,9 @@ class LocalWorkerManager(WorkerManager):
|
||||
# Apply to all workers
|
||||
worker_instances = list(itertools.chain(*self.workers.values()))
|
||||
logger.info(f"Apply to all workers: {worker_instances}")
|
||||
await asyncio.gather(*(apply_func(worker) for worker in worker_instances))
|
||||
return await asyncio.gather(
|
||||
*(apply_func(worker) for worker in worker_instances)
|
||||
)
|
||||
|
||||
async def _start_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
@ -423,6 +426,28 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
worker_instances.append(wr)
|
||||
return worker_instances
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
||||
worker_addr = worker_run_data.worker.worker_addr
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
worker_addr + "/apply",
|
||||
headers=worker_run_data.worker.headers,
|
||||
json=apply_req.dict(),
|
||||
timeout=worker_run_data.worker.timeout,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
output = WorkerApplyOutput(**response.json())
|
||||
logger.info(f"worker_apply success: {output}")
|
||||
else:
|
||||
output = WorkerApplyOutput(message=response.text)
|
||||
logger.warn(f"worker_apply failed: {output}")
|
||||
return output
|
||||
|
||||
results = await self._apply_worker(apply_req, _remote_apply_func)
|
||||
if results:
|
||||
return results[0]
|
||||
|
||||
|
||||
class WorkerManagerAdapter(WorkerManager):
|
||||
def __init__(self, worker_manager: WorkerManager = None) -> None:
|
||||
|
@ -28,10 +28,7 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
from pilot.utils import (
|
||||
build_logger,
|
||||
server_error_msg,
|
||||
)
|
||||
from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop
|
||||
from pilot.scene.base_message import (
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
@ -222,13 +219,10 @@ class BaseChat(ABC):
|
||||
return self.current_ai_response()
|
||||
|
||||
def _blocking_stream_call(self):
|
||||
import asyncio
|
||||
|
||||
logger.warn(
|
||||
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = get_or_create_event_loop()
|
||||
async_gen = self.stream_call()
|
||||
while True:
|
||||
try:
|
||||
@ -238,13 +232,10 @@ class BaseChat(ABC):
|
||||
break
|
||||
|
||||
def _blocking_nostream_call(self):
|
||||
import asyncio
|
||||
|
||||
logger.warn(
|
||||
"_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance"
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = get_or_create_event_loop()
|
||||
return loop.run_until_complete(self.nostream_call())
|
||||
|
||||
def call(self):
|
||||
|
0
pilot/scripts/__init__.py
Normal file
0
pilot/scripts/__init__.py
Normal file
45
pilot/scripts/cli_scripts.py
Normal file
45
pilot/scripts/cli_scripts.py
Normal file
@ -0,0 +1,45 @@
|
||||
import sys
|
||||
import click
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
|
||||
sys.path.append(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option(
|
||||
"--log-level",
|
||||
required=False,
|
||||
type=str,
|
||||
default="warn",
|
||||
help="Log level",
|
||||
)
|
||||
@click.version_option()
|
||||
def cli(log_level: str):
|
||||
# TODO not working now
|
||||
logging.basicConfig(level=log_level, encoding="utf-8")
|
||||
|
||||
|
||||
def add_command_alias(command, name: str, hidden: bool = False):
|
||||
new_command = copy.deepcopy(command)
|
||||
new_command.hidden = hidden
|
||||
cli.add_command(new_command, name=name)
|
||||
|
||||
|
||||
try:
|
||||
from pilot.model.cli import model_cli_group
|
||||
|
||||
add_command_alias(model_cli_group, name="model")
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
return cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -75,14 +75,20 @@ app.include_router(api_editor_route_v1, prefix="/api")
|
||||
app.include_router(knowledge_router)
|
||||
# app.include_router(api_editor_route_v1)
|
||||
|
||||
|
||||
def mount_static_files(app):
|
||||
os.makedirs(static_message_img_path, exist_ok=True)
|
||||
app.mount(
|
||||
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2"
|
||||
"/images",
|
||||
StaticFiles(directory=static_message_img_path, html=True),
|
||||
name="static2",
|
||||
)
|
||||
app.mount(
|
||||
"/_next/static", StaticFiles(directory=static_file_path + "/_next/static")
|
||||
)
|
||||
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
|
||||
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
||||
|
||||
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,7 +2,6 @@ import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
@ -23,6 +22,8 @@ class BrianSpeech(VoiceBase):
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
from playsound import playsound
|
||||
|
||||
tts_url = (
|
||||
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
|
||||
)
|
||||
|
@ -2,7 +2,6 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.speech.base import VoiceBase
|
||||
@ -70,6 +69,7 @@ class ElevenLabsSpeech(VoiceBase):
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
from pilot.logs import logger
|
||||
from playsound import playsound
|
||||
|
||||
tts_url = (
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
|
||||
|
@ -2,7 +2,6 @@
|
||||
import os
|
||||
|
||||
import gtts
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
@ -15,6 +14,8 @@ class GTTSVoice(VoiceBase):
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
from playsound import playsound
|
||||
|
||||
tts = gtts.gTTS(text)
|
||||
tts.save("speech.mp3")
|
||||
playsound("speech.mp3", True)
|
||||
|
@ -5,4 +5,5 @@ from .utils import (
|
||||
disable_torch_init,
|
||||
pretty_print_semaphore,
|
||||
server_error_msg,
|
||||
get_or_create_event_loop,
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import httpx
|
||||
from inspect import signature
|
||||
import typing_inspect
|
||||
import logging
|
||||
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
|
||||
from dataclasses import is_dataclass, asdict
|
||||
|
||||
@ -21,7 +22,9 @@ def _api_remote(path, method="GET"):
|
||||
raise TypeError("Return type must be annotated in the decorated function.")
|
||||
|
||||
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
||||
print(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
|
||||
logging.debug(
|
||||
f"return_type: {return_type}, actual_dataclass: {actual_dataclass}"
|
||||
)
|
||||
if not actual_dataclass:
|
||||
actual_dataclass = return_type
|
||||
sig = signature(func)
|
||||
@ -57,7 +60,9 @@ def _api_remote(path, method="GET"):
|
||||
else: # For GET, DELETE, etc.
|
||||
request_params["params"] = request_data
|
||||
|
||||
print(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
|
||||
logging.info(
|
||||
f"request_params: {request_params}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(**request_params)
|
||||
|
@ -5,8 +5,8 @@ import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
import torch
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
server_error_msg = (
|
||||
@ -17,6 +17,8 @@ handler = None
|
||||
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
import torch
|
||||
|
||||
gpu_memory = []
|
||||
num_gpus = (
|
||||
torch.cuda.device_count()
|
||||
@ -130,3 +132,15 @@ def pretty_print_semaphore(semaphore):
|
||||
if semaphore is None:
|
||||
return "None"
|
||||
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
||||
|
||||
|
||||
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except Exception as e:
|
||||
if not "no running event loop" in str(e):
|
||||
raise e
|
||||
logging.warning("Cant not get running event loop, create new event loop now")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
@ -76,4 +76,7 @@ bardapi==0.1.29
|
||||
# TODO moved to optional dependencies
|
||||
pymysql
|
||||
duckdb
|
||||
duckdb-engine
|
||||
duckdb-engine
|
||||
|
||||
# cli
|
||||
prettytable
|
4
setup.py
4
setup.py
@ -267,7 +267,7 @@ def llama_cpp_python_cuda_requires():
|
||||
llama_cpp_version = "0.1.77"
|
||||
py_version = "cp310"
|
||||
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
extra_index_url, _ = encode_url(extra_index_url)
|
||||
print(f"Install llama_cpp_python_cuda from {extra_index_url}")
|
||||
|
||||
@ -361,7 +361,7 @@ setuptools.setup(
|
||||
extras_require=setup_spec.extras,
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"dbgpt_server=pilot.server:webserver",
|
||||
"dbgpt=pilot.scripts.cli_scripts:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
@ -25,7 +25,6 @@ sys.path.append(
|
||||
|
||||
from pilot.configs.model_config import DATASETS_DIR
|
||||
|
||||
from tools.cli.knowledge_client import knowledge_init
|
||||
|
||||
API_ADDRESS: str = "http://127.0.0.1:5000"
|
||||
|
||||
@ -97,6 +96,8 @@ def knowledge(
|
||||
verbose: bool,
|
||||
):
|
||||
"""Knowledge command line tool"""
|
||||
from tools.cli.knowledge_client import knowledge_init
|
||||
|
||||
knowledge_init(
|
||||
API_ADDRESS,
|
||||
vector_name,
|
||||
|
Loading…
Reference in New Issue
Block a user