mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-12 13:42:23 +00:00
feat: Optimize code import time
This commit is contained in:
parent
0bc5134a07
commit
f19551a7cd
@ -143,3 +143,9 @@ SUMMARY_CONFIG=FAST
|
|||||||
# CUDA_VISIBLE_DEVICES=0
|
# CUDA_VISIBLE_DEVICES=0
|
||||||
## You can configure the maximum memory used by each GPU.
|
## You can configure the maximum memory used by each GPU.
|
||||||
# MAX_GPU_MEMORY=16Gib
|
# MAX_GPU_MEMORY=16Gib
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#** LOG **#
|
||||||
|
#*******************************************************************#
|
||||||
|
# FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET
|
||||||
|
DBGPT_LOG_LEVEL=INFO
|
@ -1,4 +1,12 @@
|
|||||||
from pilot.embedding_engine import SourceEmbedding, register
|
# Old packages
|
||||||
from pilot.embedding_engine import EmbeddingEngine, KnowledgeType
|
# __all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
|
||||||
|
|
||||||
__all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
|
__all__ = ["embedding_engine"]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
if name in ["embedding_engine"]:
|
||||||
|
return importlib.import_module("." + name, __name__)
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
@ -12,7 +12,7 @@ from pilot.json_utils.json_fix_general import (
|
|||||||
fix_invalid_escape,
|
fix_invalid_escape,
|
||||||
)
|
)
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
from pilot.speech import say_text
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -87,6 +87,8 @@ def correct_json(json_to_load: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
|
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
|
||||||
|
from pilot.speech.say import say_text
|
||||||
|
|
||||||
if CFG.speak_mode and CFG.debug_mode:
|
if CFG.speak_mode and CFG.debug_mode:
|
||||||
say_text(
|
say_text(
|
||||||
"I have received an invalid JSON response from the OpenAI API. "
|
"I have received an invalid JSON response from the OpenAI API. "
|
||||||
|
@ -7,7 +7,6 @@ from typing import Dict
|
|||||||
from pilot.commands.exception_not_commands import NotCommands
|
from pilot.commands.exception_not_commands import NotCommands
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.prompts.generator import PluginPromptGenerator
|
from pilot.prompts.generator import PluginPromptGenerator
|
||||||
from pilot.speech import say_text
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_pathlike_command_args(command_args):
|
def _resolve_pathlike_command_args(command_args):
|
||||||
@ -37,6 +36,8 @@ def execute_ai_response_json(
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from pilot.speech.say import say_text
|
||||||
|
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
command_name, arguments = get_command(ai_response)
|
command_name, arguments = get_command(ai_response)
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import markdown2
|
import markdown2
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
|
|
||||||
def datas_to_table_html(data):
|
def datas_to_table_html(data):
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
df = pd.DataFrame(data[1:], columns=data[0])
|
df = pd.DataFrame(data[1:], columns=data[0])
|
||||||
table_style = """<style>
|
table_style = """<style>
|
||||||
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
|
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import List, Any
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,8 +3,6 @@ import sqlparse
|
|||||||
import regex as re
|
import regex as re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
MetaData,
|
MetaData,
|
||||||
@ -14,7 +12,7 @@ from sqlalchemy import (
|
|||||||
select,
|
select,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.engine import CursorResult, Engine
|
from sqlalchemy.engine import CursorResult
|
||||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||||
from sqlalchemy.schema import CreateTable
|
from sqlalchemy.schema import CreateTable
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||||
|
@ -4,12 +4,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import nltk
|
|
||||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
|
||||||
|
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
from pilot.common.sql_database import Database
|
|
||||||
from pilot.prompts.prompt_registry import PromptTemplateRegistry
|
|
||||||
|
|
||||||
|
|
||||||
class Config(metaclass=Singleton):
|
class Config(metaclass=Singleton):
|
||||||
@ -78,6 +73,8 @@ class Config(metaclass=Singleton):
|
|||||||
)
|
)
|
||||||
self.speak_mode = False
|
self.speak_mode = False
|
||||||
|
|
||||||
|
from pilot.prompts.prompt_registry import PromptTemplateRegistry
|
||||||
|
|
||||||
self.prompt_template_registry = PromptTemplateRegistry()
|
self.prompt_template_registry = PromptTemplateRegistry()
|
||||||
### Related configuration of built-in commands
|
### Related configuration of built-in commands
|
||||||
self.command_registry = []
|
self.command_registry = []
|
||||||
@ -98,6 +95,8 @@ class Config(metaclass=Singleton):
|
|||||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||||
|
|
||||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||||
|
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||||
|
|
||||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||||
self.plugins_openai = []
|
self.plugins_openai = []
|
||||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||||
@ -183,6 +182,9 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None)
|
self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None)
|
||||||
|
|
||||||
|
### Log level
|
||||||
|
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
def set_debug_mode(self, value: bool) -> None:
|
def set_debug_mode(self, value: bool) -> None:
|
||||||
"""Set the debug mode value"""
|
"""Set the debug mode value"""
|
||||||
self.debug_mode = value
|
self.debug_mode = value
|
||||||
|
@ -3,8 +3,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import nltk
|
# import nltk
|
||||||
import torch
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||||
@ -13,7 +12,7 @@ VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
|||||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||||
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||||
|
|
||||||
@ -22,13 +21,19 @@ current_directory = os.getcwd()
|
|||||||
new_directory = PILOT_PATH
|
new_directory = PILOT_PATH
|
||||||
os.chdir(new_directory)
|
os.chdir(new_directory)
|
||||||
|
|
||||||
DEVICE = (
|
|
||||||
"cuda"
|
def get_device() -> str:
|
||||||
if torch.cuda.is_available()
|
import torch
|
||||||
else "mps"
|
|
||||||
if torch.backends.mps.is_available()
|
return (
|
||||||
else "cpu"
|
"cuda"
|
||||||
)
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
LLM_MODEL_CONFIG = {
|
LLM_MODEL_CONFIG = {
|
||||||
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
|
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
|
||||||
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
"""We need to design a base class. That other connector can Write with this"""
|
"""We need to design a base class. That other connector can Write with this"""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import duckdb
|
import duckdb
|
||||||
from typing import List
|
|
||||||
|
|
||||||
default_db_path = os.path.join(os.getcwd(), "message")
|
default_db_path = os.path.join(os.getcwd(), "message")
|
||||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
|
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import uuid
|
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
from pilot.language.translation_handler import get_lang_text
|
from pilot.language.translation_handler import get_lang_text
|
||||||
|
@ -12,9 +12,6 @@ class JsonFileHandler(logging.FileHandler):
|
|||||||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
class JsonFormatter(logging.Formatter):
|
class JsonFormatter(logging.Formatter):
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
return record.msg
|
return record.msg
|
||||||
|
@ -8,9 +8,7 @@ from typing import Any
|
|||||||
|
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
|
|
||||||
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
|
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
from pilot.speech import say_text
|
|
||||||
|
|
||||||
|
|
||||||
class Logger(metaclass=Singleton):
|
class Logger(metaclass=Singleton):
|
||||||
@ -86,6 +84,8 @@ class Logger(metaclass=Singleton):
|
|||||||
def typewriter_log(
|
def typewriter_log(
|
||||||
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
||||||
):
|
):
|
||||||
|
from pilot.speech.say import say_text
|
||||||
|
|
||||||
if speak_text and self.speak_mode:
|
if speak_text and self.speak_mode:
|
||||||
say_text(f"{title}. {content}")
|
say_text(f"{title}. {content}")
|
||||||
|
|
||||||
@ -159,6 +159,8 @@ class Logger(metaclass=Singleton):
|
|||||||
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
|
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
|
||||||
|
|
||||||
def log_json(self, data: Any, file_name: str) -> None:
|
def log_json(self, data: Any, file_name: str) -> None:
|
||||||
|
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
|
||||||
|
|
||||||
# Define log directory
|
# Define log directory
|
||||||
this_files_dir_path = os.path.dirname(__file__)
|
this_files_dir_path = os.path.dirname(__file__)
|
||||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||||
@ -255,6 +257,8 @@ def print_assistant_thoughts(
|
|||||||
assistant_reply_json_valid: object,
|
assistant_reply_json_valid: object,
|
||||||
speak_mode: bool = False,
|
speak_mode: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
from pilot.speech.say import say_text
|
||||||
|
|
||||||
assistant_thoughts_reasoning = None
|
assistant_thoughts_reasoning = None
|
||||||
assistant_thoughts_plan = None
|
assistant_thoughts_plan = None
|
||||||
assistant_thoughts_speak = None
|
assistant_thoughts_speak = None
|
||||||
|
@ -1,18 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import List
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Generic,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
|
|
||||||
|
@ -7,9 +7,7 @@ from pilot.configs.config import Config
|
|||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
from pilot.scene.message import (
|
from pilot.scene.message import (
|
||||||
OnceConversation,
|
OnceConversation,
|
||||||
conversation_from_dict,
|
|
||||||
_conversation_to_dic,
|
_conversation_to_dic,
|
||||||
conversations_to_dict,
|
|
||||||
)
|
)
|
||||||
from pilot.common.formatting import MyEncoder
|
from pilot.common.formatting import MyEncoder
|
||||||
|
|
||||||
|
@ -1,17 +1,9 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import datetime
|
|
||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.message import (
|
from pilot.scene.message import OnceConversation
|
||||||
OnceConversation,
|
from pilot.common.custom_data_structure import FixedSizeDict
|
||||||
conversation_from_dict,
|
|
||||||
conversations_to_dict,
|
|
||||||
)
|
|
||||||
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import torch
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -14,7 +13,7 @@ from transformers import (
|
|||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
)
|
)
|
||||||
from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
|
from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import get_device
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
|
|
||||||
@ -147,9 +146,11 @@ class ChatGLMAdapater(BaseLLMAdaper):
|
|||||||
return "chatglm" in model_path
|
return "chatglm" in model_path
|
||||||
|
|
||||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
import torch
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
if DEVICE != "cuda":
|
if get_device() != "cuda":
|
||||||
model = AutoModel.from_pretrained(
|
model = AutoModel.from_pretrained(
|
||||||
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
||||||
).float()
|
).float()
|
||||||
|
3
pilot/model/cache/base.py
vendored
3
pilot/model/cache/base.py
vendored
@ -1,6 +1,3 @@
|
|||||||
import json
|
|
||||||
import hashlib
|
|
||||||
from typing import Any, Dict
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,11 +2,7 @@ import click
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
from pilot.model.controller.registry import ModelRegistryClient
|
from pilot.model.controller.registry import ModelRegistryClient
|
||||||
from pilot.model.worker.manager import (
|
from pilot.model.base import WorkerApplyType
|
||||||
RemoteWorkerManager,
|
|
||||||
WorkerApplyRequest,
|
|
||||||
WorkerApplyType,
|
|
||||||
)
|
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
ModelControllerParameters,
|
ModelControllerParameters,
|
||||||
ModelWorkerParameters,
|
ModelWorkerParameters,
|
||||||
@ -15,12 +11,14 @@ from pilot.model.parameter import (
|
|||||||
from pilot.utils import get_or_create_event_loop
|
from pilot.utils import get_or_create_event_loop
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
|
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
||||||
|
|
||||||
|
|
||||||
@click.group("model")
|
@click.group("model")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--address",
|
"--address",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://127.0.0.1:8000",
|
default=MODEL_CONTROLLER_ADDRESS,
|
||||||
required=False,
|
required=False,
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help=(
|
help=(
|
||||||
@ -28,24 +26,25 @@ from pilot.utils.parameter_utils import EnvArgumentParser
|
|||||||
"Just support light deploy model"
|
"Just support light deploy model"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def model_cli_group():
|
def model_cli_group(address: str):
|
||||||
"""Clients that manage model serving"""
|
"""Clients that manage model serving"""
|
||||||
pass
|
global MODEL_CONTROLLER_ADDRESS
|
||||||
|
MODEL_CONTROLLER_ADDRESS = address
|
||||||
|
|
||||||
|
|
||||||
@model_cli_group.command()
|
@model_cli_group.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
"--model-name", type=str, default=None, required=False, help=("The name of model")
|
"--model_name", type=str, default=None, required=False, help=("The name of model")
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--model-type", type=str, default="llm", required=False, help=("The type of model")
|
"--model_type", type=str, default="llm", required=False, help=("The type of model")
|
||||||
)
|
)
|
||||||
def list(address: str, model_name: str, model_type: str):
|
def list(model_name: str, model_type: str):
|
||||||
"""List model instances"""
|
"""List model instances"""
|
||||||
from prettytable import PrettyTable
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
loop = get_or_create_event_loop()
|
loop = get_or_create_event_loop()
|
||||||
registry = ModelRegistryClient(address)
|
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
|
||||||
|
|
||||||
if not model_name:
|
if not model_name:
|
||||||
instances = loop.run_until_complete(registry.get_all_model_instances())
|
instances = loop.run_until_complete(registry.get_all_model_instances())
|
||||||
@ -88,14 +87,14 @@ def list(address: str, model_name: str, model_type: str):
|
|||||||
|
|
||||||
def add_model_options(func):
|
def add_model_options(func):
|
||||||
@click.option(
|
@click.option(
|
||||||
"--model-name",
|
"--model_name",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=True,
|
required=True,
|
||||||
help=("The name of model"),
|
help=("The name of model"),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--model-type",
|
"--model_type",
|
||||||
type=str,
|
type=str,
|
||||||
default="llm",
|
default="llm",
|
||||||
required=False,
|
required=False,
|
||||||
@ -110,23 +109,27 @@ def add_model_options(func):
|
|||||||
|
|
||||||
@model_cli_group.command()
|
@model_cli_group.command()
|
||||||
@add_model_options
|
@add_model_options
|
||||||
def stop(address: str, model_name: str, model_type: str):
|
def stop(model_name: str, model_type: str):
|
||||||
"""Stop model instances"""
|
"""Stop model instances"""
|
||||||
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
|
worker_apply(MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.STOP)
|
||||||
|
|
||||||
|
|
||||||
@model_cli_group.command()
|
@model_cli_group.command()
|
||||||
@add_model_options
|
@add_model_options
|
||||||
def start(address: str, model_name: str, model_type: str):
|
def start(model_name: str, model_type: str):
|
||||||
"""Start model instances"""
|
"""Start model instances"""
|
||||||
worker_apply(address, model_name, model_type, WorkerApplyType.START)
|
worker_apply(
|
||||||
|
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.START
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@model_cli_group.command()
|
@model_cli_group.command()
|
||||||
@add_model_options
|
@add_model_options
|
||||||
def restart(address: str, model_name: str, model_type: str):
|
def restart(model_name: str, model_type: str):
|
||||||
"""Restart model instances"""
|
"""Restart model instances"""
|
||||||
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
|
worker_apply(
|
||||||
|
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.RESTART
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# @model_cli_group.command()
|
# @model_cli_group.command()
|
||||||
@ -139,6 +142,8 @@ def restart(address: str, model_name: str, model_type: str):
|
|||||||
def worker_apply(
|
def worker_apply(
|
||||||
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||||
):
|
):
|
||||||
|
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
|
||||||
|
|
||||||
loop = get_or_create_event_loop()
|
loop = get_or_create_event_loop()
|
||||||
registry = ModelRegistryClient(address)
|
registry = ModelRegistryClient(address)
|
||||||
worker_manager = RemoteWorkerManager(registry)
|
worker_manager = RemoteWorkerManager(registry)
|
||||||
|
@ -6,7 +6,7 @@ Conversation prompt templates.
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from enum import auto, IntEnum
|
from enum import auto, IntEnum
|
||||||
from typing import List, Any, Dict, Callable
|
from typing import List, Dict, Callable
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(IntEnum):
|
class SeparatorStyle(IntEnum):
|
||||||
|
@ -9,8 +9,6 @@ from typing import Iterable, Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
|
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
from typing import Dict, Any
|
from typing import Dict
|
||||||
import torch
|
import torch
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
|
||||||
|
@ -7,13 +7,7 @@ import time
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.conversation import (
|
from pilot.conversation import Conversation
|
||||||
Conversation,
|
|
||||||
auto_dbgpt_one_shot,
|
|
||||||
conv_one_shot,
|
|
||||||
conv_templates,
|
|
||||||
)
|
|
||||||
from pilot.model.llm.base import Message
|
|
||||||
|
|
||||||
|
|
||||||
# TODO Rewrite this
|
# TODO Rewrite this
|
||||||
|
@ -3,11 +3,9 @@
|
|||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
import re
|
import re
|
||||||
import copy
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
|
||||||
from pilot.scene.base_message import ModelMessage, _parse_model_messages
|
from pilot.scene.base_message import ModelMessage, _parse_model_messages
|
||||||
|
|
||||||
# TODO move sep to scene prompt of model
|
# TODO move sep to scene prompt of model
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import copy
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||||
|
|
||||||
|
@ -2,16 +2,13 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
import torch
|
|
||||||
|
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import get_device
|
||||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
||||||
from pilot.model.compression import compress_module
|
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
ModelParameters,
|
ModelParameters,
|
||||||
LlamaCppModelParameters,
|
LlamaCppModelParameters,
|
||||||
)
|
)
|
||||||
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
|
|
||||||
from pilot.utils import get_gpu_memory
|
from pilot.utils import get_gpu_memory
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
|
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
@ -67,7 +64,7 @@ class ModelLoader:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_path: str, model_name: str = None) -> None:
|
def __init__(self, model_path: str, model_name: str = None) -> None:
|
||||||
self.device = DEVICE
|
self.device = get_device()
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.prompt_template: str = None
|
self.prompt_template: str = None
|
||||||
@ -127,6 +124,9 @@ class ModelLoader:
|
|||||||
|
|
||||||
|
|
||||||
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
|
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
|
||||||
|
import torch
|
||||||
|
from pilot.model.compression import compress_module
|
||||||
|
|
||||||
device = model_params.device
|
device = model_params.device
|
||||||
max_memory = None
|
max_memory = None
|
||||||
|
|
||||||
@ -156,6 +156,10 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
|
|||||||
|
|
||||||
elif device == "mps":
|
elif device == "mps":
|
||||||
kwargs = {"torch_dtype": torch.float16}
|
kwargs = {"torch_dtype": torch.float16}
|
||||||
|
from pilot.model.llm.monkey_patch import (
|
||||||
|
replace_llama_attn_with_non_inplace_operations,
|
||||||
|
)
|
||||||
|
|
||||||
replace_llama_attn_with_non_inplace_operations()
|
replace_llama_attn_with_non_inplace_operations()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid device: {device}")
|
raise ValueError(f"Invalid device: {device}")
|
||||||
@ -200,6 +204,8 @@ def load_huggingface_quantization_model(
|
|||||||
kwargs: Dict,
|
kwargs: Dict,
|
||||||
max_memory: Dict[int, str],
|
max_memory: Dict[int, str],
|
||||||
):
|
):
|
||||||
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils import infer_auto_device_map
|
from accelerate.utils import infer_auto_device_map
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from dataclasses import dataclass, field, fields, MISSING
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from pilot.model.conversation import conv_templates
|
from pilot.model.conversation import conv_templates
|
||||||
from pilot.utils.parameter_utils import BaseParameters
|
from pilot.utils.parameter_utils import BaseParameters
|
||||||
|
@ -2,8 +2,7 @@ import logging
|
|||||||
import platform
|
import platform
|
||||||
from typing import Dict, Iterator, List
|
from typing import Dict, Iterator, List
|
||||||
|
|
||||||
import torch
|
from pilot.configs.model_config import get_device
|
||||||
from pilot.configs.model_config import DEVICE
|
|
||||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||||
from pilot.model.base import ModelOutput
|
from pilot.model.base import ModelOutput
|
||||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||||
@ -63,7 +62,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
if not model_params.device:
|
if not model_params.device:
|
||||||
model_params.device = DEVICE
|
model_params.device = get_device()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
|
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
|
||||||
)
|
)
|
||||||
@ -88,6 +87,8 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
_clear_torch_cache(self._model_params.device)
|
_clear_torch_cache(self._model_params.device)
|
||||||
|
|
||||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||||
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# params adaptation
|
# params adaptation
|
||||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||||
@ -95,7 +96,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for output in self.generate_stream_func(
|
for output in self.generate_stream_func(
|
||||||
self.model, self.tokenizer, params, DEVICE, self.context_len
|
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||||
):
|
):
|
||||||
# Please do not open the output in production!
|
# Please do not open the output in production!
|
||||||
# The gpt4all thread shares stdout with the parent process,
|
# The gpt4all thread shares stdout with the parent process,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Type
|
from typing import Dict, List, Type
|
||||||
|
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import get_device
|
||||||
from pilot.model.loader import _get_model_real_path
|
from pilot.model.loader import _get_model_real_path
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
EmbeddingModelParameters,
|
EmbeddingModelParameters,
|
||||||
@ -55,7 +55,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
|||||||
model_path=self.model_path,
|
model_path=self.model_path,
|
||||||
)
|
)
|
||||||
if not model_params.device:
|
if not model_params.device:
|
||||||
model_params.device = DEVICE
|
model_params.device = get_device()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import httpx
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -7,26 +6,21 @@ import random
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import APIRouter, FastAPI, Request
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.model.base import (
|
from pilot.model.base import (
|
||||||
ModelInstance,
|
ModelInstance,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
WorkerApplyType,
|
|
||||||
WorkerApplyOutput,
|
WorkerApplyOutput,
|
||||||
|
WorkerApplyType,
|
||||||
)
|
)
|
||||||
from pilot.model.controller.registry import ModelRegistry
|
from pilot.model.controller.registry import ModelRegistry
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||||
ModelParameters,
|
|
||||||
ModelWorkerParameters,
|
|
||||||
WorkerType,
|
|
||||||
)
|
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.worker.base import ModelWorker
|
||||||
from pilot.scene.base_message import ModelMessage
|
from pilot.scene.base_message import ModelMessage
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
@ -431,6 +425,8 @@ class RemoteWorkerManager(LocalWorkerManager):
|
|||||||
return worker_instances
|
return worker_instances
|
||||||
|
|
||||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||||
|
import httpx
|
||||||
|
|
||||||
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
||||||
worker_addr = worker_run_data.worker.worker_addr
|
worker_addr = worker_run_data.worker.worker_addr
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
@ -700,6 +696,8 @@ def run_worker_manager(
|
|||||||
app.include_router(router, prefix="/api")
|
app.include_router(router, prefix="/api")
|
||||||
|
|
||||||
if not embedded_mod:
|
if not embedded_mod:
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app, host=worker_params.host, port=worker_params.port, log_level="info"
|
app, host=worker_params.host, port=worker_params.port, log_level="info"
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, Iterator, List
|
from typing import Dict, Iterator, List
|
||||||
|
import logging
|
||||||
import httpx
|
|
||||||
from pilot.model.base import ModelOutput
|
from pilot.model.base import ModelOutput
|
||||||
from pilot.model.parameter import ModelParameters
|
from pilot.model.parameter import ModelParameters
|
||||||
from pilot.model.worker.base import ModelWorker
|
from pilot.model.worker.base import ModelWorker
|
||||||
@ -10,7 +9,8 @@ from pilot.model.worker.base import ModelWorker
|
|||||||
class RemoteModelWorker(ModelWorker):
|
class RemoteModelWorker(ModelWorker):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.headers = {}
|
self.headers = {}
|
||||||
self.timeout = 60
|
# TODO Configured by ModelParameters
|
||||||
|
self.timeout = 180
|
||||||
self.host = None
|
self.host = None
|
||||||
self.port = None
|
self.port = None
|
||||||
|
|
||||||
@ -44,7 +44,9 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
|
|
||||||
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||||
"""Asynchronous generate stream"""
|
"""Asynchronous generate stream"""
|
||||||
print(f"Send async_generate_stream, params: {params}")
|
import httpx
|
||||||
|
|
||||||
|
logging.debug(f"Send async_generate_stream, params: {params}")
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
delimiter = b"\0"
|
delimiter = b"\0"
|
||||||
buffer = b""
|
buffer = b""
|
||||||
@ -71,8 +73,9 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
|
|
||||||
async def async_generate(self, params: Dict) -> ModelOutput:
|
async def async_generate(self, params: Dict) -> ModelOutput:
|
||||||
"""Asynchronous generate non stream"""
|
"""Asynchronous generate non stream"""
|
||||||
print(f"Send async_generate_stream, params: {params}")
|
import httpx
|
||||||
|
|
||||||
|
logging.debug(f"Send async_generate_stream, params: {params}")
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self.worker_addr + "/generate",
|
self.worker_addr + "/generate",
|
||||||
@ -88,6 +91,8 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
|
|
||||||
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
||||||
"""Asynchronous get embeddings for input"""
|
"""Asynchronous get embeddings for input"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self.worker_addr + "/embeddings",
|
self.worker_addr + "/embeddings",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import TypeVar, Union, List, Generic, Any
|
from typing import TypeVar, Generic, Any
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
@ -1,24 +1,6 @@
|
|||||||
from fastapi import (
|
from fastapi import Request
|
||||||
APIRouter,
|
|
||||||
Request,
|
|
||||||
Body,
|
|
||||||
status,
|
|
||||||
HTTPException,
|
|
||||||
Response,
|
|
||||||
BackgroundTasks,
|
|
||||||
)
|
|
||||||
|
|
||||||
from fastapi.responses import JSONResponse, HTMLResponse
|
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from pilot.openapi.api_view_model import Result
|
||||||
from pilot.openapi.api_view_model import (
|
|
||||||
Result,
|
|
||||||
ConversationVo,
|
|
||||||
MessageVo,
|
|
||||||
ChatSceneVo,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import TypeVar, Union, List, Generic, Any
|
from typing import List, Any
|
||||||
|
|
||||||
|
|
||||||
class DbField(BaseModel):
|
class DbField(BaseModel):
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
from abc import ABC
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Any, Dict, TypeVar, Union
|
from typing import Any, Dict, TypeVar, Union
|
||||||
|
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
import json
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from typing import List
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
|
||||||
|
|
||||||
import yaml
|
from pydantic import BaseModel
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
|
||||||
|
|
||||||
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
from typing import List
|
||||||
|
|
||||||
from pilot.common.schema import ExampleType
|
from pilot.common.schema import ExampleType
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
from pilot.common.formatting import formatter, no_strict_formatter
|
from pilot.common.formatting import formatter, no_strict_formatter
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
import json
|
|
||||||
|
|
||||||
_DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__"
|
_DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__"
|
||||||
_DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__"
|
_DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__"
|
||||||
|
@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
|||||||
|
|
||||||
from pilot.out_parser.base import BaseOutputParser
|
from pilot.out_parser.base import BaseOutputParser
|
||||||
from pilot.prompts.base import PromptValue
|
from pilot.prompts.base import PromptValue
|
||||||
from pilot.scene.base_message import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
from pilot.scene.base_message import HumanMessage, BaseMessage
|
||||||
from pilot.common.formatting import formatter
|
from pilot.common.formatting import formatter
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,44 +1,20 @@
|
|||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import datetime
|
import datetime
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
from abc import ABC, abstractmethod
|
||||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
from typing import Any, List
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Generic,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
import requests
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
import pilot.configs.config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.prompts.prompt_new import PromptTemplate
|
|
||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
||||||
from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
||||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
from pilot.scene.message import OnceConversation
|
||||||
from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop
|
from pilot.utils import build_logger, get_or_create_event_loop
|
||||||
from pilot.scene.base_message import (
|
from pydantic import Extra
|
||||||
BaseMessage,
|
|
||||||
SystemMessage,
|
|
||||||
HumanMessage,
|
|
||||||
AIMessage,
|
|
||||||
ViewMessage,
|
|
||||||
ModelMessage,
|
|
||||||
ModelMessageRoleType,
|
|
||||||
)
|
|
||||||
from pilot.configs.config import Config
|
|
||||||
|
|
||||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
|
@ -1,20 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import Any, Dict, List, Tuple, Optional
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Generic,
|
|
||||||
List,
|
|
||||||
Tuple,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(BaseModel, ABC):
|
class PromptValue(BaseModel, ABC):
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import json
|
from pydantic import BaseModel
|
||||||
from pydantic import BaseModel, Field
|
from typing import List, Any
|
||||||
from typing import TypeVar, Union, List, Generic, Any
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
|
|
||||||
|
|
||||||
class ValueItem(BaseModel):
|
class ValueItem(BaseModel):
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
from typing import NamedTuple, List
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, NamedTuple, List
|
|
||||||
import pandas as pd
|
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
@ -3,17 +3,11 @@ import os
|
|||||||
|
|
||||||
|
|
||||||
from typing import List, Any, Dict
|
from typing import List, Any, Dict
|
||||||
from pilot.scene.base_message import (
|
|
||||||
HumanMessage,
|
|
||||||
ViewMessage,
|
|
||||||
)
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger
|
from pilot.scene.base_chat import BaseChat, logger
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.common.markdown_text import (
|
|
||||||
generate_htm_table,
|
|
||||||
)
|
|
||||||
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
|
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
|
||||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||||
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
|
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
|
||||||
|
@ -1,16 +1,7 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from pilot.scene.base_message import (
|
|
||||||
HumanMessage,
|
|
||||||
ViewMessage,
|
|
||||||
)
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.common.markdown_text import (
|
|
||||||
generate_htm_table,
|
|
||||||
)
|
|
||||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, NamedTuple
|
from typing import Dict, NamedTuple
|
||||||
import pandas as pd
|
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
@ -36,6 +33,8 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
return SqlAction(sql, thoughts)
|
return SqlAction(sql, thoughts)
|
||||||
|
|
||||||
def parse_view_response(self, speak, data) -> str:
|
def parse_view_response(self, speak, data) -> str:
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
### tool out data to table view
|
### tool out data to table view
|
||||||
data_loader = DbDataLoader()
|
data_loader = DbDataLoader()
|
||||||
if len(data) <= 1:
|
if len(data) <= 1:
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
|
||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.common.schema import SeparatorStyle
|
||||||
from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import json
|
|||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
|
||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.common.schema import SeparatorStyle
|
||||||
from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import pandas as pd
|
|
||||||
|
|
||||||
|
|
||||||
class DbDataLoader:
|
class DbDataLoader:
|
||||||
def get_table_view_by_conn(self, data, speak):
|
def get_table_view_by_conn(self, data, speak):
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
### tool out data to table view
|
### tool out data to table view
|
||||||
if len(data) <= 1:
|
if len(data) <= 1:
|
||||||
data.insert(0, ["result"])
|
data.insert(0, ["result"])
|
||||||
|
@ -1,14 +1,7 @@
|
|||||||
from pilot.scene.base_message import (
|
|
||||||
HumanMessage,
|
|
||||||
ViewMessage,
|
|
||||||
)
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.common.markdown_text import (
|
|
||||||
generate_htm_table,
|
|
||||||
)
|
|
||||||
from pilot.scene.chat_db.professional_qa.prompt import prompt
|
from pilot.scene.chat_db.professional_qa.prompt import prompt
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
@ -1,12 +1,6 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, NamedTuple
|
|
||||||
import pandas as pd
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import json
|
|
||||||
import importlib
|
|
||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
|
@ -1,11 +1,6 @@
|
|||||||
import requests
|
|
||||||
import datetime
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
from typing import List
|
from typing import List
|
||||||
import traceback
|
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.message import OnceConversation
|
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.commands.command import execute_command
|
from pilot.commands.command import execute_command
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, NamedTuple
|
from typing import Dict, NamedTuple
|
||||||
import pandas as pd
|
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
@ -1,20 +1,20 @@
|
|||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
import inspect
|
|
||||||
import importlib
|
|
||||||
from pilot.scene.chat_execution.chat import ChatWithPlugin
|
|
||||||
from pilot.scene.chat_normal.chat import ChatNormal
|
|
||||||
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
|
|
||||||
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
|
|
||||||
from pilot.scene.chat_dashboard.chat import ChatDashboard
|
|
||||||
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
|
|
||||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
|
||||||
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
|
||||||
|
|
||||||
|
|
||||||
class ChatFactory(metaclass=Singleton):
|
class ChatFactory(metaclass=Singleton):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_implementation(chat_mode, **kwargs):
|
def get_implementation(chat_mode, **kwargs):
|
||||||
|
# Lazy loading
|
||||||
|
from pilot.scene.chat_execution.chat import ChatWithPlugin
|
||||||
|
from pilot.scene.chat_normal.chat import ChatNormal
|
||||||
|
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
|
||||||
|
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
|
||||||
|
from pilot.scene.chat_dashboard.chat import ChatDashboard
|
||||||
|
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
|
||||||
|
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||||
|
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
||||||
|
|
||||||
chat_classes = BaseChat.__subclasses__()
|
chat_classes = BaseChat.__subclasses__()
|
||||||
implementation = None
|
implementation = None
|
||||||
for cls in chat_classes:
|
for cls in chat_classes:
|
||||||
|
@ -1,8 +1,3 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, NamedTuple
|
|
||||||
import pandas as pd
|
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import builtins
|
|
||||||
import importlib
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
@ -1,25 +1,15 @@
|
|||||||
from chromadb.errors import NoIndexException
|
from chromadb.errors import NoIndexException
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
from pilot.common.markdown_text import (
|
|
||||||
generate_markdown_table,
|
|
||||||
generate_htm_table,
|
|
||||||
datas_to_table_html,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pilot.configs.model_config import (
|
from pilot.configs.model_config import (
|
||||||
DATASETS_DIR,
|
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
LLM_MODEL_CONFIG,
|
LLM_MODEL_CONFIG,
|
||||||
LOGDIR,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
|
||||||
from pilot.server.knowledge.service import KnowledgeService
|
from pilot.server.knowledge.service import KnowledgeService
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -32,6 +22,8 @@ class ChatKnowledge(BaseChat):
|
|||||||
|
|
||||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||||
""" """
|
""" """
|
||||||
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
|
|
||||||
self.knowledge_space = select_param
|
self.knowledge_space = select_param
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatKnowledge,
|
chat_mode=ChatScene.ChatKnowledge,
|
||||||
|
@ -1,8 +1,3 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, NamedTuple
|
|
||||||
import pandas as pd
|
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
import builtins
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
import builtins
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
|
@ -1,13 +1,7 @@
|
|||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
from pilot.common.markdown_text import (
|
|
||||||
generate_markdown_table,
|
|
||||||
generate_htm_table,
|
|
||||||
datas_to_table_html,
|
|
||||||
)
|
|
||||||
from pilot.scene.chat_normal.prompt import prompt
|
from pilot.scene.chat_normal.prompt import prompt
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
@ -1,8 +1,3 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, NamedTuple
|
|
||||||
import pandas as pd
|
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
import builtins
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
|
@ -1,13 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Field, root_validator, validator
|
from typing import List
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Generic,
|
|
||||||
List,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pilot.scene.base_message import (
|
from pilot.scene.base_message import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
|
@ -1,28 +1,18 @@
|
|||||||
import signal
|
import signal
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
from pilot.commands.command_mange import CommandRegistry
|
from pilot.commands.command_mange import CommandRegistry
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
# from pilot.configs.model_config import (
|
from pilot.common.plugins import scan_plugins
|
||||||
# DATASETS_DIR,
|
|
||||||
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
|
||||||
# LLM_MODEL_CONFIG,
|
|
||||||
# LOGDIR,
|
|
||||||
# )
|
|
||||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
from pilot.connections.manages.connection_manager import ConnectManager
|
from pilot.connections.manages.connection_manager import ConnectManager
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
# logger = build_logger("webserver", LOGDIR + "webserver.log")
|
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
print("in order to avoid chroma db atexit problem")
|
print("in order to avoid chroma db atexit problem")
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
|
||||||
from pilot.model.conversation import Conversation, get_conv_template
|
from pilot.model.conversation import Conversation, get_conv_template
|
||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
@ -131,6 +130,8 @@ class VicunaChatAdapter(BaseChatAdpter):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_generate_stream_func(self, model_path: str):
|
def get_generate_stream_func(self, model_path: str):
|
||||||
|
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
||||||
|
|
||||||
if self._is_llama2_based(model_path):
|
if self._is_llama2_based(model_path):
|
||||||
return super().get_generate_stream_func(model_path)
|
return super().get_generate_stream_func(model_path)
|
||||||
return generate_stream
|
return generate_stream
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
import atexit
|
|
||||||
import traceback
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
@ -11,7 +8,6 @@ sys.path.append(ROOT_PATH)
|
|||||||
import signal
|
import signal
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||||
from pilot.utils import build_logger
|
|
||||||
|
|
||||||
from pilot.server.base import server_init
|
from pilot.server.base import server_init
|
||||||
|
|
||||||
@ -28,14 +24,11 @@ from pilot.openapi.base import validation_exception_handler
|
|||||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||||
from pilot.model.worker.manager import initialize_worker_manager_in_client
|
from pilot.model.worker.manager import initialize_worker_manager_in_client
|
||||||
|
from pilot.utils.utils import setup_logging
|
||||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
|
||||||
|
|
||||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||||
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
# logger = build_logger("webserver", LOGDIR + "webserver.log")
|
|
||||||
|
|
||||||
|
|
||||||
def signal_handler():
|
def signal_handler():
|
||||||
@ -102,7 +95,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--port", type=int, default=5000)
|
parser.add_argument("--port", type=int, default=5000)
|
||||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||||
parser.add_argument("--share", default=False, action="store_true")
|
parser.add_argument("--share", default=False, action="store_true")
|
||||||
parser.add_argument("--log-level", type=str, default="info")
|
parser.add_argument("--log-level", type=str, default=None)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-light",
|
"-light",
|
||||||
"--light",
|
"--light",
|
||||||
@ -113,6 +106,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# init server config
|
# init server config
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
setup_logging(logging_level=args.log_level)
|
||||||
server_init(args)
|
server_init(args)
|
||||||
|
|
||||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||||
@ -137,6 +131,5 @@ if __name__ == "__main__":
|
|||||||
mount_static_files(app)
|
mount_static_files(app)
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="info")
|
||||||
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=args.log_level)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler())
|
signal.signal(signal.SIGINT, signal_handler())
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
|
|
||||||
from fastapi import APIRouter, File, UploadFile, Form
|
from fastapi import APIRouter, File, UploadFile, Form
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
|
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
from sqlalchemy.orm import declarative_base
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.connections.rdbms.base_dao import BaseDao
|
from pilot.connections.rdbms.base_dao import BaseDao
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
|
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
from sqlalchemy.orm import declarative_base
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.connections.rdbms.base_dao import BaseDao
|
from pilot.connections.rdbms.base_dao import BaseDao
|
||||||
|
@ -2,12 +2,10 @@ import json
|
|||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, SpacyTextSplitter
|
|
||||||
from pilot.vector_store.connector import VectorStoreConnector
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
from pilot.server.knowledge.chunk_db import (
|
from pilot.server.knowledge.chunk_db import (
|
||||||
DocumentChunkEntity,
|
DocumentChunkEntity,
|
||||||
@ -152,6 +150,14 @@ class KnowledgeService:
|
|||||||
"""sync knowledge document chunk into vector store"""
|
"""sync knowledge document chunk into vector store"""
|
||||||
|
|
||||||
def sync_knowledge_document(self, space_name, doc_ids):
|
def sync_knowledge_document(self, space_name, doc_ids):
|
||||||
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
|
from langchain.text_splitter import (
|
||||||
|
RecursiveCharacterTextSplitter,
|
||||||
|
SpacyTextSplitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# import langchain is very very slow!!!
|
||||||
|
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
query = KnowledgeDocumentEntity(
|
query = KnowledgeDocumentEntity(
|
||||||
id=doc_id,
|
id=doc_id,
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine
|
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||||
|
@ -4,9 +4,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
global_counter = 0
|
|
||||||
model_semaphore = None
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
@ -17,15 +14,6 @@ from pilot.model.worker.manager import run_worker_manager
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||||
# worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
|
|
||||||
|
|
||||||
# @app.post("/embedding")
|
|
||||||
# def embeddings(prompt_request: EmbeddingRequest):
|
|
||||||
# params = {"prompt": prompt_request.prompt}
|
|
||||||
# print("Received prompt: ", params["prompt"])
|
|
||||||
# output = worker.get_embeddings(params["prompt"])
|
|
||||||
# return {"response": [float(x) for x in output]}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_worker_manager(
|
run_worker_manager(
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
from pilot.speech.say import say_text
|
|
||||||
|
|
||||||
__all__ = ["say_text"]
|
|
@ -1,21 +1,19 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings, logger
|
from pilot.common.schema import DBType
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import (
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
LLM_MODEL_CONFIG,
|
||||||
|
LOGDIR,
|
||||||
|
)
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
|
||||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
|
||||||
from pilot.summary.rdbms_db_summary import RdbmsSummary
|
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
from pilot.common.schema import DBType
|
from pilot.summary.rdbms_db_summary import RdbmsSummary
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
|
||||||
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
|
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
|
||||||
|
|
||||||
|
|
||||||
@ -33,6 +31,8 @@ class DBSummaryClient:
|
|||||||
|
|
||||||
def db_summary_embedding(self, dbname, db_type):
|
def db_summary_embedding(self, dbname, db_type):
|
||||||
"""put db profile and table profile summary into vector store"""
|
"""put db profile and table profile summary into vector store"""
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||||
|
|
||||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||||
embeddings = HuggingFaceEmbeddings(
|
embeddings = HuggingFaceEmbeddings(
|
||||||
@ -82,6 +82,8 @@ class DBSummaryClient:
|
|||||||
logger.info("db summary embedding success")
|
logger.info("db summary embedding success")
|
||||||
|
|
||||||
def get_db_summary(self, dbname, query, topk):
|
def get_db_summary(self, dbname, query, topk):
|
||||||
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
|
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": dbname + "_profile",
|
"vector_store_name": dbname + "_profile",
|
||||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||||
@ -97,6 +99,8 @@ class DBSummaryClient:
|
|||||||
|
|
||||||
def get_similar_tables(self, dbname, query, topk):
|
def get_similar_tables(self, dbname, query, topk):
|
||||||
"""get user query related tables info"""
|
"""get user query related tables info"""
|
||||||
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
|
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
"vector_store_name": dbname + "_summary",
|
"vector_store_name": dbname + "_summary",
|
||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
@ -149,6 +153,8 @@ class DBSummaryClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
||||||
|
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||||
|
|
||||||
profile_store_config = {
|
profile_store_config = {
|
||||||
"vector_store_name": dbname + "_profile",
|
"vector_store_name": dbname + "_profile",
|
||||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import httpx
|
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
import typing_inspect
|
|
||||||
import logging
|
import logging
|
||||||
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
|
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
|
||||||
from dataclasses import is_dataclass, asdict
|
from dataclasses import is_dataclass, asdict
|
||||||
@ -9,6 +7,8 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||||
|
import typing_inspect
|
||||||
|
|
||||||
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
|
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
|
||||||
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
|
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
|
||||||
return typing_inspect.get_args(type_hint)[0]
|
return typing_inspect.get_args(type_hint)[0]
|
||||||
@ -30,6 +30,8 @@ def _api_remote(path, method="GET"):
|
|||||||
sig = signature(func)
|
sig = signature(func)
|
||||||
|
|
||||||
async def wrapper(self, *args, **kwargs):
|
async def wrapper(self, *args, **kwargs):
|
||||||
|
import httpx
|
||||||
|
|
||||||
base_url = self.base_url # Get base_url from class instance
|
base_url = self.base_url # Get base_url from class instance
|
||||||
|
|
||||||
bound = sig.bind(self, *args, **kwargs)
|
bound = sig.bind(self, *args, **kwargs)
|
||||||
|
@ -16,6 +16,22 @@ server_error_msg = (
|
|||||||
handler = None
|
handler = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_logging_level() -> str:
|
||||||
|
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(logging_level=None, logger_name: str = None):
|
||||||
|
if not logging_level:
|
||||||
|
logging_level = _get_logging_level()
|
||||||
|
if type(logging_level) is str:
|
||||||
|
logging_level = logging.getLevelName(logging_level.upper())
|
||||||
|
if logger_name:
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging_level)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging_level, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory(max_gpus=None):
|
def get_gpu_memory(max_gpus=None):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -47,7 +63,7 @@ def build_logger(logger_name, logger_filename):
|
|||||||
|
|
||||||
# Set the format of root handlers
|
# Set the format of root handlers
|
||||||
if not logging.getLogger().handlers:
|
if not logging.getLogger().handlers:
|
||||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
setup_logging()
|
||||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
# Redirect stdout and stderr to loggers
|
# Redirect stdout and stderr to loggers
|
||||||
@ -73,11 +89,11 @@ def build_logger(logger_name, logger_filename):
|
|||||||
for name, item in logging.root.manager.loggerDict.items():
|
for name, item in logging.root.manager.loggerDict.items():
|
||||||
if isinstance(item, logging.Logger):
|
if isinstance(item, logging.Logger):
|
||||||
item.addHandler(handler)
|
item.addHandler(handler)
|
||||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
setup_logging()
|
||||||
|
|
||||||
# Get logger
|
# Get logger
|
||||||
logger = logging.getLogger(logger_name)
|
logger = logging.getLogger(logger_name)
|
||||||
logger.setLevel(logging.INFO)
|
setup_logging(logger_name=logger_name)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from langchain.vectorstores import Chroma
|
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
from pilot.vector_store.base import VectorStoreBase
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
|
||||||
@ -11,6 +10,8 @@ class ChromaStore(VectorStoreBase):
|
|||||||
"""chroma database"""
|
"""chroma database"""
|
||||||
|
|
||||||
def __init__(self, ctx: {}) -> None:
|
def __init__(self, ctx: {}) -> None:
|
||||||
|
from langchain.vectorstores import Chroma
|
||||||
|
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.embeddings = ctx.get("embeddings", None)
|
self.embeddings = ctx.get("embeddings", None)
|
||||||
self.persist_dir = os.path.join(
|
self.persist_dir = os.path.join(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Iterable, List, Optional, Tuple
|
from typing import Any, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
from pymilvus import Collection, DataType, connections, utility
|
from pymilvus import Collection, DataType, connections, utility
|
||||||
|
|
||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
@ -279,7 +280,9 @@ class MilvusStore(VectorStoreBase):
|
|||||||
round_decimal: int = -1,
|
round_decimal: int = -1,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]:
|
):
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
self.col.load()
|
self.col.load()
|
||||||
# use default index params.
|
# use default index params.
|
||||||
if param is None:
|
if param is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user