feat: Optimize code import time

This commit is contained in:
FangYin Cheng 2023-09-01 10:40:18 +08:00
parent 0bc5134a07
commit f19551a7cd
83 changed files with 244 additions and 394 deletions

View File

@ -143,3 +143,9 @@ SUMMARY_CONFIG=FAST
# CUDA_VISIBLE_DEVICES=0 # CUDA_VISIBLE_DEVICES=0
## You can configure the maximum memory used by each GPU. ## You can configure the maximum memory used by each GPU.
# MAX_GPU_MEMORY=16Gib # MAX_GPU_MEMORY=16Gib
#*******************************************************************#
#** LOG **#
#*******************************************************************#
# FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET
DBGPT_LOG_LEVEL=INFO

View File

@ -1,4 +1,12 @@
from pilot.embedding_engine import SourceEmbedding, register # Old packages
from pilot.embedding_engine import EmbeddingEngine, KnowledgeType # __all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
__all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"] __all__ = ["embedding_engine"]
def __getattr__(name: str):
import importlib
if name in ["embedding_engine"]:
return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -12,7 +12,7 @@ from pilot.json_utils.json_fix_general import (
fix_invalid_escape, fix_invalid_escape,
) )
from pilot.logs import logger from pilot.logs import logger
from pilot.speech import say_text
CFG = Config() CFG = Config()
@ -87,6 +87,8 @@ def correct_json(json_to_load: str) -> str:
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str): def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
from pilot.speech.say import say_text
if CFG.speak_mode and CFG.debug_mode: if CFG.speak_mode and CFG.debug_mode:
say_text( say_text(
"I have received an invalid JSON response from the OpenAI API. " "I have received an invalid JSON response from the OpenAI API. "

View File

@ -7,7 +7,6 @@ from typing import Dict
from pilot.commands.exception_not_commands import NotCommands from pilot.commands.exception_not_commands import NotCommands
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.prompts.generator import PluginPromptGenerator from pilot.prompts.generator import PluginPromptGenerator
from pilot.speech import say_text
def _resolve_pathlike_command_args(command_args): def _resolve_pathlike_command_args(command_args):
@ -37,6 +36,8 @@ def execute_ai_response_json(
Returns: Returns:
""" """
from pilot.speech.say import say_text
cfg = Config() cfg = Config()
command_name, arguments = get_command(ai_response) command_name, arguments = get_command(ai_response)

View File

@ -1,8 +1,9 @@
import markdown2 import markdown2
import pandas as pd
def datas_to_table_html(data): def datas_to_table_html(data):
import pandas as pd
df = pd.DataFrame(data[1:], columns=data[0]) df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style> table_style = """<style>
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333} table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}

View File

@ -1,5 +1,4 @@
from enum import auto, Enum from enum import auto, Enum
from typing import List, Any
import os import os

View File

@ -3,8 +3,6 @@ import sqlparse
import regex as re import regex as re
import warnings import warnings
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
import sqlalchemy import sqlalchemy
from sqlalchemy import ( from sqlalchemy import (
MetaData, MetaData,
@ -14,7 +12,7 @@ from sqlalchemy import (
select, select,
text, text,
) )
from sqlalchemy.engine import CursorResult, Engine from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session

View File

@ -4,12 +4,7 @@
import os import os
from typing import List from typing import List
import nltk
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.singleton import Singleton from pilot.singleton import Singleton
from pilot.common.sql_database import Database
from pilot.prompts.prompt_registry import PromptTemplateRegistry
class Config(metaclass=Singleton): class Config(metaclass=Singleton):
@ -78,6 +73,8 @@ class Config(metaclass=Singleton):
) )
self.speak_mode = False self.speak_mode = False
from pilot.prompts.prompt_registry import PromptTemplateRegistry
self.prompt_template_registry = PromptTemplateRegistry() self.prompt_template_registry = PromptTemplateRegistry()
### Related configuration of built-in commands ### Related configuration of built-in commands
self.command_registry = [] self.command_registry = []
@ -98,6 +95,8 @@ class Config(metaclass=Singleton):
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message") self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
### The associated configuration parameters of the plug-in control the loading and use of the plug-in ### The associated configuration parameters of the plug-in control the loading and use of the plug-in
from auto_gpt_plugin_template import AutoGPTPluginTemplate
self.plugins: List[AutoGPTPluginTemplate] = [] self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = [] self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True" self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
@ -183,6 +182,9 @@ class Config(metaclass=Singleton):
self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None) self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None)
### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
def set_debug_mode(self, value: bool) -> None: def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value""" """Set the debug mode value"""
self.debug_mode = value self.debug_mode = value

View File

@ -3,8 +3,7 @@
import os import os
import nltk # import nltk
import torch
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models") MODEL_PATH = os.path.join(ROOT_PATH, "models")
@ -13,7 +12,7 @@ VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
LOGDIR = os.path.join(ROOT_PATH, "logs") LOGDIR = os.path.join(ROOT_PATH, "logs")
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data") DATA_DIR = os.path.join(PILOT_PATH, "data")
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path # nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts") FONT_DIR = os.path.join(PILOT_PATH, "fonts")
@ -22,13 +21,19 @@ current_directory = os.getcwd()
new_directory = PILOT_PATH new_directory = PILOT_PATH
os.chdir(new_directory) os.chdir(new_directory)
DEVICE = (
"cuda" def get_device() -> str:
if torch.cuda.is_available() import torch
else "mps"
if torch.backends.mps.is_available() return (
else "cpu" "cuda"
) if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
LLM_MODEL_CONFIG = { LLM_MODEL_CONFIG = {
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"), "flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"), "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),

View File

@ -3,7 +3,6 @@
"""We need to design a base class. That other connector can Write with this""" """We need to design a base class. That other connector can Write with this"""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pydantic import BaseModel, Extra, Field, root_validator
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional

View File

@ -1,6 +1,5 @@
import os import os
import duckdb import duckdb
from typing import List
default_db_path = os.path.join(os.getcwd(), "message") default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")

View File

@ -2,7 +2,6 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import dataclasses import dataclasses
import uuid
from enum import auto, Enum from enum import auto, Enum
from typing import List, Any from typing import List, Any
from pilot.language.translation_handler import get_lang_text from pilot.language.translation_handler import get_lang_text

View File

@ -12,9 +12,6 @@ class JsonFileHandler(logging.FileHandler):
json.dump(json_data, f, ensure_ascii=False, indent=4) json.dump(json_data, f, ensure_ascii=False, indent=4)
import logging
class JsonFormatter(logging.Formatter): class JsonFormatter(logging.Formatter):
def format(self, record): def format(self, record):
return record.msg return record.msg

View File

@ -8,9 +8,7 @@ from typing import Any
from colorama import Fore, Style from colorama import Fore, Style
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
from pilot.singleton import Singleton from pilot.singleton import Singleton
from pilot.speech import say_text
class Logger(metaclass=Singleton): class Logger(metaclass=Singleton):
@ -86,6 +84,8 @@ class Logger(metaclass=Singleton):
def typewriter_log( def typewriter_log(
self, title="", title_color="", content="", speak_text=False, level=logging.INFO self, title="", title_color="", content="", speak_text=False, level=logging.INFO
): ):
from pilot.speech.say import say_text
if speak_text and self.speak_mode: if speak_text and self.speak_mode:
say_text(f"{title}. {content}") say_text(f"{title}. {content}")
@ -159,6 +159,8 @@ class Logger(metaclass=Singleton):
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText) self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
def log_json(self, data: Any, file_name: str) -> None: def log_json(self, data: Any, file_name: str) -> None:
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
# Define log directory # Define log directory
this_files_dir_path = os.path.dirname(__file__) this_files_dir_path = os.path.dirname(__file__)
log_dir = os.path.join(this_files_dir_path, "../logs") log_dir = os.path.join(this_files_dir_path, "../logs")
@ -255,6 +257,8 @@ def print_assistant_thoughts(
assistant_reply_json_valid: object, assistant_reply_json_valid: object,
speak_mode: bool = False, speak_mode: bool = False,
) -> None: ) -> None:
from pilot.speech.say import say_text
assistant_thoughts_reasoning = None assistant_thoughts_reasoning = None
assistant_thoughts_plan = None assistant_thoughts_plan = None
assistant_thoughts_speak = None assistant_thoughts_speak = None

View File

@ -1,18 +1,7 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import ( from typing import List
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation

View File

@ -7,9 +7,7 @@ from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.scene.message import ( from pilot.scene.message import (
OnceConversation, OnceConversation,
conversation_from_dict,
_conversation_to_dic, _conversation_to_dic,
conversations_to_dict,
) )
from pilot.common.formatting import MyEncoder from pilot.common.formatting import MyEncoder

View File

@ -1,17 +1,9 @@
from typing import List from typing import List
import json
import os
import datetime
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pathlib import Path
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.message import ( from pilot.scene.message import OnceConversation
OnceConversation, from pilot.common.custom_data_structure import FixedSizeDict
conversation_from_dict,
conversations_to_dict,
)
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
CFG = Config() CFG = Config()

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import torch
import os import os
import re import re
from pathlib import Path from pathlib import Path
@ -14,7 +13,7 @@ from transformers import (
LlamaTokenizer, LlamaTokenizer,
) )
from pilot.model.parameter import ModelParameters, LlamaCppModelParameters from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import get_device
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.logs import logger from pilot.logs import logger
@ -147,9 +146,11 @@ class ChatGLMAdapater(BaseLLMAdaper):
return "chatglm" in model_path return "chatglm" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
import torch
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if DEVICE != "cuda": if get_device() != "cuda":
model = AutoModel.from_pretrained( model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs model_path, trust_remote_code=True, **from_pretrained_kwargs
).float() ).float()

View File

@ -1,6 +1,3 @@
import json
import hashlib
from typing import Any, Dict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@ -2,11 +2,7 @@ import click
import functools import functools
from pilot.model.controller.registry import ModelRegistryClient from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.worker.manager import ( from pilot.model.base import WorkerApplyType
RemoteWorkerManager,
WorkerApplyRequest,
WorkerApplyType,
)
from pilot.model.parameter import ( from pilot.model.parameter import (
ModelControllerParameters, ModelControllerParameters,
ModelWorkerParameters, ModelWorkerParameters,
@ -15,12 +11,14 @@ from pilot.model.parameter import (
from pilot.utils import get_or_create_event_loop from pilot.utils import get_or_create_event_loop
from pilot.utils.parameter_utils import EnvArgumentParser from pilot.utils.parameter_utils import EnvArgumentParser
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
@click.group("model") @click.group("model")
@click.option( @click.option(
"--address", "--address",
type=str, type=str,
default="http://127.0.0.1:8000", default=MODEL_CONTROLLER_ADDRESS,
required=False, required=False,
show_default=True, show_default=True,
help=( help=(
@ -28,24 +26,25 @@ from pilot.utils.parameter_utils import EnvArgumentParser
"Just support light deploy model" "Just support light deploy model"
), ),
) )
def model_cli_group(): def model_cli_group(address: str):
"""Clients that manage model serving""" """Clients that manage model serving"""
pass global MODEL_CONTROLLER_ADDRESS
MODEL_CONTROLLER_ADDRESS = address
@model_cli_group.command() @model_cli_group.command()
@click.option( @click.option(
"--model-name", type=str, default=None, required=False, help=("The name of model") "--model_name", type=str, default=None, required=False, help=("The name of model")
) )
@click.option( @click.option(
"--model-type", type=str, default="llm", required=False, help=("The type of model") "--model_type", type=str, default="llm", required=False, help=("The type of model")
) )
def list(address: str, model_name: str, model_type: str): def list(model_name: str, model_type: str):
"""List model instances""" """List model instances"""
from prettytable import PrettyTable from prettytable import PrettyTable
loop = get_or_create_event_loop() loop = get_or_create_event_loop()
registry = ModelRegistryClient(address) registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
if not model_name: if not model_name:
instances = loop.run_until_complete(registry.get_all_model_instances()) instances = loop.run_until_complete(registry.get_all_model_instances())
@ -88,14 +87,14 @@ def list(address: str, model_name: str, model_type: str):
def add_model_options(func): def add_model_options(func):
@click.option( @click.option(
"--model-name", "--model_name",
type=str, type=str,
default=None, default=None,
required=True, required=True,
help=("The name of model"), help=("The name of model"),
) )
@click.option( @click.option(
"--model-type", "--model_type",
type=str, type=str,
default="llm", default="llm",
required=False, required=False,
@ -110,23 +109,27 @@ def add_model_options(func):
@model_cli_group.command() @model_cli_group.command()
@add_model_options @add_model_options
def stop(address: str, model_name: str, model_type: str): def stop(model_name: str, model_type: str):
"""Stop model instances""" """Stop model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.STOP) worker_apply(MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.STOP)
@model_cli_group.command() @model_cli_group.command()
@add_model_options @add_model_options
def start(address: str, model_name: str, model_type: str): def start(model_name: str, model_type: str):
"""Start model instances""" """Start model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.START) worker_apply(
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.START
)
@model_cli_group.command() @model_cli_group.command()
@add_model_options @add_model_options
def restart(address: str, model_name: str, model_type: str): def restart(model_name: str, model_type: str):
"""Restart model instances""" """Restart model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART) worker_apply(
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.RESTART
)
# @model_cli_group.command() # @model_cli_group.command()
@ -139,6 +142,8 @@ def restart(address: str, model_name: str, model_type: str):
def worker_apply( def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
): ):
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
loop = get_or_create_event_loop() loop = get_or_create_event_loop()
registry = ModelRegistryClient(address) registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry) worker_manager = RemoteWorkerManager(registry)

View File

@ -6,7 +6,7 @@ Conversation prompt templates.
import dataclasses import dataclasses
from enum import auto, IntEnum from enum import auto, IntEnum
from typing import List, Any, Dict, Callable from typing import List, Dict, Callable
class SeparatorStyle(IntEnum): class SeparatorStyle(IntEnum):

View File

@ -9,8 +9,6 @@ from typing import Iterable, Dict
import torch import torch
import torch
from transformers.generation.logits_process import ( from transformers.generation.logits_process import (
LogitsProcessorList, LogitsProcessorList,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,

View File

@ -2,7 +2,7 @@
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
""" """
import re import re
from typing import Dict, Any from typing import Dict
import torch import torch
import llama_cpp import llama_cpp

View File

@ -7,13 +7,7 @@ import time
from typing import Optional from typing import Optional
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.conversation import ( from pilot.conversation import Conversation
Conversation,
auto_dbgpt_one_shot,
conv_one_shot,
conv_templates,
)
from pilot.model.llm.base import Message
# TODO Rewrite this # TODO Rewrite this

View File

@ -3,11 +3,9 @@
from typing import List from typing import List
import re import re
import copy
import torch import torch
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
from pilot.scene.base_message import ModelMessage, _parse_model_messages from pilot.scene.base_message import ModelMessage, _parse_model_messages
# TODO move sep to scene prompt of model # TODO move sep to scene prompt of model

View File

@ -1,5 +1,4 @@
import torch import torch
import copy
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria

View File

@ -2,16 +2,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import Optional, Dict from typing import Optional, Dict
import torch
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import get_device
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
from pilot.model.compression import compress_module
from pilot.model.parameter import ( from pilot.model.parameter import (
ModelParameters, ModelParameters,
LlamaCppModelParameters, LlamaCppModelParameters,
) )
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
from pilot.utils import get_gpu_memory from pilot.utils import get_gpu_memory
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
from pilot.logs import logger from pilot.logs import logger
@ -67,7 +64,7 @@ class ModelLoader:
""" """
def __init__(self, model_path: str, model_name: str = None) -> None: def __init__(self, model_path: str, model_name: str = None) -> None:
self.device = DEVICE self.device = get_device()
self.model_path = model_path self.model_path = model_path
self.model_name = model_name self.model_name = model_name
self.prompt_template: str = None self.prompt_template: str = None
@ -127,6 +124,9 @@ class ModelLoader:
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters): def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
import torch
from pilot.model.compression import compress_module
device = model_params.device device = model_params.device
max_memory = None max_memory = None
@ -156,6 +156,10 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
elif device == "mps": elif device == "mps":
kwargs = {"torch_dtype": torch.float16} kwargs = {"torch_dtype": torch.float16}
from pilot.model.llm.monkey_patch import (
replace_llama_attn_with_non_inplace_operations,
)
replace_llama_attn_with_non_inplace_operations() replace_llama_attn_with_non_inplace_operations()
else: else:
raise ValueError(f"Invalid device: {device}") raise ValueError(f"Invalid device: {device}")
@ -200,6 +204,8 @@ def load_huggingface_quantization_model(
kwargs: Dict, kwargs: Dict,
max_memory: Dict[int, str], max_memory: Dict[int, str],
): ):
import torch
try: try:
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import infer_auto_device_map from accelerate.utils import infer_auto_device_map

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from dataclasses import dataclass, field, fields, MISSING from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Dict, Optional
from pilot.model.conversation import conv_templates from pilot.model.conversation import conv_templates
from pilot.utils.parameter_utils import BaseParameters from pilot.utils.parameter_utils import BaseParameters

View File

@ -2,8 +2,7 @@ import logging
import platform import platform
from typing import Dict, Iterator, List from typing import Dict, Iterator, List
import torch from pilot.configs.model_config import get_device
from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.base import ModelOutput from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.model.loader import ModelLoader, _get_model_real_path
@ -63,7 +62,7 @@ class DefaultModelWorker(ModelWorker):
model_type=model_type, model_type=model_type,
) )
if not model_params.device: if not model_params.device:
model_params.device = DEVICE model_params.device = get_device()
logger.info( logger.info(
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}" f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
) )
@ -88,6 +87,8 @@ class DefaultModelWorker(ModelWorker):
_clear_torch_cache(self._model_params.device) _clear_torch_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
import torch
try: try:
# params adaptation # params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation( params, model_context = self.llm_chat_adapter.model_adaptation(
@ -95,7 +96,7 @@ class DefaultModelWorker(ModelWorker):
) )
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, self.context_len self.model, self.tokenizer, params, get_device(), self.context_len
): ):
# Please do not open the output in production! # Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process, # The gpt4all thread shares stdout with the parent process,

View File

@ -1,7 +1,7 @@
import logging import logging
from typing import Dict, List, Type from typing import Dict, List, Type
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import get_device
from pilot.model.loader import _get_model_real_path from pilot.model.loader import _get_model_real_path
from pilot.model.parameter import ( from pilot.model.parameter import (
EmbeddingModelParameters, EmbeddingModelParameters,
@ -55,7 +55,7 @@ class EmbeddingsModelWorker(ModelWorker):
model_path=self.model_path, model_path=self.model_path,
) )
if not model_params.device: if not model_params.device:
model_params.device = DEVICE model_params.device = get_device()
logger.info( logger.info(
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}" f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
) )

View File

@ -1,5 +1,4 @@
import asyncio import asyncio
import httpx
import itertools import itertools
import json import json
import os import os
@ -7,26 +6,21 @@ import random
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass
from datetime import datetime from datetime import datetime
from typing import Awaitable, Callable, Dict, Iterator, List, Optional from typing import Awaitable, Callable, Dict, Iterator, List, Optional
import uvicorn
from fastapi import APIRouter, FastAPI, Request from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.model.base import ( from pilot.model.base import (
ModelInstance, ModelInstance,
ModelOutput, ModelOutput,
WorkerApplyType,
WorkerApplyOutput, WorkerApplyOutput,
WorkerApplyType,
) )
from pilot.model.controller.registry import ModelRegistry from pilot.model.controller.registry import ModelRegistry
from pilot.model.parameter import ( from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
ModelParameters,
ModelWorkerParameters,
WorkerType,
)
from pilot.model.worker.base import ModelWorker from pilot.model.worker.base import ModelWorker
from pilot.scene.base_message import ModelMessage from pilot.scene.base_message import ModelMessage
from pilot.utils import build_logger from pilot.utils import build_logger
@ -431,6 +425,8 @@ class RemoteWorkerManager(LocalWorkerManager):
return worker_instances return worker_instances
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
import httpx
async def _remote_apply_func(worker_run_data: WorkerRunData): async def _remote_apply_func(worker_run_data: WorkerRunData):
worker_addr = worker_run_data.worker.worker_addr worker_addr = worker_run_data.worker.worker_addr
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@ -700,6 +696,8 @@ def run_worker_manager(
app.include_router(router, prefix="/api") app.include_router(router, prefix="/api")
if not embedded_mod: if not embedded_mod:
import uvicorn
uvicorn.run( uvicorn.run(
app, host=worker_params.host, port=worker_params.port, log_level="info" app, host=worker_params.host, port=worker_params.port, log_level="info"
) )

View File

@ -1,7 +1,6 @@
import json import json
from typing import Dict, Iterator, List from typing import Dict, Iterator, List
import logging
import httpx
from pilot.model.base import ModelOutput from pilot.model.base import ModelOutput
from pilot.model.parameter import ModelParameters from pilot.model.parameter import ModelParameters
from pilot.model.worker.base import ModelWorker from pilot.model.worker.base import ModelWorker
@ -10,7 +9,8 @@ from pilot.model.worker.base import ModelWorker
class RemoteModelWorker(ModelWorker): class RemoteModelWorker(ModelWorker):
def __init__(self) -> None: def __init__(self) -> None:
self.headers = {} self.headers = {}
self.timeout = 60 # TODO Configured by ModelParameters
self.timeout = 180
self.host = None self.host = None
self.port = None self.port = None
@ -44,7 +44,9 @@ class RemoteModelWorker(ModelWorker):
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]: async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
"""Asynchronous generate stream""" """Asynchronous generate stream"""
print(f"Send async_generate_stream, params: {params}") import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
delimiter = b"\0" delimiter = b"\0"
buffer = b"" buffer = b""
@ -71,8 +73,9 @@ class RemoteModelWorker(ModelWorker):
async def async_generate(self, params: Dict) -> ModelOutput: async def async_generate(self, params: Dict) -> ModelOutput:
"""Asynchronous generate non stream""" """Asynchronous generate non stream"""
print(f"Send async_generate_stream, params: {params}") import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
self.worker_addr + "/generate", self.worker_addr + "/generate",
@ -88,6 +91,8 @@ class RemoteModelWorker(ModelWorker):
async def async_embeddings(self, params: Dict) -> List[List[float]]: async def async_embeddings(self, params: Dict) -> List[List[float]]:
"""Asynchronous get embeddings for input""" """Asynchronous get embeddings for input"""
import httpx
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
self.worker_addr + "/embeddings", self.worker_addr + "/embeddings",

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any from typing import TypeVar, Generic, Any
T = TypeVar("T") T = TypeVar("T")

View File

@ -1,24 +1,6 @@
from fastapi import ( from fastapi import Request
APIRouter,
Request,
Body,
status,
HTTPException,
Response,
BackgroundTasks,
)
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from pilot.openapi.api_view_model import Result
from pilot.openapi.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any from typing import List, Any
class DbField(BaseModel): class DbField(BaseModel):

View File

@ -1,8 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
import re from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict, TypeVar, Union from typing import Any, Dict, TypeVar, Union

View File

@ -1,10 +1,7 @@
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from typing import List
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml from pydantic import BaseModel
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union from typing import List
from pilot.common.schema import ExampleType from pilot.common.schema import ExampleType

View File

@ -1,7 +1,7 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel
from pilot.common.formatting import formatter, no_strict_formatter from pilot.common.formatting import formatter, no_strict_formatter

View File

@ -3,7 +3,6 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict, List from typing import Dict, List
import json
_DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__" _DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__"
_DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__" _DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__"

View File

@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from pilot.out_parser.base import BaseOutputParser from pilot.out_parser.base import BaseOutputParser
from pilot.prompts.base import PromptValue from pilot.prompts.base import PromptValue
from pilot.scene.base_message import HumanMessage, AIMessage, SystemMessage, BaseMessage from pilot.scene.base_message import HumanMessage, BaseMessage
from pilot.common.formatting import formatter from pilot.common.formatting import formatter

View File

@ -1,44 +1,20 @@
import time
from abc import ABC, abstractmethod
import datetime import datetime
import traceback import traceback
import warnings import warnings
import json from abc import ABC, abstractmethod
from pydantic import BaseModel, Field, root_validator, validator, Extra from typing import Any, List
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
import requests
from urllib.parse import urljoin
import pilot.configs.config from pilot.configs.config import Config
from pilot.scene.message import OnceConversation from pilot.configs.model_config import LOGDIR
from pilot.prompts.prompt_new import PromptTemplate
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory
from pilot.memory.chat_history.mem_history import MemHistoryMemory from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.prompts.prompt_new import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.configs.model_config import LOGDIR, DATASETS_DIR from pilot.scene.message import OnceConversation
from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop from pilot.utils import build_logger, get_or_create_event_loop
from pilot.scene.base_message import ( from pydantic import Extra
BaseMessage,
SystemMessage,
HumanMessage,
AIMessage,
ViewMessage,
ModelMessage,
ModelMessageRoleType,
)
from pilot.configs.config import Config
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"} headers = {"User-Agent": "dbgpt Client"}

View File

@ -1,20 +1,9 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import ( from typing import Any, Dict, List, Tuple, Optional
Any,
Dict,
Generic,
List,
Tuple,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Field, root_validator
class PromptValue(BaseModel, ABC): class PromptValue(BaseModel, ABC):

View File

@ -3,7 +3,7 @@ import os
import uuid import uuid
from typing import List from typing import List
from pilot.scene.base_chat import BaseChat, logger from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_dashboard.data_preparation.report_schma import ( from pilot.scene.chat_dashboard.data_preparation.report_schma import (

View File

@ -1,7 +1,5 @@
import json from pydantic import BaseModel
from pydantic import BaseModel, Field from typing import List, Any
from typing import TypeVar, Union, List, Generic, Any
from dataclasses import dataclass, asdict
class ValueItem(BaseModel): class ValueItem(BaseModel):

View File

@ -1,9 +1,5 @@
import json import json
import re from typing import NamedTuple, List
from dataclasses import dataclass, asdict
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List
import pandas as pd
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR

View File

@ -3,17 +3,11 @@ import os
from typing import List, Any, Dict from typing import List, Any, Dict
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat, logger from pilot.scene.base_chat import BaseChat, logger
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning

View File

@ -1,16 +1,7 @@
import json
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.scene.chat_db.auto_execute.prompt import prompt
CFG = Config() CFG = Config()

View File

@ -1,8 +1,5 @@
import json import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
@ -36,6 +33,8 @@ class DbChatOutputParser(BaseOutputParser):
return SqlAction(sql, thoughts) return SqlAction(sql, thoughts)
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:
import pandas as pd
### tool out data to table view ### tool out data to table view
data_loader = DbDataLoader() data_loader = DbDataLoader()
if len(data) <= 1: if len(data) <= 1:

View File

@ -2,7 +2,7 @@ import json
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_db.auto_execute.example import sql_data_example from pilot.scene.chat_db.auto_execute.example import sql_data_example

View File

@ -5,7 +5,7 @@ import json
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_db.auto_execute.example import sql_data_example from pilot.scene.chat_db.auto_execute.example import sql_data_example

View File

@ -1,8 +1,7 @@
import pandas as pd
class DbDataLoader: class DbDataLoader:
def get_table_view_by_conn(self, data, speak): def get_table_view_by_conn(self, data, speak):
import pandas as pd
### tool out data to table view ### tool out data to table view
if len(data) <= 1: if len(data) <= 1:
data.insert(0, ["result"]) data.insert(0, ["result"])

View File

@ -1,14 +1,7 @@
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.professional_qa.prompt import prompt from pilot.scene.chat_db.professional_qa.prompt import prompt
CFG = Config() CFG = Config()

View File

@ -1,12 +1,6 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.out_parser.base import BaseOutputParser, T
from pilot.utils import build_logger
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")

View File

@ -1,5 +1,3 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene

View File

@ -1,11 +1,6 @@
import requests
import datetime
from urllib.parse import urljoin
from typing import List from typing import List
import traceback
from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base_chat import BaseChat
from pilot.scene.message import OnceConversation
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.commands.command import execute_command from pilot.commands.command import execute_command

View File

@ -1,8 +1,5 @@
import json import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR

View File

@ -1,20 +1,20 @@
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.singleton import Singleton from pilot.singleton import Singleton
import inspect
import importlib
from pilot.scene.chat_execution.chat import ChatWithPlugin
from pilot.scene.chat_normal.chat import ChatNormal
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_dashboard.chat import ChatDashboard
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
class ChatFactory(metaclass=Singleton): class ChatFactory(metaclass=Singleton):
@staticmethod @staticmethod
def get_implementation(chat_mode, **kwargs): def get_implementation(chat_mode, **kwargs):
# Lazy loading
from pilot.scene.chat_execution.chat import ChatWithPlugin
from pilot.scene.chat_normal.chat import ChatNormal
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_dashboard.chat import ChatDashboard
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
chat_classes = BaseChat.__subclasses__() chat_classes = BaseChat.__subclasses__()
implementation = None implementation = None
for cls in chat_classes: for cls in chat_classes:

View File

@ -1,8 +1,3 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR

View File

@ -1,5 +1,3 @@
import builtins
import importlib
import json import json
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate

View File

@ -1,25 +1,15 @@
from chromadb.errors import NoIndexException from chromadb.errors import NoIndexException
from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.configs.model_config import ( from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
LOGDIR,
) )
from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.service import KnowledgeService
CFG = Config() CFG = Config()
@ -32,6 +22,8 @@ class ChatKnowledge(BaseChat):
def __init__(self, chat_session_id, user_input, select_param: str = None): def __init__(self, chat_session_id, user_input, select_param: str = None):
""" """ """ """
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
self.knowledge_space = select_param self.knowledge_space = select_param
super().__init__( super().__init__(
chat_mode=ChatScene.ChatKnowledge, chat_mode=ChatScene.ChatKnowledge,

View File

@ -1,8 +1,3 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR

View File

@ -1,6 +1,3 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene

View File

@ -1,6 +1,3 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene

View File

@ -1,13 +1,7 @@
from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.scene.chat_normal.prompt import prompt from pilot.scene.chat_normal.prompt import prompt
CFG = Config() CFG = Config()

View File

@ -1,8 +1,3 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR

View File

@ -1,6 +1,3 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene

View File

@ -1,13 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta from datetime import datetime
from pydantic import BaseModel, Field, root_validator, validator from typing import List
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
)
from pilot.scene.base_message import ( from pilot.scene.base_message import (
BaseMessage, BaseMessage,

View File

@ -1,28 +1,18 @@
import signal import signal
import os import os
import threading import threading
import traceback
import sys import sys
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry from pilot.commands.command_mange import CommandRegistry
from pilot.configs.config import Config from pilot.configs.config import Config
# from pilot.configs.model_config import ( from pilot.common.plugins import scan_plugins
# DATASETS_DIR,
# KNOWLEDGE_UPLOAD_ROOT_PATH,
# LLM_MODEL_CONFIG,
# LOGDIR,
# )
from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.utils import build_logger
from pilot.connections.manages.connection_manager import ConnectManager from pilot.connections.manages.connection_manager import ConnectManager
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
# logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler(sig, frame): def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem") print("in order to avoid chroma db atexit problem")

View File

@ -3,7 +3,6 @@
from functools import cache from functools import cache
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
from pilot.model.llm_out.vicuna_base_llm import generate_stream
from pilot.model.conversation import Conversation, get_conv_template from pilot.model.conversation import Conversation, get_conv_template
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
@ -131,6 +130,8 @@ class VicunaChatAdapter(BaseChatAdpter):
return None return None
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.vicuna_base_llm import generate_stream
if self._is_llama2_based(model_path): if self._is_llama2_based(model_path):
return super().get_generate_stream_func(model_path) return super().get_generate_stream_func(model_path)
return generate_stream return generate_stream

View File

@ -1,7 +1,4 @@
import atexit
import traceback
import os import os
import shutil
import argparse import argparse
import sys import sys
import logging import logging
@ -11,7 +8,6 @@ sys.path.append(ROOT_PATH)
import signal import signal
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.utils import build_logger
from pilot.server.base import server_init from pilot.server.base import server_init
@ -28,14 +24,11 @@ from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.worker.manager import initialize_worker_manager_in_client from pilot.model.worker.manager import initialize_worker_manager_in_client
from pilot.utils.utils import setup_logging
logging.basicConfig(level=logging.INFO, encoding="utf-8")
static_file_path = os.path.join(os.getcwd(), "server/static") static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config() CFG = Config()
# logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler(): def signal_handler():
@ -102,7 +95,7 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=5000) parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true") parser.add_argument("--share", default=False, action="store_true")
parser.add_argument("--log-level", type=str, default="info") parser.add_argument("--log-level", type=str, default=None)
parser.add_argument( parser.add_argument(
"-light", "-light",
"--light", "--light",
@ -113,6 +106,7 @@ if __name__ == "__main__":
# init server config # init server config
args = parser.parse_args() args = parser.parse_args()
setup_logging(logging_level=args.log_level)
server_init(args) server_init(args)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
@ -137,6 +131,5 @@ if __name__ == "__main__":
mount_static_files(app) mount_static_files(app)
import uvicorn import uvicorn
logging.basicConfig(level=logging.INFO, encoding="utf-8") uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="info")
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=args.log_level)
signal.signal(signal.SIGINT, signal_handler()) signal.signal(signal.SIGINT, signal_handler())

View File

@ -1,7 +1,6 @@
import os import os
import shutil import shutil
import tempfile import tempfile
from tempfile import NamedTemporaryFile
from fastapi import APIRouter, File, UploadFile, Form from fastapi import APIRouter, File, UploadFile, Form

View File

@ -1,8 +1,8 @@
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao from pilot.connections.rdbms.base_dao import BaseDao

View File

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao from pilot.connections.rdbms.base_dao import BaseDao

View File

@ -2,12 +2,10 @@ import json
import threading import threading
from datetime import datetime from datetime import datetime
from langchain.text_splitter import RecursiveCharacterTextSplitter, SpacyTextSplitter
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.logs import logger from pilot.logs import logger
from pilot.server.knowledge.chunk_db import ( from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity, DocumentChunkEntity,
@ -152,6 +150,14 @@ class KnowledgeService:
"""sync knowledge document chunk into vector store""" """sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, doc_ids): def sync_knowledge_document(self, space_name, doc_ids):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
# import langchain is very very slow!!!
for doc_id in doc_ids: for doc_id in doc_ids:
query = KnowledgeDocumentEntity( query = KnowledgeDocumentEntity(
id=doc_id, id=doc_id,

View File

@ -1,8 +1,7 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from pilot.server.knowledge.request.request import KnowledgeSpaceRequest

View File

@ -4,9 +4,6 @@
import os import os
import sys import sys
global_counter = 0
model_semaphore = None
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
@ -17,15 +14,6 @@ from pilot.model.worker.manager import run_worker_manager
CFG = Config() CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
# worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
# @app.post("/embedding")
# def embeddings(prompt_request: EmbeddingRequest):
# params = {"prompt": prompt_request.prompt}
# print("Received prompt: ", params["prompt"])
# output = worker.get_embeddings(params["prompt"])
# return {"response": [float(x) for x in output]}
if __name__ == "__main__": if __name__ == "__main__":
run_worker_manager( run_worker_manager(

View File

@ -1,3 +0,0 @@
from pilot.speech.say import say_text
__all__ = ["say_text"]

View File

@ -1,21 +1,19 @@
import json import json
import uuid import uuid
from langchain.embeddings import HuggingFaceEmbeddings, logger from pilot.common.schema import DBType
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.summary.rdbms_db_summary import RdbmsSummary
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.common.schema import DBType from pilot.summary.rdbms_db_summary import RdbmsSummary
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger from pilot.utils import build_logger
logger = build_logger("db_summary", LOGDIR + "db_summary.log") logger = build_logger("db_summary", LOGDIR + "db_summary.log")
@ -33,6 +31,8 @@ class DBSummaryClient:
def db_summary_embedding(self, dbname, db_type): def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store""" """put db profile and table profile summary into vector store"""
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.embedding_engine.string_embedding import StringEmbedding
db_summary_client = RdbmsSummary(dbname, db_type) db_summary_client = RdbmsSummary(dbname, db_type)
embeddings = HuggingFaceEmbeddings( embeddings = HuggingFaceEmbeddings(
@ -82,6 +82,8 @@ class DBSummaryClient:
logger.info("db summary embedding success") logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk): def get_db_summary(self, dbname, query, topk):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_profile", "vector_store_name": dbname + "_profile",
"vector_store_type": CFG.VECTOR_STORE_TYPE, "vector_store_type": CFG.VECTOR_STORE_TYPE,
@ -97,6 +99,8 @@ class DBSummaryClient:
def get_similar_tables(self, dbname, query, topk): def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info""" """get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
vector_store_config = { vector_store_config = {
"vector_store_name": dbname + "_summary", "vector_store_name": dbname + "_summary",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -149,6 +153,8 @@ class DBSummaryClient:
) )
def init_db_profile(self, db_summary_client, dbname, embeddings): def init_db_profile(self, db_summary_client, dbname, embeddings):
from pilot.embedding_engine.string_embedding import StringEmbedding
profile_store_config = { profile_store_config = {
"vector_store_name": dbname + "_profile", "vector_store_name": dbname + "_profile",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,

View File

@ -1,6 +1,4 @@
import httpx
from inspect import signature from inspect import signature
import typing_inspect
import logging import logging
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
from dataclasses import is_dataclass, asdict from dataclasses import is_dataclass, asdict
@ -9,6 +7,8 @@ T = TypeVar("T")
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]: def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
import typing_inspect
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc.""" """Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint): if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
return typing_inspect.get_args(type_hint)[0] return typing_inspect.get_args(type_hint)[0]
@ -30,6 +30,8 @@ def _api_remote(path, method="GET"):
sig = signature(func) sig = signature(func)
async def wrapper(self, *args, **kwargs): async def wrapper(self, *args, **kwargs):
import httpx
base_url = self.base_url # Get base_url from class instance base_url = self.base_url # Get base_url from class instance
bound = sig.bind(self, *args, **kwargs) bound = sig.bind(self, *args, **kwargs)

View File

@ -16,6 +16,22 @@ server_error_msg = (
handler = None handler = None
def _get_logging_level() -> str:
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
def setup_logging(logging_level=None, logger_name: str = None):
if not logging_level:
logging_level = _get_logging_level()
if type(logging_level) is str:
logging_level = logging.getLevelName(logging_level.upper())
if logger_name:
logger = logging.getLogger(logger_name)
logger.setLevel(logging_level)
else:
logging.basicConfig(level=logging_level, encoding="utf-8")
def get_gpu_memory(max_gpus=None): def get_gpu_memory(max_gpus=None):
import torch import torch
@ -47,7 +63,7 @@ def build_logger(logger_name, logger_filename):
# Set the format of root handlers # Set the format of root handlers
if not logging.getLogger().handlers: if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO, encoding="utf-8") setup_logging()
logging.getLogger().handlers[0].setFormatter(formatter) logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers # Redirect stdout and stderr to loggers
@ -73,11 +89,11 @@ def build_logger(logger_name, logger_filename):
for name, item in logging.root.manager.loggerDict.items(): for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger): if isinstance(item, logging.Logger):
item.addHandler(handler) item.addHandler(handler)
logging.basicConfig(level=logging.INFO, encoding="utf-8") setup_logging()
# Get logger # Get logger
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO) setup_logging(logger_name=logger_name)
return logger return logger

View File

@ -2,7 +2,6 @@ import os
from typing import Any from typing import Any
from chromadb.config import Settings from chromadb.config import Settings
from langchain.vectorstores import Chroma
from pilot.logs import logger from pilot.logs import logger
from pilot.vector_store.base import VectorStoreBase from pilot.vector_store.base import VectorStoreBase
@ -11,6 +10,8 @@ class ChromaStore(VectorStoreBase):
"""chroma database""" """chroma database"""
def __init__(self, ctx: {}) -> None: def __init__(self, ctx: {}) -> None:
from langchain.vectorstores import Chroma
self.ctx = ctx self.ctx = ctx
self.embeddings = ctx.get("embeddings", None) self.embeddings = ctx.get("embeddings", None)
self.persist_dir = os.path.join( self.persist_dir = os.path.join(

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document
from pymilvus import Collection, DataType, connections, utility from pymilvus import Collection, DataType, connections, utility
from pilot.logs import logger from pilot.logs import logger
@ -279,7 +280,9 @@ class MilvusStore(VectorStoreBase):
round_decimal: int = -1, round_decimal: int = -1,
timeout: Optional[int] = None, timeout: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]: ):
from langchain.docstore.document import Document
self.col.load() self.col.load()
# use default index params. # use default index params.
if param is None: if param is None: