mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-12 13:42:23 +00:00
feat: Optimize code import time
This commit is contained in:
parent
0bc5134a07
commit
f19551a7cd
@ -143,3 +143,9 @@ SUMMARY_CONFIG=FAST
|
||||
# CUDA_VISIBLE_DEVICES=0
|
||||
## You can configure the maximum memory used by each GPU.
|
||||
# MAX_GPU_MEMORY=16Gib
|
||||
|
||||
#*******************************************************************#
|
||||
#** LOG **#
|
||||
#*******************************************************************#
|
||||
# FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET
|
||||
DBGPT_LOG_LEVEL=INFO
|
@ -1,4 +1,12 @@
|
||||
from pilot.embedding_engine import SourceEmbedding, register
|
||||
from pilot.embedding_engine import EmbeddingEngine, KnowledgeType
|
||||
# Old packages
|
||||
# __all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
|
||||
|
||||
__all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
|
||||
__all__ = ["embedding_engine"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
import importlib
|
||||
|
||||
if name in ["embedding_engine"]:
|
||||
return importlib.import_module("." + name, __name__)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
@ -12,7 +12,7 @@ from pilot.json_utils.json_fix_general import (
|
||||
fix_invalid_escape,
|
||||
)
|
||||
from pilot.logs import logger
|
||||
from pilot.speech import say_text
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -87,6 +87,8 @@ def correct_json(json_to_load: str) -> str:
|
||||
|
||||
|
||||
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
if CFG.speak_mode and CFG.debug_mode:
|
||||
say_text(
|
||||
"I have received an invalid JSON response from the OpenAI API. "
|
||||
|
@ -7,7 +7,6 @@ from typing import Dict
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.generator import PluginPromptGenerator
|
||||
from pilot.speech import say_text
|
||||
|
||||
|
||||
def _resolve_pathlike_command_args(command_args):
|
||||
@ -37,6 +36,8 @@ def execute_ai_response_json(
|
||||
Returns:
|
||||
|
||||
"""
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
cfg = Config()
|
||||
|
||||
command_name, arguments = get_command(ai_response)
|
||||
|
@ -1,8 +1,9 @@
|
||||
import markdown2
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def datas_to_table_html(data):
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(data[1:], columns=data[0])
|
||||
table_style = """<style>
|
||||
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}
|
||||
|
@ -1,5 +1,4 @@
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
import os
|
||||
|
||||
|
||||
|
@ -3,8 +3,6 @@ import sqlparse
|
||||
import regex as re
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Optional
|
||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||
from abc import ABC, abstractmethod
|
||||
import sqlalchemy
|
||||
from sqlalchemy import (
|
||||
MetaData,
|
||||
@ -14,7 +12,7 @@ from sqlalchemy import (
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import CursorResult, Engine
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
|
@ -4,12 +4,7 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import nltk
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.prompts.prompt_registry import PromptTemplateRegistry
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
@ -78,6 +73,8 @@ class Config(metaclass=Singleton):
|
||||
)
|
||||
self.speak_mode = False
|
||||
|
||||
from pilot.prompts.prompt_registry import PromptTemplateRegistry
|
||||
|
||||
self.prompt_template_registry = PromptTemplateRegistry()
|
||||
### Related configuration of built-in commands
|
||||
self.command_registry = []
|
||||
@ -98,6 +95,8 @@ class Config(metaclass=Singleton):
|
||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins_openai = []
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||
@ -183,6 +182,9 @@ class Config(metaclass=Singleton):
|
||||
|
||||
self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None)
|
||||
|
||||
### Log level
|
||||
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
@ -3,8 +3,7 @@
|
||||
|
||||
import os
|
||||
|
||||
import nltk
|
||||
import torch
|
||||
# import nltk
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||
@ -13,7 +12,7 @@ VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
|
||||
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
|
||||
|
||||
@ -22,13 +21,19 @@ current_directory = os.getcwd()
|
||||
new_directory = PILOT_PATH
|
||||
os.chdir(new_directory)
|
||||
|
||||
DEVICE = (
|
||||
|
||||
def get_device() -> str:
|
||||
import torch
|
||||
|
||||
return (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
LLM_MODEL_CONFIG = {
|
||||
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
|
||||
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
"""We need to design a base class. That other connector can Write with this"""
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import duckdb
|
||||
from typing import List
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
|
||||
|
@ -2,7 +2,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import dataclasses
|
||||
import uuid
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
from pilot.language.translation_handler import get_lang_text
|
||||
|
@ -12,9 +12,6 @@ class JsonFileHandler(logging.FileHandler):
|
||||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
return record.msg
|
||||
|
@ -8,9 +8,7 @@ from typing import Any
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.speech import say_text
|
||||
|
||||
|
||||
class Logger(metaclass=Singleton):
|
||||
@ -86,6 +84,8 @@ class Logger(metaclass=Singleton):
|
||||
def typewriter_log(
|
||||
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
||||
):
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
if speak_text and self.speak_mode:
|
||||
say_text(f"{title}. {content}")
|
||||
|
||||
@ -159,6 +159,8 @@ class Logger(metaclass=Singleton):
|
||||
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
|
||||
|
||||
def log_json(self, data: Any, file_name: str) -> None:
|
||||
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
|
||||
|
||||
# Define log directory
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
@ -255,6 +257,8 @@ def print_assistant_thoughts(
|
||||
assistant_reply_json_valid: object,
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
assistant_thoughts_reasoning = None
|
||||
assistant_thoughts_plan = None
|
||||
assistant_thoughts_speak = None
|
||||
|
@ -1,18 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
from pilot.scene.message import OnceConversation
|
||||
|
||||
|
@ -7,9 +7,7 @@ from pilot.configs.config import Config
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.scene.message import (
|
||||
OnceConversation,
|
||||
conversation_from_dict,
|
||||
_conversation_to_dic,
|
||||
conversations_to_dict,
|
||||
)
|
||||
from pilot.common.formatting import MyEncoder
|
||||
|
||||
|
@ -1,17 +1,9 @@
|
||||
from typing import List
|
||||
import json
|
||||
import os
|
||||
import datetime
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pathlib import Path
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.message import (
|
||||
OnceConversation,
|
||||
conversation_from_dict,
|
||||
conversations_to_dict,
|
||||
)
|
||||
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.common.custom_data_structure import FixedSizeDict
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
@ -14,7 +13,7 @@ from transformers import (
|
||||
LlamaTokenizer,
|
||||
)
|
||||
from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
@ -147,9 +146,11 @@ class ChatGLMAdapater(BaseLLMAdaper):
|
||||
return "chatglm" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
if DEVICE != "cuda":
|
||||
if get_device() != "cuda":
|
||||
model = AutoModel.from_pretrained(
|
||||
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
||||
).float()
|
||||
|
3
pilot/model/cache/base.py
vendored
3
pilot/model/cache/base.py
vendored
@ -1,6 +1,3 @@
|
||||
import json
|
||||
import hashlib
|
||||
from typing import Any, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
@ -2,11 +2,7 @@ import click
|
||||
import functools
|
||||
|
||||
from pilot.model.controller.registry import ModelRegistryClient
|
||||
from pilot.model.worker.manager import (
|
||||
RemoteWorkerManager,
|
||||
WorkerApplyRequest,
|
||||
WorkerApplyType,
|
||||
)
|
||||
from pilot.model.base import WorkerApplyType
|
||||
from pilot.model.parameter import (
|
||||
ModelControllerParameters,
|
||||
ModelWorkerParameters,
|
||||
@ -15,12 +11,14 @@ from pilot.model.parameter import (
|
||||
from pilot.utils import get_or_create_event_loop
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
||||
|
||||
|
||||
@click.group("model")
|
||||
@click.option(
|
||||
"--address",
|
||||
type=str,
|
||||
default="http://127.0.0.1:8000",
|
||||
default=MODEL_CONTROLLER_ADDRESS,
|
||||
required=False,
|
||||
show_default=True,
|
||||
help=(
|
||||
@ -28,24 +26,25 @@ from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
"Just support light deploy model"
|
||||
),
|
||||
)
|
||||
def model_cli_group():
|
||||
def model_cli_group(address: str):
|
||||
"""Clients that manage model serving"""
|
||||
pass
|
||||
global MODEL_CONTROLLER_ADDRESS
|
||||
MODEL_CONTROLLER_ADDRESS = address
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@click.option(
|
||||
"--model-name", type=str, default=None, required=False, help=("The name of model")
|
||||
"--model_name", type=str, default=None, required=False, help=("The name of model")
|
||||
)
|
||||
@click.option(
|
||||
"--model-type", type=str, default="llm", required=False, help=("The type of model")
|
||||
"--model_type", type=str, default="llm", required=False, help=("The type of model")
|
||||
)
|
||||
def list(address: str, model_name: str, model_type: str):
|
||||
def list(model_name: str, model_type: str):
|
||||
"""List model instances"""
|
||||
from prettytable import PrettyTable
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
registry = ModelRegistryClient(address)
|
||||
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
|
||||
|
||||
if not model_name:
|
||||
instances = loop.run_until_complete(registry.get_all_model_instances())
|
||||
@ -88,14 +87,14 @@ def list(address: str, model_name: str, model_type: str):
|
||||
|
||||
def add_model_options(func):
|
||||
@click.option(
|
||||
"--model-name",
|
||||
"--model_name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help=("The name of model"),
|
||||
)
|
||||
@click.option(
|
||||
"--model-type",
|
||||
"--model_type",
|
||||
type=str,
|
||||
default="llm",
|
||||
required=False,
|
||||
@ -110,23 +109,27 @@ def add_model_options(func):
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def stop(address: str, model_name: str, model_type: str):
|
||||
def stop(model_name: str, model_type: str):
|
||||
"""Stop model instances"""
|
||||
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
|
||||
worker_apply(MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.STOP)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def start(address: str, model_name: str, model_type: str):
|
||||
def start(model_name: str, model_type: str):
|
||||
"""Start model instances"""
|
||||
worker_apply(address, model_name, model_type, WorkerApplyType.START)
|
||||
worker_apply(
|
||||
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.START
|
||||
)
|
||||
|
||||
|
||||
@model_cli_group.command()
|
||||
@add_model_options
|
||||
def restart(address: str, model_name: str, model_type: str):
|
||||
def restart(model_name: str, model_type: str):
|
||||
"""Restart model instances"""
|
||||
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
|
||||
worker_apply(
|
||||
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.RESTART
|
||||
)
|
||||
|
||||
|
||||
# @model_cli_group.command()
|
||||
@ -139,6 +142,8 @@ def restart(address: str, model_name: str, model_type: str):
|
||||
def worker_apply(
|
||||
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
|
||||
):
|
||||
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
|
||||
|
||||
loop = get_or_create_event_loop()
|
||||
registry = ModelRegistryClient(address)
|
||||
worker_manager = RemoteWorkerManager(registry)
|
||||
|
@ -6,7 +6,7 @@ Conversation prompt templates.
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, IntEnum
|
||||
from typing import List, Any, Dict, Callable
|
||||
from typing import List, Dict, Callable
|
||||
|
||||
|
||||
class SeparatorStyle(IntEnum):
|
||||
|
@ -9,8 +9,6 @@ from typing import Iterable, Dict
|
||||
|
||||
import torch
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
|
@ -2,7 +2,7 @@
|
||||
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
|
||||
"""
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
from typing import Dict
|
||||
import torch
|
||||
import llama_cpp
|
||||
|
||||
|
@ -7,13 +7,7 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import (
|
||||
Conversation,
|
||||
auto_dbgpt_one_shot,
|
||||
conv_one_shot,
|
||||
conv_templates,
|
||||
)
|
||||
from pilot.model.llm.base import Message
|
||||
from pilot.conversation import Conversation
|
||||
|
||||
|
||||
# TODO Rewrite this
|
||||
|
@ -3,11 +3,9 @@
|
||||
|
||||
from typing import List
|
||||
import re
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
from pilot.scene.base_message import ModelMessage, _parse_model_messages
|
||||
|
||||
# TODO move sep to scene prompt of model
|
||||
|
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import copy
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
|
||||
|
@ -2,16 +2,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional, Dict
|
||||
import torch
|
||||
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
||||
from pilot.model.compression import compress_module
|
||||
from pilot.model.parameter import (
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
)
|
||||
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
|
||||
from pilot.utils import get_gpu_memory
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
|
||||
from pilot.logs import logger
|
||||
@ -67,7 +64,7 @@ class ModelLoader:
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, model_name: str = None) -> None:
|
||||
self.device = DEVICE
|
||||
self.device = get_device()
|
||||
self.model_path = model_path
|
||||
self.model_name = model_name
|
||||
self.prompt_template: str = None
|
||||
@ -127,6 +124,9 @@ class ModelLoader:
|
||||
|
||||
|
||||
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
|
||||
import torch
|
||||
from pilot.model.compression import compress_module
|
||||
|
||||
device = model_params.device
|
||||
max_memory = None
|
||||
|
||||
@ -156,6 +156,10 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
|
||||
|
||||
elif device == "mps":
|
||||
kwargs = {"torch_dtype": torch.float16}
|
||||
from pilot.model.llm.monkey_patch import (
|
||||
replace_llama_attn_with_non_inplace_operations,
|
||||
)
|
||||
|
||||
replace_llama_attn_with_non_inplace_operations()
|
||||
else:
|
||||
raise ValueError(f"Invalid device: {device}")
|
||||
@ -200,6 +204,8 @@ def load_huggingface_quantization_model(
|
||||
kwargs: Dict,
|
||||
max_memory: Dict[int, str],
|
||||
):
|
||||
import torch
|
||||
|
||||
try:
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import infer_auto_device_map
|
||||
|
@ -1,8 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from dataclasses import dataclass, field, fields, MISSING
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pilot.model.conversation import conv_templates
|
||||
from pilot.utils.parameter_utils import BaseParameters
|
||||
|
@ -2,8 +2,7 @@ import logging
|
||||
import platform
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
import torch
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
@ -63,7 +62,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_type=model_type,
|
||||
)
|
||||
if not model_params.device:
|
||||
model_params.device = DEVICE
|
||||
model_params.device = get_device()
|
||||
logger.info(
|
||||
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
)
|
||||
@ -88,6 +87,8 @@ class DefaultModelWorker(ModelWorker):
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
import torch
|
||||
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
@ -95,7 +96,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
)
|
||||
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, self.context_len
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.model.loader import _get_model_real_path
|
||||
from pilot.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
@ -55,7 +55,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
model_path=self.model_path,
|
||||
)
|
||||
if not model_params.device:
|
||||
model_params.device = DEVICE
|
||||
model_params.device = get_device()
|
||||
logger.info(
|
||||
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
|
||||
)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import httpx
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
@ -7,26 +6,21 @@ import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import (
|
||||
ModelInstance,
|
||||
ModelOutput,
|
||||
WorkerApplyType,
|
||||
WorkerApplyOutput,
|
||||
WorkerApplyType,
|
||||
)
|
||||
from pilot.model.controller.registry import ModelRegistry
|
||||
from pilot.model.parameter import (
|
||||
ModelParameters,
|
||||
ModelWorkerParameters,
|
||||
WorkerType,
|
||||
)
|
||||
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
from pilot.utils import build_logger
|
||||
@ -431,6 +425,8 @@ class RemoteWorkerManager(LocalWorkerManager):
|
||||
return worker_instances
|
||||
|
||||
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
|
||||
import httpx
|
||||
|
||||
async def _remote_apply_func(worker_run_data: WorkerRunData):
|
||||
worker_addr = worker_run_data.worker.worker_addr
|
||||
async with httpx.AsyncClient() as client:
|
||||
@ -700,6 +696,8 @@ def run_worker_manager(
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
if not embedded_mod:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
app, host=worker_params.host, port=worker_params.port, log_level="info"
|
||||
)
|
||||
|
@ -1,7 +1,6 @@
|
||||
import json
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.worker.base import ModelWorker
|
||||
@ -10,7 +9,8 @@ from pilot.model.worker.base import ModelWorker
|
||||
class RemoteModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
self.headers = {}
|
||||
self.timeout = 60
|
||||
# TODO Configured by ModelParameters
|
||||
self.timeout = 180
|
||||
self.host = None
|
||||
self.port = None
|
||||
|
||||
@ -44,7 +44,9 @@ class RemoteModelWorker(ModelWorker):
|
||||
|
||||
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
"""Asynchronous generate stream"""
|
||||
print(f"Send async_generate_stream, params: {params}")
|
||||
import httpx
|
||||
|
||||
logging.debug(f"Send async_generate_stream, params: {params}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
delimiter = b"\0"
|
||||
buffer = b""
|
||||
@ -71,8 +73,9 @@ class RemoteModelWorker(ModelWorker):
|
||||
|
||||
async def async_generate(self, params: Dict) -> ModelOutput:
|
||||
"""Asynchronous generate non stream"""
|
||||
print(f"Send async_generate_stream, params: {params}")
|
||||
import httpx
|
||||
|
||||
logging.debug(f"Send async_generate_stream, params: {params}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.worker_addr + "/generate",
|
||||
@ -88,6 +91,8 @@ class RemoteModelWorker(ModelWorker):
|
||||
|
||||
async def async_embeddings(self, params: Dict) -> List[List[float]]:
|
||||
"""Asynchronous get embeddings for input"""
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.worker_addr + "/embeddings",
|
||||
|
@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic, Any
|
||||
from typing import TypeVar, Generic, Any
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -1,24 +1,6 @@
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Request,
|
||||
Body,
|
||||
status,
|
||||
HTTPException,
|
||||
Response,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
from fastapi.responses import JSONResponse, HTMLResponse
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
||||
from pilot.openapi.api_view_model import (
|
||||
Result,
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo,
|
||||
)
|
||||
from pilot.openapi.api_view_model import Result
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
|
@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic, Any
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class DbField(BaseModel):
|
||||
|
@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, TypeVar, Union
|
||||
|
||||
|
@ -1,10 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
from typing import List
|
||||
from pilot.common.schema import ExampleType
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from pilot.common.formatting import formatter, no_strict_formatter
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
import json
|
||||
|
||||
_DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__"
|
||||
_DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__"
|
||||
|
@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.prompts.base import PromptValue
|
||||
from pilot.scene.base_message import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from pilot.scene.base_message import HumanMessage, BaseMessage
|
||||
from pilot.common.formatting import formatter
|
||||
|
||||
|
||||
|
@ -1,44 +1,20 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
import datetime
|
||||
import traceback
|
||||
import warnings
|
||||
import json
|
||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
import requests
|
||||
from urllib.parse import urljoin
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
|
||||
import pilot.configs.config
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.memory.chat_history.file_history import FileHistoryMemory
|
||||
from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
|
||||
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
|
||||
from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop
|
||||
from pilot.scene.base_message import (
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ViewMessage,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
)
|
||||
from pilot.configs.config import Config
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.utils import build_logger, get_or_create_event_loop
|
||||
from pydantic import Extra
|
||||
|
||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
|
@ -1,20 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Tuple,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
|
||||
class PromptValue(BaseModel, ABC):
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||
|
@ -1,7 +1,5 @@
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class ValueItem(BaseModel):
|
||||
|
@ -1,9 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, asdict
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple, List
|
||||
import pandas as pd
|
||||
from typing import NamedTuple, List
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
@ -3,17 +3,11 @@ import os
|
||||
|
||||
|
||||
from typing import List, Any, Dict
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat, logger
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
|
||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
|
||||
|
@ -1,16 +1,7 @@
|
||||
import json
|
||||
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
@ -1,8 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
@ -36,6 +33,8 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
return SqlAction(sql, thoughts)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
import pandas as pd
|
||||
|
||||
### tool out data to table view
|
||||
data_loader = DbDataLoader()
|
||||
if len(data) <= 1:
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
||||
|
||||
|
@ -5,7 +5,7 @@ import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class DbDataLoader:
|
||||
def get_table_view_by_conn(self, data, speak):
|
||||
import pandas as pd
|
||||
|
||||
### tool out data to table view
|
||||
if len(data) <= 1:
|
||||
data.insert(0, ["result"])
|
||||
|
@ -1,14 +1,7 @@
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_db.professional_qa.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
@ -1,12 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.utils import build_logger
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
import json
|
||||
import importlib
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
|
@ -1,11 +1,6 @@
|
||||
import requests
|
||||
import datetime
|
||||
from urllib.parse import urljoin
|
||||
from typing import List
|
||||
import traceback
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command import execute_command
|
||||
|
@ -1,8 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
@ -1,20 +1,20 @@
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.singleton import Singleton
|
||||
import inspect
|
||||
import importlib
|
||||
from pilot.scene.chat_execution.chat import ChatWithPlugin
|
||||
from pilot.scene.chat_normal.chat import ChatNormal
|
||||
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
|
||||
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
|
||||
from pilot.scene.chat_dashboard.chat import ChatDashboard
|
||||
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
|
||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
||||
|
||||
|
||||
class ChatFactory(metaclass=Singleton):
|
||||
@staticmethod
|
||||
def get_implementation(chat_mode, **kwargs):
|
||||
# Lazy loading
|
||||
from pilot.scene.chat_execution.chat import ChatWithPlugin
|
||||
from pilot.scene.chat_normal.chat import ChatNormal
|
||||
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
|
||||
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
|
||||
from pilot.scene.chat_dashboard.chat import ChatDashboard
|
||||
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
|
||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
||||
|
||||
chat_classes = BaseChat.__subclasses__()
|
||||
implementation = None
|
||||
for cls in chat_classes:
|
||||
|
@ -1,8 +1,3 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
@ -1,5 +1,3 @@
|
||||
import builtins
|
||||
import importlib
|
||||
import json
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
|
@ -1,25 +1,15 @@
|
||||
from chromadb.errors import NoIndexException
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.common.markdown_text import (
|
||||
generate_markdown_table,
|
||||
generate_htm_table,
|
||||
datas_to_table_html,
|
||||
)
|
||||
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.server.knowledge.service import KnowledgeService
|
||||
|
||||
CFG = Config()
|
||||
@ -32,6 +22,8 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
""" """
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
|
||||
self.knowledge_space = select_param
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatKnowledge,
|
||||
|
@ -1,8 +1,3 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
@ -1,6 +1,3 @@
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
|
@ -1,6 +1,3 @@
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
|
@ -1,13 +1,7 @@
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.common.markdown_text import (
|
||||
generate_markdown_table,
|
||||
generate_htm_table,
|
||||
datas_to_table_html,
|
||||
)
|
||||
from pilot.scene.chat_normal.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
@ -1,8 +1,3 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
@ -1,6 +1,3 @@
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
|
@ -1,13 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime, timedelta
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
)
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from pilot.scene.base_message import (
|
||||
BaseMessage,
|
||||
|
@ -1,28 +1,18 @@
|
||||
import signal
|
||||
import os
|
||||
import threading
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.configs.config import Config
|
||||
|
||||
# from pilot.configs.model_config import (
|
||||
# DATASETS_DIR,
|
||||
# KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
# LLM_MODEL_CONFIG,
|
||||
# LOGDIR,
|
||||
# )
|
||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||
from pilot.utils import build_logger
|
||||
from pilot.common.plugins import scan_plugins
|
||||
from pilot.connections.manages.connection_manager import ConnectManager
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
# logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
print("in order to avoid chroma db atexit problem")
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
from functools import cache
|
||||
from typing import List, Dict, Tuple
|
||||
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
||||
from pilot.model.conversation import Conversation, get_conv_template
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
@ -131,6 +130,8 @@ class VicunaChatAdapter(BaseChatAdpter):
|
||||
return None
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
||||
|
||||
if self._is_llama2_based(model_path):
|
||||
return super().get_generate_stream_func(model_path)
|
||||
return generate_stream
|
||||
|
@ -1,7 +1,4 @@
|
||||
import atexit
|
||||
import traceback
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import sys
|
||||
import logging
|
||||
@ -11,7 +8,6 @@ sys.path.append(ROOT_PATH)
|
||||
import signal
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.utils import build_logger
|
||||
|
||||
from pilot.server.base import server_init
|
||||
|
||||
@ -28,14 +24,11 @@ from pilot.openapi.base import validation_exception_handler
|
||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||
from pilot.model.worker.manager import initialize_worker_manager_in_client
|
||||
|
||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
||||
from pilot.utils.utils import setup_logging
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
||||
|
||||
CFG = Config()
|
||||
# logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
|
||||
|
||||
def signal_handler():
|
||||
@ -102,7 +95,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--port", type=int, default=5000)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
parser.add_argument("--log-level", type=str, default="info")
|
||||
parser.add_argument("--log-level", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"-light",
|
||||
"--light",
|
||||
@ -113,6 +106,7 @@ if __name__ == "__main__":
|
||||
|
||||
# init server config
|
||||
args = parser.parse_args()
|
||||
setup_logging(logging_level=args.log_level)
|
||||
server_init(args)
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
@ -137,6 +131,5 @@ if __name__ == "__main__":
|
||||
mount_static_files(app)
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=args.log_level)
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="info")
|
||||
signal.signal(signal.SIGINT, signal_handler())
|
||||
|
@ -1,7 +1,6 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile, Form
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
|
@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
|
@ -2,12 +2,10 @@ import json
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, SpacyTextSplitter
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.logs import logger
|
||||
from pilot.server.knowledge.chunk_db import (
|
||||
DocumentChunkEntity,
|
||||
@ -152,6 +150,14 @@ class KnowledgeService:
|
||||
"""sync knowledge document chunk into vector store"""
|
||||
|
||||
def sync_knowledge_document(self, space_name, doc_ids):
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from langchain.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
SpacyTextSplitter,
|
||||
)
|
||||
|
||||
# import langchain is very very slow!!!
|
||||
|
||||
for doc_id in doc_ids:
|
||||
query = KnowledgeDocumentEntity(
|
||||
id=doc_id,
|
||||
|
@ -1,8 +1,7 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||
|
@ -4,9 +4,6 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
global_counter = 0
|
||||
model_semaphore = None
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
@ -17,15 +14,6 @@ from pilot.model.worker.manager import run_worker_manager
|
||||
CFG = Config()
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
# worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
|
||||
|
||||
# @app.post("/embedding")
|
||||
# def embeddings(prompt_request: EmbeddingRequest):
|
||||
# params = {"prompt": prompt_request.prompt}
|
||||
# print("Received prompt: ", params["prompt"])
|
||||
# output = worker.get_embeddings(params["prompt"])
|
||||
# return {"response": [float(x) for x in output]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_worker_manager(
|
||||
|
@ -1,3 +0,0 @@
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
__all__ = ["say_text"]
|
@ -1,21 +1,19 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings, logger
|
||||
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
from pilot.summary.rdbms_db_summary import RdbmsSummary
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.summary.rdbms_db_summary import RdbmsSummary
|
||||
from pilot.utils import build_logger
|
||||
|
||||
|
||||
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
|
||||
|
||||
|
||||
@ -33,6 +31,8 @@ class DBSummaryClient:
|
||||
|
||||
def db_summary_embedding(self, dbname, db_type):
|
||||
"""put db profile and table profile summary into vector store"""
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
|
||||
db_summary_client = RdbmsSummary(dbname, db_type)
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
@ -82,6 +82,8 @@ class DBSummaryClient:
|
||||
logger.info("db summary embedding success")
|
||||
|
||||
def get_db_summary(self, dbname, query, topk):
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_profile",
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
@ -97,6 +99,8 @@ class DBSummaryClient:
|
||||
|
||||
def get_similar_tables(self, dbname, query, topk):
|
||||
"""get user query related tables info"""
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_summary",
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
@ -149,6 +153,8 @@ class DBSummaryClient:
|
||||
)
|
||||
|
||||
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
|
||||
profile_store_config = {
|
||||
"vector_store_name": dbname + "_profile",
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
|
@ -1,6 +1,4 @@
|
||||
import httpx
|
||||
from inspect import signature
|
||||
import typing_inspect
|
||||
import logging
|
||||
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
|
||||
from dataclasses import is_dataclass, asdict
|
||||
@ -9,6 +7,8 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||
import typing_inspect
|
||||
|
||||
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
|
||||
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
|
||||
return typing_inspect.get_args(type_hint)[0]
|
||||
@ -30,6 +30,8 @@ def _api_remote(path, method="GET"):
|
||||
sig = signature(func)
|
||||
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
import httpx
|
||||
|
||||
base_url = self.base_url # Get base_url from class instance
|
||||
|
||||
bound = sig.bind(self, *args, **kwargs)
|
||||
|
@ -16,6 +16,22 @@ server_error_msg = (
|
||||
handler = None
|
||||
|
||||
|
||||
def _get_logging_level() -> str:
|
||||
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
def setup_logging(logging_level=None, logger_name: str = None):
|
||||
if not logging_level:
|
||||
logging_level = _get_logging_level()
|
||||
if type(logging_level) is str:
|
||||
logging_level = logging.getLevelName(logging_level.upper())
|
||||
if logger_name:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging_level)
|
||||
else:
|
||||
logging.basicConfig(level=logging_level, encoding="utf-8")
|
||||
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
import torch
|
||||
|
||||
@ -47,7 +63,7 @@ def build_logger(logger_name, logger_filename):
|
||||
|
||||
# Set the format of root handlers
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
||||
setup_logging()
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
@ -73,11 +89,11 @@ def build_logger(logger_name, logger_filename):
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
||||
setup_logging()
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.INFO)
|
||||
setup_logging(logger_name=logger_name)
|
||||
|
||||
return logger
|
||||
|
||||
|
@ -2,7 +2,6 @@ import os
|
||||
from typing import Any
|
||||
|
||||
from chromadb.config import Settings
|
||||
from langchain.vectorstores import Chroma
|
||||
from pilot.logs import logger
|
||||
from pilot.vector_store.base import VectorStoreBase
|
||||
|
||||
@ -11,6 +10,8 @@ class ChromaStore(VectorStoreBase):
|
||||
"""chroma database"""
|
||||
|
||||
def __init__(self, ctx: {}) -> None:
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
self.ctx = ctx
|
||||
self.embeddings = ctx.get("embeddings", None)
|
||||
self.persist_dir = os.path.join(
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from pymilvus import Collection, DataType, connections, utility
|
||||
|
||||
from pilot.logs import logger
|
||||
@ -279,7 +280,9 @@ class MilvusStore(VectorStoreBase):
|
||||
round_decimal: int = -1,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]:
|
||||
):
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
self.col.load()
|
||||
# use default index params.
|
||||
if param is None:
|
||||
|
Loading…
Reference in New Issue
Block a user