feat: Multi-model command line

This commit is contained in:
FangYin Cheng 2023-08-30 11:07:35 +08:00
parent d467092766
commit dd86fb86b1
20 changed files with 317 additions and 35 deletions

151
pilot/model/cli.py Normal file
View File

@ -0,0 +1,151 @@
import click
import functools
from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.worker.manager import (
RemoteWorkerManager,
WorkerApplyRequest,
WorkerApplyType,
)
from pilot.utils import get_or_create_event_loop
@click.group("model")
def model_cli_group():
pass
@model_cli_group.command()
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
help=(
"Address of the Model Controller to connect to."
"Just support light deploy model"
),
)
@click.option(
"--model-name", type=str, default=None, required=False, help=("The name of model")
)
@click.option(
"--model-type", type=str, default="llm", required=False, help=("The type of model")
)
def list(address: str, model_name: str, model_type: str):
"""List model instances"""
from prettytable import PrettyTable
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
if not model_name:
instances = loop.run_until_complete(registry.get_all_model_instances())
else:
if not model_type:
model_type = "llm"
register_model_name = f"{model_name}@{model_type}"
instances = loop.run_until_complete(
registry.get_all_instances(register_model_name)
)
table = PrettyTable()
table.field_names = [
"Model Name",
"Model Type",
"Host",
"Port",
"Healthy",
"Enabled",
"Prompt Template",
"Last Heartbeat",
]
for instance in instances:
model_name, model_type = instance.model_name.split("@")
table.add_row(
[
model_name,
model_type,
instance.host,
instance.port,
instance.healthy,
instance.enabled,
instance.prompt_template,
instance.last_heartbeat,
]
)
print(table)
def add_model_options(func):
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
help=(
"Address of the Model Controller to connect to."
"Just support light deploy model"
),
)
@click.option(
"--model-name",
type=str,
default=None,
required=True,
help=("The name of model"),
)
@click.option(
"--model-type",
type=str,
default="llm",
required=False,
help=("The type of model"),
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@model_cli_group.command()
@add_model_options
def stop(address: str, model_name: str, model_type: str):
"""Stop model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
@model_cli_group.command()
@add_model_options
def start(address: str, model_name: str, model_type: str):
"""Start model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.START)
@model_cli_group.command()
@add_model_options
def restart(address: str, model_name: str, model_type: str):
"""Restart model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
# @model_cli_group.command()
# @add_model_options
# def modify(address: str, model_name: str, model_type: str):
# """Restart model instances"""
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)
def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
):
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
apply_req = WorkerApplyRequest(
model=model_name, worker_type=model_type, apply_type=apply_type
)
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
print(res)

View File

@ -27,6 +27,9 @@ class ModelController:
) )
return await self.registry.get_all_instances(model_name, healthy_only) return await self.registry.get_all_instances(model_name, healthy_only)
async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.registry.get_all_model_instances()
async def send_heartbeat(self, instance: ModelInstance) -> bool: async def send_heartbeat(self, instance: ModelInstance) -> bool:
return await self.registry.send_heartbeat(instance) return await self.registry.send_heartbeat(instance)
@ -51,10 +54,12 @@ async def api_deregister_instance(request: ModelInstance):
@router.get("/controller/models") @router.get("/controller/models")
async def api_get_all_instances(model_name: str, healthy_only: bool = False): async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
if not model_name:
return await controller.get_all_model_instances()
return await controller.get_all_instances(model_name, healthy_only=healthy_only) return await controller.get_all_instances(model_name, healthy_only=healthy_only)
@router.post("/controller/heartbeat") @router.post("/controller/heartbeat")
async def api_get_all_instances(request: ModelInstance): async def api_model_heartbeat(request: ModelInstance):
return await controller.send_heartbeat(request) return await controller.send_heartbeat(request)

View File

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import itertools
from pilot.model.base import ModelInstance from pilot.model.base import ModelInstance
@ -57,6 +58,15 @@ class ModelRegistry(ABC):
- List[ModelInstance]: A list of instances for the given model. - List[ModelInstance]: A list of instances for the given model.
""" """
@abstractmethod
async def get_all_model_instances(self) -> List[ModelInstance]:
"""
Fetch all instances of all models
Returns:
- List[ModelInstance]: A list of instances for the all models.
"""
async def select_one_health_instance(self, model_name: str) -> ModelInstance: async def select_one_health_instance(self, model_name: str) -> ModelInstance:
""" """
Selects one healthy and enabled instance for a given model. Selects one healthy and enabled instance for a given model.
@ -154,12 +164,15 @@ class EmbeddedModelRegistry(ModelRegistry):
async def get_all_instances( async def get_all_instances(
self, model_name: str, healthy_only: bool = False self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
print(self.registry)
instances = self.registry[model_name] instances = self.registry[model_name]
if healthy_only: if healthy_only:
instances = [ins for ins in instances if ins.healthy == True] instances = [ins for ins in instances if ins.healthy == True]
return instances return instances
async def get_all_model_instances(self) -> List[ModelInstance]:
print(self.registry)
return list(itertools.chain(*self.registry.values()))
async def send_heartbeat(self, instance: ModelInstance) -> bool: async def send_heartbeat(self, instance: ModelInstance) -> bool:
_, exist_ins = self._get_instances( _, exist_ins = self._get_instances(
instance.model_name, instance.host, instance.port, healthy_only=False instance.model_name, instance.host, instance.port, healthy_only=False
@ -194,6 +207,10 @@ class ModelRegistryClient(ModelRegistry):
) -> List[ModelInstance]: ) -> List[ModelInstance]:
pass pass
@api_remote(path="/api/controller/models")
async def get_all_model_instances(self) -> List[ModelInstance]:
pass
@api_remote(path="/api/controller/models") @api_remote(path="/api/controller/models")
async def select_one_health_instance(self, model_name: str) -> ModelInstance: async def select_one_health_instance(self, model_name: str) -> ModelInstance:
instances = await self.get_all_instances(model_name, healthy_only=True) instances = await self.get_all_instances(model_name, healthy_only=True)

View File

@ -118,7 +118,6 @@ class ModelLoader:
def loader_with_params(self, model_params: ModelParameters): def loader_with_params(self, model_params: ModelParameters):
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = llm_adapter.model_type() model_type = llm_adapter.model_type()
param_cls = llm_adapter.model_param_class(model_type)
self.prompt_template = model_params.prompt_template self.prompt_template = model_params.prompt_template
logger.info(f"model_params:\n{model_params}") logger.info(f"model_params:\n{model_params}")
if model_type == ModelType.HF: if model_type == ModelType.HF:

View File

@ -184,6 +184,19 @@ class BaseParameters:
return updated return updated
def __str__(self) -> str:
class_name = self.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(self):
value = getattr(self, field_info.name)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
@dataclass @dataclass
class ModelWorkerParameters(BaseParameters): class ModelWorkerParameters(BaseParameters):

View File

@ -4,12 +4,12 @@ from typing import Dict, Iterator, List
import torch import torch
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.base import ModelOutput from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import EnvArgumentParser, ModelParameters from pilot.model.parameter import EnvArgumentParser, ModelParameters
from pilot.model.worker.base import ModelWorker from pilot.model.worker.base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache from pilot.utils.model_utils import _clear_torch_cache
logger = logging.getLogger("model_worker") logger = logging.getLogger("model_worker")
@ -20,6 +20,8 @@ class DefaultModelWorker(ModelWorker):
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self._model_params = None self._model_params = None
self.llm_adapter: BaseLLMAdaper = None
self.llm_chat_adapter: BaseChatAdpter = None
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"): if model_path.endswith("/"):
@ -28,9 +30,9 @@ class DefaultModelWorker(ModelWorker):
self.model_name = model_name self.model_name = model_name
self.model_path = model_path self.model_path = model_path
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = llm_adapter.model_type() model_type = self.llm_adapter.model_type()
self.param_cls = llm_adapter.model_param_class(model_type) self.param_cls = self.llm_adapter.model_param_class(model_type)
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path) self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func( self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
@ -50,12 +52,14 @@ class DefaultModelWorker(ModelWorker):
param_cls = self.model_param_class() param_cls = self.model_param_class()
model_args = EnvArgumentParser() model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name) env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
model_type = self.llm_adapter.model_type()
model_params: ModelParameters = model_args.parse_args_into_dataclass( model_params: ModelParameters = model_args.parse_args_into_dataclass(
param_cls, param_cls,
env_prefix=env_prefix, env_prefix=env_prefix,
command_args=command_args, command_args=command_args,
model_name=self.model_name, model_name=self.model_name,
model_path=self.model_path, model_path=self.model_path,
model_type=model_type,
) )
if not model_params.device: if not model_params.device:
model_params.device = DEVICE model_params.device = DEVICE

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import httpx
import itertools import itertools
import json import json
import os import os
@ -147,7 +148,7 @@ class LocalWorkerManager(WorkerManager):
self.model_registry = model_registry self.model_registry = model_registry
def _worker_key(self, worker_type: str, model_name: str) -> str: def _worker_key(self, worker_type: str, model_name: str) -> str:
return f"$${worker_type}_$$_{model_name}" return f"{model_name}@{worker_type}"
def add_worker( def add_worker(
self, self,
@ -311,7 +312,9 @@ class LocalWorkerManager(WorkerManager):
# Apply to all workers # Apply to all workers
worker_instances = list(itertools.chain(*self.workers.values())) worker_instances = list(itertools.chain(*self.workers.values()))
logger.info(f"Apply to all workers: {worker_instances}") logger.info(f"Apply to all workers: {worker_instances}")
await asyncio.gather(*(apply_func(worker) for worker in worker_instances)) return await asyncio.gather(
*(apply_func(worker) for worker in worker_instances)
)
async def _start_all_worker( async def _start_all_worker(
self, apply_req: WorkerApplyRequest self, apply_req: WorkerApplyRequest
@ -423,6 +426,28 @@ class RemoteWorkerManager(LocalWorkerManager):
worker_instances.append(wr) worker_instances.append(wr)
return worker_instances return worker_instances
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
async def _remote_apply_func(worker_run_data: WorkerRunData):
worker_addr = worker_run_data.worker.worker_addr
async with httpx.AsyncClient() as client:
response = await client.post(
worker_addr + "/apply",
headers=worker_run_data.worker.headers,
json=apply_req.dict(),
timeout=worker_run_data.worker.timeout,
)
if response.status_code == 200:
output = WorkerApplyOutput(**response.json())
logger.info(f"worker_apply success: {output}")
else:
output = WorkerApplyOutput(message=response.text)
logger.warn(f"worker_apply failed: {output}")
return output
results = await self._apply_worker(apply_req, _remote_apply_func)
if results:
return results[0]
class WorkerManagerAdapter(WorkerManager): class WorkerManagerAdapter(WorkerManager):
def __init__(self, worker_manager: WorkerManager = None) -> None: def __init__(self, worker_manager: WorkerManager = None) -> None:

View File

@ -28,10 +28,7 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.configs.model_config import LOGDIR, DATASETS_DIR from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import ( from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop
build_logger,
server_error_msg,
)
from pilot.scene.base_message import ( from pilot.scene.base_message import (
BaseMessage, BaseMessage,
SystemMessage, SystemMessage,
@ -222,13 +219,10 @@ class BaseChat(ABC):
return self.current_ai_response() return self.current_ai_response()
def _blocking_stream_call(self): def _blocking_stream_call(self):
import asyncio
logger.warn( logger.warn(
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance" "_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
) )
loop = asyncio.new_event_loop() loop = get_or_create_event_loop()
asyncio.set_event_loop(loop)
async_gen = self.stream_call() async_gen = self.stream_call()
while True: while True:
try: try:
@ -238,13 +232,10 @@ class BaseChat(ABC):
break break
def _blocking_nostream_call(self): def _blocking_nostream_call(self):
import asyncio
logger.warn( logger.warn(
"_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance" "_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance"
) )
loop = asyncio.new_event_loop() loop = get_or_create_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.nostream_call()) return loop.run_until_complete(self.nostream_call())
def call(self): def call(self):

View File

View File

@ -0,0 +1,45 @@
import sys
import click
import os
import copy
import logging
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
)
@click.group()
@click.option(
"--log-level",
required=False,
type=str,
default="warn",
help="Log level",
)
@click.version_option()
def cli(log_level: str):
# TODO not working now
logging.basicConfig(level=log_level, encoding="utf-8")
def add_command_alias(command, name: str, hidden: bool = False):
new_command = copy.deepcopy(command)
new_command.hidden = hidden
cli.add_command(new_command, name=name)
try:
from pilot.model.cli import model_cli_group
add_command_alias(model_cli_group, name="model")
except ImportError as e:
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
def main():
return cli()
if __name__ == "__main__":
main()

View File

@ -75,14 +75,20 @@ app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(knowledge_router) app.include_router(knowledge_router)
# app.include_router(api_editor_route_v1) # app.include_router(api_editor_route_v1)
def mount_static_files(app): def mount_static_files(app):
os.makedirs(static_message_img_path, exist_ok=True) os.makedirs(static_message_img_path, exist_ok=True)
app.mount( app.mount(
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2" "/images",
StaticFiles(directory=static_message_img_path, html=True),
name="static2",
)
app.mount(
"/_next/static", StaticFiles(directory=static_file_path + "/_next/static")
) )
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static") app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,7 +2,6 @@ import logging
import os import os
import requests import requests
from playsound import playsound
from pilot.speech.base import VoiceBase from pilot.speech.base import VoiceBase
@ -23,6 +22,8 @@ class BrianSpeech(VoiceBase):
Returns: Returns:
bool: True if the request was successful, False otherwise bool: True if the request was successful, False otherwise
""" """
from playsound import playsound
tts_url = ( tts_url = (
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}" f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
) )

View File

@ -2,7 +2,6 @@
import os import os
import requests import requests
from playsound import playsound
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.speech.base import VoiceBase from pilot.speech.base import VoiceBase
@ -70,6 +69,7 @@ class ElevenLabsSpeech(VoiceBase):
bool: True if the request was successful, False otherwise bool: True if the request was successful, False otherwise
""" """
from pilot.logs import logger from pilot.logs import logger
from playsound import playsound
tts_url = ( tts_url = (
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}" f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"

View File

@ -2,7 +2,6 @@
import os import os
import gtts import gtts
from playsound import playsound
from pilot.speech.base import VoiceBase from pilot.speech.base import VoiceBase
@ -15,6 +14,8 @@ class GTTSVoice(VoiceBase):
def _speech(self, text: str, _: int = 0) -> bool: def _speech(self, text: str, _: int = 0) -> bool:
"""Play the given text.""" """Play the given text."""
from playsound import playsound
tts = gtts.gTTS(text) tts = gtts.gTTS(text)
tts.save("speech.mp3") tts.save("speech.mp3")
playsound("speech.mp3", True) playsound("speech.mp3", True)

View File

@ -5,4 +5,5 @@ from .utils import (
disable_torch_init, disable_torch_init,
pretty_print_semaphore, pretty_print_semaphore,
server_error_msg, server_error_msg,
get_or_create_event_loop,
) )

View File

@ -1,6 +1,7 @@
import httpx import httpx
from inspect import signature from inspect import signature
import typing_inspect import typing_inspect
import logging
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
from dataclasses import is_dataclass, asdict from dataclasses import is_dataclass, asdict
@ -21,7 +22,9 @@ def _api_remote(path, method="GET"):
raise TypeError("Return type must be annotated in the decorated function.") raise TypeError("Return type must be annotated in the decorated function.")
actual_dataclass = _extract_dataclass_from_generic(return_type) actual_dataclass = _extract_dataclass_from_generic(return_type)
print(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}") logging.debug(
f"return_type: {return_type}, actual_dataclass: {actual_dataclass}"
)
if not actual_dataclass: if not actual_dataclass:
actual_dataclass = return_type actual_dataclass = return_type
sig = signature(func) sig = signature(func)
@ -57,7 +60,9 @@ def _api_remote(path, method="GET"):
else: # For GET, DELETE, etc. else: # For GET, DELETE, etc.
request_params["params"] = request_data request_params["params"] = request_data
print(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}") logging.info(
f"request_params: {request_params}, args: {args}, kwargs: {kwargs}"
)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.request(**request_params) response = await client.request(**request_params)

View File

@ -5,8 +5,8 @@ import logging
import logging.handlers import logging.handlers
import os import os
import sys import sys
import asyncio
import torch
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
server_error_msg = ( server_error_msg = (
@ -17,6 +17,8 @@ handler = None
def get_gpu_memory(max_gpus=None): def get_gpu_memory(max_gpus=None):
import torch
gpu_memory = [] gpu_memory = []
num_gpus = ( num_gpus = (
torch.cuda.device_count() torch.cuda.device_count()
@ -130,3 +132,15 @@ def pretty_print_semaphore(semaphore):
if semaphore is None: if semaphore is None:
return "None" return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
def get_or_create_event_loop() -> asyncio.BaseEventLoop:
try:
loop = asyncio.get_event_loop()
except Exception as e:
if not "no running event loop" in str(e):
raise e
logging.warning("Cant not get running event loop, create new event loop now")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop

View File

@ -77,3 +77,6 @@ bardapi==0.1.29
pymysql pymysql
duckdb duckdb
duckdb-engine duckdb-engine
# cli
prettytable

View File

@ -267,7 +267,7 @@ def llama_cpp_python_cuda_requires():
llama_cpp_version = "0.1.77" llama_cpp_version = "0.1.77"
py_version = "cp310" py_version = "cp310"
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64" os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl" extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
extra_index_url, _ = encode_url(extra_index_url) extra_index_url, _ = encode_url(extra_index_url)
print(f"Install llama_cpp_python_cuda from {extra_index_url}") print(f"Install llama_cpp_python_cuda from {extra_index_url}")
@ -361,7 +361,7 @@ setuptools.setup(
extras_require=setup_spec.extras, extras_require=setup_spec.extras,
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"dbgpt_server=pilot.server:webserver", "dbgpt=pilot.scripts.cli_scripts:main",
], ],
}, },
) )

View File

@ -25,7 +25,6 @@ sys.path.append(
from pilot.configs.model_config import DATASETS_DIR from pilot.configs.model_config import DATASETS_DIR
from tools.cli.knowledge_client import knowledge_init
API_ADDRESS: str = "http://127.0.0.1:5000" API_ADDRESS: str = "http://127.0.0.1:5000"
@ -97,6 +96,8 @@ def knowledge(
verbose: bool, verbose: bool,
): ):
"""Knowledge command line tool""" """Knowledge command line tool"""
from tools.cli.knowledge_client import knowledge_init
knowledge_init( knowledge_init(
API_ADDRESS, API_ADDRESS,
vector_name, vector_name,