feat: Optimize code import time

This commit is contained in:
FangYin Cheng 2023-09-01 10:40:18 +08:00
parent 0bc5134a07
commit f19551a7cd
83 changed files with 244 additions and 394 deletions

View File

@ -143,3 +143,9 @@ SUMMARY_CONFIG=FAST
# CUDA_VISIBLE_DEVICES=0
## You can configure the maximum memory used by each GPU.
# MAX_GPU_MEMORY=16Gib
#*******************************************************************#
#** LOG **#
#*******************************************************************#
# FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET
DBGPT_LOG_LEVEL=INFO

View File

@ -1,4 +1,12 @@
from pilot.embedding_engine import SourceEmbedding, register
from pilot.embedding_engine import EmbeddingEngine, KnowledgeType
# Old packages
# __all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
__all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
__all__ = ["embedding_engine"]
def __getattr__(name: str):
import importlib
if name in ["embedding_engine"]:
return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -12,7 +12,7 @@ from pilot.json_utils.json_fix_general import (
fix_invalid_escape,
)
from pilot.logs import logger
from pilot.speech import say_text
CFG = Config()
@ -87,6 +87,8 @@ def correct_json(json_to_load: str) -> str:
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
from pilot.speech.say import say_text
if CFG.speak_mode and CFG.debug_mode:
say_text(
"I have received an invalid JSON response from the OpenAI API. "

View File

@ -7,7 +7,6 @@ from typing import Dict
from pilot.commands.exception_not_commands import NotCommands
from pilot.configs.config import Config
from pilot.prompts.generator import PluginPromptGenerator
from pilot.speech import say_text
def _resolve_pathlike_command_args(command_args):
@ -37,6 +36,8 @@ def execute_ai_response_json(
Returns:
"""
from pilot.speech.say import say_text
cfg = Config()
command_name, arguments = get_command(ai_response)

View File

@ -1,8 +1,9 @@
import markdown2
import pandas as pd
def datas_to_table_html(data):
import pandas as pd
df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style>
table{border-collapse:collapse;width:60%;height:80%;margin:0 auto;float:right;border: 1px solid #007bff; background-color:#CFE299}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#7C9F4A}tr:hover{background-color:#333}

View File

@ -1,5 +1,4 @@
from enum import auto, Enum
from typing import List, Any
import os

View File

@ -3,8 +3,6 @@ import sqlparse
import regex as re
import warnings
from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
import sqlalchemy
from sqlalchemy import (
MetaData,
@ -14,7 +12,7 @@ from sqlalchemy import (
select,
text,
)
from sqlalchemy.engine import CursorResult, Engine
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session

View File

@ -4,12 +4,7 @@
import os
from typing import List
import nltk
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
from pilot.prompts.prompt_registry import PromptTemplateRegistry
class Config(metaclass=Singleton):
@ -78,6 +73,8 @@ class Config(metaclass=Singleton):
)
self.speak_mode = False
from pilot.prompts.prompt_registry import PromptTemplateRegistry
self.prompt_template_registry = PromptTemplateRegistry()
### Related configuration of built-in commands
self.command_registry = []
@ -98,6 +95,8 @@ class Config(metaclass=Singleton):
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
from auto_gpt_plugin_template import AutoGPTPluginTemplate
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
@ -183,6 +182,9 @@ class Config(metaclass=Singleton):
self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None)
### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value

View File

@ -3,8 +3,7 @@
import os
import nltk
import torch
# import nltk
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models")
@ -13,7 +12,7 @@ VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
LOGDIR = os.path.join(ROOT_PATH, "logs")
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data")
nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins")
FONT_DIR = os.path.join(PILOT_PATH, "fonts")
@ -22,13 +21,19 @@ current_directory = os.getcwd()
new_directory = PILOT_PATH
os.chdir(new_directory)
DEVICE = (
def get_device() -> str:
import torch
return (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
)
LLM_MODEL_CONFIG = {
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),

View File

@ -3,7 +3,6 @@
"""We need to design a base class. That other connector can Write with this"""
from abc import ABC, abstractmethod
from pydantic import BaseModel, Extra, Field, root_validator
from typing import Any, Iterable, List, Optional

View File

@ -1,6 +1,5 @@
import os
import duckdb
from typing import List
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")

View File

@ -2,7 +2,6 @@
# -*- coding:utf-8 -*-
import dataclasses
import uuid
from enum import auto, Enum
from typing import List, Any
from pilot.language.translation_handler import get_lang_text

View File

@ -12,9 +12,6 @@ class JsonFileHandler(logging.FileHandler):
json.dump(json_data, f, ensure_ascii=False, indent=4)
import logging
class JsonFormatter(logging.Formatter):
def format(self, record):
return record.msg

View File

@ -8,9 +8,7 @@ from typing import Any
from colorama import Fore, Style
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
from pilot.singleton import Singleton
from pilot.speech import say_text
class Logger(metaclass=Singleton):
@ -86,6 +84,8 @@ class Logger(metaclass=Singleton):
def typewriter_log(
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
):
from pilot.speech.say import say_text
if speak_text and self.speak_mode:
say_text(f"{title}. {content}")
@ -159,6 +159,8 @@ class Logger(metaclass=Singleton):
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
def log_json(self, data: Any, file_name: str) -> None:
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
# Define log directory
this_files_dir_path = os.path.dirname(__file__)
log_dir = os.path.join(this_files_dir_path, "../logs")
@ -255,6 +257,8 @@ def print_assistant_thoughts(
assistant_reply_json_valid: object,
speak_mode: bool = False,
) -> None:
from pilot.speech.say import say_text
assistant_thoughts_reasoning = None
assistant_thoughts_plan = None
assistant_thoughts_speak = None

View File

@ -1,18 +1,7 @@
from __future__ import annotations
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from typing import List
from pilot.scene.message import OnceConversation

View File

@ -7,9 +7,7 @@ from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.scene.message import (
OnceConversation,
conversation_from_dict,
_conversation_to_dic,
conversations_to_dict,
)
from pilot.common.formatting import MyEncoder

View File

@ -1,17 +1,9 @@
from typing import List
import json
import os
import datetime
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pathlib import Path
from pilot.configs.config import Config
from pilot.scene.message import (
OnceConversation,
conversation_from_dict,
conversations_to_dict,
)
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
from pilot.scene.message import OnceConversation
from pilot.common.custom_data_structure import FixedSizeDict
CFG = Config()

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import os
import re
from pathlib import Path
@ -14,7 +13,7 @@ from transformers import (
LlamaTokenizer,
)
from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
from pilot.configs.model_config import DEVICE
from pilot.configs.model_config import get_device
from pilot.configs.config import Config
from pilot.logs import logger
@ -147,9 +146,11 @@ class ChatGLMAdapater(BaseLLMAdaper):
return "chatglm" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
import torch
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if DEVICE != "cuda":
if get_device() != "cuda":
model = AutoModel.from_pretrained(
model_path, trust_remote_code=True, **from_pretrained_kwargs
).float()

View File

@ -1,6 +1,3 @@
import json
import hashlib
from typing import Any, Dict
from abc import ABC, abstractmethod

View File

@ -2,11 +2,7 @@ import click
import functools
from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.worker.manager import (
RemoteWorkerManager,
WorkerApplyRequest,
WorkerApplyType,
)
from pilot.model.base import WorkerApplyType
from pilot.model.parameter import (
ModelControllerParameters,
ModelWorkerParameters,
@ -15,12 +11,14 @@ from pilot.model.parameter import (
from pilot.utils import get_or_create_event_loop
from pilot.utils.parameter_utils import EnvArgumentParser
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
@click.group("model")
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
default=MODEL_CONTROLLER_ADDRESS,
required=False,
show_default=True,
help=(
@ -28,24 +26,25 @@ from pilot.utils.parameter_utils import EnvArgumentParser
"Just support light deploy model"
),
)
def model_cli_group():
def model_cli_group(address: str):
"""Clients that manage model serving"""
pass
global MODEL_CONTROLLER_ADDRESS
MODEL_CONTROLLER_ADDRESS = address
@model_cli_group.command()
@click.option(
"--model-name", type=str, default=None, required=False, help=("The name of model")
"--model_name", type=str, default=None, required=False, help=("The name of model")
)
@click.option(
"--model-type", type=str, default="llm", required=False, help=("The type of model")
"--model_type", type=str, default="llm", required=False, help=("The type of model")
)
def list(address: str, model_name: str, model_type: str):
def list(model_name: str, model_type: str):
"""List model instances"""
from prettytable import PrettyTable
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
registry = ModelRegistryClient(MODEL_CONTROLLER_ADDRESS)
if not model_name:
instances = loop.run_until_complete(registry.get_all_model_instances())
@ -88,14 +87,14 @@ def list(address: str, model_name: str, model_type: str):
def add_model_options(func):
@click.option(
"--model-name",
"--model_name",
type=str,
default=None,
required=True,
help=("The name of model"),
)
@click.option(
"--model-type",
"--model_type",
type=str,
default="llm",
required=False,
@ -110,23 +109,27 @@ def add_model_options(func):
@model_cli_group.command()
@add_model_options
def stop(address: str, model_name: str, model_type: str):
def stop(model_name: str, model_type: str):
"""Stop model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)
worker_apply(MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.STOP)
@model_cli_group.command()
@add_model_options
def start(address: str, model_name: str, model_type: str):
def start(model_name: str, model_type: str):
"""Start model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.START)
worker_apply(
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.START
)
@model_cli_group.command()
@add_model_options
def restart(address: str, model_name: str, model_type: str):
def restart(model_name: str, model_type: str):
"""Restart model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)
worker_apply(
MODEL_CONTROLLER_ADDRESS, model_name, model_type, WorkerApplyType.RESTART
)
# @model_cli_group.command()
@ -139,6 +142,8 @@ def restart(address: str, model_name: str, model_type: str):
def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
):
from pilot.model.worker.manager import RemoteWorkerManager, WorkerApplyRequest
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)

View File

@ -6,7 +6,7 @@ Conversation prompt templates.
import dataclasses
from enum import auto, IntEnum
from typing import List, Any, Dict, Callable
from typing import List, Dict, Callable
class SeparatorStyle(IntEnum):

View File

@ -9,8 +9,6 @@ from typing import Iterable, Dict
import torch
import torch
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,

View File

@ -2,7 +2,7 @@
Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
"""
import re
from typing import Dict, Any
from typing import Dict
import torch
import llama_cpp

View File

@ -7,13 +7,7 @@ import time
from typing import Optional
from pilot.configs.config import Config
from pilot.conversation import (
Conversation,
auto_dbgpt_one_shot,
conv_one_shot,
conv_templates,
)
from pilot.model.llm.base import Message
from pilot.conversation import Conversation
# TODO Rewrite this

View File

@ -3,11 +3,9 @@
from typing import List
import re
import copy
import torch
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
from pilot.scene.base_message import ModelMessage, _parse_model_messages
# TODO move sep to scene prompt of model

View File

@ -1,5 +1,4 @@
import torch
import copy
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria

View File

@ -2,16 +2,13 @@
# -*- coding: utf-8 -*-
from typing import Optional, Dict
import torch
from pilot.configs.model_config import DEVICE
from pilot.configs.model_config import get_device
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
from pilot.model.compression import compress_module
from pilot.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
)
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
from pilot.utils import get_gpu_memory
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
from pilot.logs import logger
@ -67,7 +64,7 @@ class ModelLoader:
"""
def __init__(self, model_path: str, model_name: str = None) -> None:
self.device = DEVICE
self.device = get_device()
self.model_path = model_path
self.model_name = model_name
self.prompt_template: str = None
@ -127,6 +124,9 @@ class ModelLoader:
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
import torch
from pilot.model.compression import compress_module
device = model_params.device
max_memory = None
@ -156,6 +156,10 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
elif device == "mps":
kwargs = {"torch_dtype": torch.float16}
from pilot.model.llm.monkey_patch import (
replace_llama_attn_with_non_inplace_operations,
)
replace_llama_attn_with_non_inplace_operations()
else:
raise ValueError(f"Invalid device: {device}")
@ -200,6 +204,8 @@ def load_huggingface_quantization_model(
kwargs: Dict,
max_memory: Dict[int, str],
):
import torch
try:
from accelerate import init_empty_weights
from accelerate.utils import infer_auto_device_map

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from dataclasses import dataclass, field, fields, MISSING
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional
from typing import Dict, Optional
from pilot.model.conversation import conv_templates
from pilot.utils.parameter_utils import BaseParameters

View File

@ -2,8 +2,7 @@ import logging
import platform
from typing import Dict, Iterator, List
import torch
from pilot.configs.model_config import DEVICE
from pilot.configs.model_config import get_device
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path
@ -63,7 +62,7 @@ class DefaultModelWorker(ModelWorker):
model_type=model_type,
)
if not model_params.device:
model_params.device = DEVICE
model_params.device = get_device()
logger.info(
f"[DefaultModelWorker] Parameters of device is None, use {model_params.device}"
)
@ -88,6 +87,8 @@ class DefaultModelWorker(ModelWorker):
_clear_torch_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
import torch
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
@ -95,7 +96,7 @@ class DefaultModelWorker(ModelWorker):
)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, self.context_len
self.model, self.tokenizer, params, get_device(), self.context_len
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,

View File

@ -1,7 +1,7 @@
import logging
from typing import Dict, List, Type
from pilot.configs.model_config import DEVICE
from pilot.configs.model_config import get_device
from pilot.model.loader import _get_model_real_path
from pilot.model.parameter import (
EmbeddingModelParameters,
@ -55,7 +55,7 @@ class EmbeddingsModelWorker(ModelWorker):
model_path=self.model_path,
)
if not model_params.device:
model_params.device = DEVICE
model_params.device = get_device()
logger.info(
f"[EmbeddingsModelWorker] Parameters of device is None, use {model_params.device}"
)

View File

@ -1,5 +1,4 @@
import asyncio
import httpx
import itertools
import json
import os
@ -7,26 +6,21 @@ import random
import time
from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
import uvicorn
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import StreamingResponse
from pilot.configs.model_config import LOGDIR
from pilot.model.base import (
ModelInstance,
ModelOutput,
WorkerApplyType,
WorkerApplyOutput,
WorkerApplyType,
)
from pilot.model.controller.registry import ModelRegistry
from pilot.model.parameter import (
ModelParameters,
ModelWorkerParameters,
WorkerType,
)
from pilot.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from pilot.model.worker.base import ModelWorker
from pilot.scene.base_message import ModelMessage
from pilot.utils import build_logger
@ -431,6 +425,8 @@ class RemoteWorkerManager(LocalWorkerManager):
return worker_instances
async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
import httpx
async def _remote_apply_func(worker_run_data: WorkerRunData):
worker_addr = worker_run_data.worker.worker_addr
async with httpx.AsyncClient() as client:
@ -700,6 +696,8 @@ def run_worker_manager(
app.include_router(router, prefix="/api")
if not embedded_mod:
import uvicorn
uvicorn.run(
app, host=worker_params.host, port=worker_params.port, log_level="info"
)

View File

@ -1,7 +1,6 @@
import json
from typing import Dict, Iterator, List
import httpx
import logging
from pilot.model.base import ModelOutput
from pilot.model.parameter import ModelParameters
from pilot.model.worker.base import ModelWorker
@ -10,7 +9,8 @@ from pilot.model.worker.base import ModelWorker
class RemoteModelWorker(ModelWorker):
def __init__(self) -> None:
self.headers = {}
self.timeout = 60
# TODO Configured by ModelParameters
self.timeout = 180
self.host = None
self.port = None
@ -44,7 +44,9 @@ class RemoteModelWorker(ModelWorker):
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
"""Asynchronous generate stream"""
print(f"Send async_generate_stream, params: {params}")
import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client:
delimiter = b"\0"
buffer = b""
@ -71,8 +73,9 @@ class RemoteModelWorker(ModelWorker):
async def async_generate(self, params: Dict) -> ModelOutput:
"""Asynchronous generate non stream"""
print(f"Send async_generate_stream, params: {params}")
import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client:
response = await client.post(
self.worker_addr + "/generate",
@ -88,6 +91,8 @@ class RemoteModelWorker(ModelWorker):
async def async_embeddings(self, params: Dict) -> List[List[float]]:
"""Asynchronous get embeddings for input"""
import httpx
async with httpx.AsyncClient() as client:
response = await client.post(
self.worker_addr + "/embeddings",

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
from typing import TypeVar, Generic, Any
T = TypeVar("T")

View File

@ -1,24 +1,6 @@
from fastapi import (
APIRouter,
Request,
Body,
status,
HTTPException,
Response,
BackgroundTasks,
)
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.encoders import jsonable_encoder
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from pilot.openapi.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
from pilot.openapi.api_view_model import Result
async def validation_exception_handler(request: Request, exc: RequestValidationError):

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
from typing import List, Any
class DbField(BaseModel):

View File

@ -1,8 +1,7 @@
from __future__ import annotations
import json
import re
from abc import ABC, abstractmethod
from abc import ABC
from dataclasses import asdict
from typing import Any, Dict, TypeVar, Union

View File

@ -1,10 +1,7 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from typing import List
import yaml
from pydantic import BaseModel, Extra, Field, root_validator
from pydantic import BaseModel
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from typing import List
from pilot.common.schema import ExampleType

View File

@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pydantic import BaseModel, Extra, Field, root_validator
from abc import ABC
from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel
from pilot.common.formatting import formatter, no_strict_formatter

View File

@ -3,7 +3,6 @@
from collections import defaultdict
from typing import Dict, List
import json
_DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__"
_DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__"

View File

@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from pilot.out_parser.base import BaseOutputParser
from pilot.prompts.base import PromptValue
from pilot.scene.base_message import HumanMessage, AIMessage, SystemMessage, BaseMessage
from pilot.scene.base_message import HumanMessage, BaseMessage
from pilot.common.formatting import formatter

View File

@ -1,44 +1,20 @@
import time
from abc import ABC, abstractmethod
import datetime
import traceback
import warnings
import json
from pydantic import BaseModel, Field, root_validator, validator, Extra
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
import requests
from urllib.parse import urljoin
from abc import ABC, abstractmethod
from typing import Any, List
import pilot.configs.config
from pilot.scene.message import OnceConversation
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.memory.chat_history.file_history import FileHistoryMemory
from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop
from pilot.scene.base_message import (
BaseMessage,
SystemMessage,
HumanMessage,
AIMessage,
ViewMessage,
ModelMessage,
ModelMessageRoleType,
)
from pilot.configs.config import Config
from pilot.prompts.prompt_new import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation
from pilot.utils import build_logger, get_or_create_event_loop
from pydantic import Extra
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"}

View File

@ -1,20 +1,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
Tuple,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from typing import Any, Dict, List, Tuple, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from pydantic import BaseModel, Field, root_validator
class PromptValue(BaseModel, ABC):

View File

@ -3,7 +3,7 @@ import os
import uuid
from typing import List
from pilot.scene.base_chat import BaseChat, logger
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.scene.chat_dashboard.data_preparation.report_schma import (

View File

@ -1,7 +1,5 @@
import json
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
from dataclasses import dataclass, asdict
from pydantic import BaseModel
from typing import List, Any
class ValueItem(BaseModel):

View File

@ -1,9 +1,5 @@
import json
import re
from dataclasses import dataclass, asdict
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List
import pandas as pd
from typing import NamedTuple, List
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR

View File

@ -3,17 +3,11 @@ import os
from typing import List, Any, Dict
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat, logger
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning

View File

@ -1,16 +1,7 @@
import json
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt
CFG = Config()

View File

@ -1,8 +1,5 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
@ -36,6 +33,8 @@ class DbChatOutputParser(BaseOutputParser):
return SqlAction(sql, thoughts)
def parse_view_response(self, speak, data) -> str:
import pandas as pd
### tool out data to table view
data_loader = DbDataLoader()
if len(data) <= 1:

View File

@ -2,7 +2,7 @@ import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_db.auto_execute.example import sql_data_example

View File

@ -5,7 +5,7 @@ import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_db.auto_execute.example import sql_data_example

View File

@ -1,8 +1,7 @@
import pandas as pd
class DbDataLoader:
def get_table_view_by_conn(self, data, speak):
import pandas as pd
### tool out data to table view
if len(data) <= 1:
data.insert(0, ["result"])

View File

@ -1,14 +1,7 @@
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.professional_qa.prompt import prompt
CFG = Config()

View File

@ -1,12 +1,6 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.out_parser.base import BaseOutputParser, T
from pilot.utils import build_logger
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")

View File

@ -1,5 +1,3 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene

View File

@ -1,11 +1,6 @@
import requests
import datetime
from urllib.parse import urljoin
from typing import List
import traceback
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.message import OnceConversation
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.commands.command import execute_command

View File

@ -1,8 +1,5 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR

View File

@ -1,20 +1,20 @@
from pilot.scene.base_chat import BaseChat
from pilot.singleton import Singleton
import inspect
import importlib
from pilot.scene.chat_execution.chat import ChatWithPlugin
from pilot.scene.chat_normal.chat import ChatNormal
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_dashboard.chat import ChatDashboard
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
class ChatFactory(metaclass=Singleton):
@staticmethod
def get_implementation(chat_mode, **kwargs):
# Lazy loading
from pilot.scene.chat_execution.chat import ChatWithPlugin
from pilot.scene.chat_normal.chat import ChatNormal
from pilot.scene.chat_db.professional_qa.chat import ChatWithDbQA
from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
from pilot.scene.chat_dashboard.chat import ChatDashboard
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
chat_classes = BaseChat.__subclasses__()
implementation = None
for cls in chat_classes:

View File

@ -1,8 +1,3 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR

View File

@ -1,5 +1,3 @@
import builtins
import importlib
import json
from pilot.prompts.prompt_new import PromptTemplate

View File

@ -1,25 +1,15 @@
from chromadb.errors import NoIndexException
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.server.knowledge.service import KnowledgeService
CFG = Config()
@ -32,6 +22,8 @@ class ChatKnowledge(BaseChat):
def __init__(self, chat_session_id, user_input, select_param: str = None):
""" """
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
self.knowledge_space = select_param
super().__init__(
chat_mode=ChatScene.ChatKnowledge,

View File

@ -1,8 +1,3 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR

View File

@ -1,6 +1,3 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene

View File

@ -1,6 +1,3 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene

View File

@ -1,13 +1,7 @@
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_markdown_table,
generate_htm_table,
datas_to_table_html,
)
from pilot.scene.chat_normal.prompt import prompt
CFG = Config()

View File

@ -1,8 +1,3 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR

View File

@ -1,6 +1,3 @@
import builtins
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene

View File

@ -1,13 +1,6 @@
from __future__ import annotations
from datetime import datetime, timedelta
from pydantic import BaseModel, Field, root_validator, validator
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
)
from datetime import datetime
from typing import List
from pilot.scene.base_message import (
BaseMessage,

View File

@ -1,28 +1,18 @@
import signal
import os
import threading
import traceback
import sys
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry
from pilot.configs.config import Config
# from pilot.configs.model_config import (
# DATASETS_DIR,
# KNOWLEDGE_UPLOAD_ROOT_PATH,
# LLM_MODEL_CONFIG,
# LOGDIR,
# )
from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.utils import build_logger
from pilot.common.plugins import scan_plugins
from pilot.connections.manages.connection_manager import ConnectManager
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
# logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")

View File

@ -3,7 +3,6 @@
from functools import cache
from typing import List, Dict, Tuple
from pilot.model.llm_out.vicuna_base_llm import generate_stream
from pilot.model.conversation import Conversation, get_conv_template
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
@ -131,6 +130,8 @@ class VicunaChatAdapter(BaseChatAdpter):
return None
def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.vicuna_base_llm import generate_stream
if self._is_llama2_based(model_path):
return super().get_generate_stream_func(model_path)
return generate_stream

View File

@ -1,7 +1,4 @@
import atexit
import traceback
import os
import shutil
import argparse
import sys
import logging
@ -11,7 +8,6 @@ sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.utils import build_logger
from pilot.server.base import server_init
@ -28,14 +24,11 @@ from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.worker.manager import initialize_worker_manager_in_client
logging.basicConfig(level=logging.INFO, encoding="utf-8")
from pilot.utils.utils import setup_logging
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config()
# logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler():
@ -102,7 +95,7 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
parser.add_argument("--log-level", type=str, default="info")
parser.add_argument("--log-level", type=str, default=None)
parser.add_argument(
"-light",
"--light",
@ -113,6 +106,7 @@ if __name__ == "__main__":
# init server config
args = parser.parse_args()
setup_logging(logging_level=args.log_level)
server_init(args)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
@ -137,6 +131,5 @@ if __name__ == "__main__":
mount_static_files(app)
import uvicorn
logging.basicConfig(level=logging.INFO, encoding="utf-8")
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=args.log_level)
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="info")
signal.signal(signal.SIGINT, signal_handler())

View File

@ -1,7 +1,6 @@
import os
import shutil
import tempfile
from tempfile import NamedTemporaryFile
from fastapi import APIRouter, File, UploadFile, Form

View File

@ -1,8 +1,8 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao

View File

@ -1,7 +1,7 @@
from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao

View File

@ -2,12 +2,10 @@ import json
import threading
from datetime import datetime
from langchain.text_splitter import RecursiveCharacterTextSplitter, SpacyTextSplitter
from pilot.vector_store.connector import VectorStoreConnector
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.logs import logger
from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity,
@ -152,6 +150,14 @@ class KnowledgeService:
"""sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, doc_ids):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
# import langchain is very very slow!!!
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
id=doc_id,

View File

@ -1,8 +1,7 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime, create_engine
from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from pilot.configs.config import Config
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest

View File

@ -4,9 +4,6 @@
import os
import sys
global_counter = 0
model_semaphore = None
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@ -17,15 +14,6 @@ from pilot.model.worker.manager import run_worker_manager
CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
# worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
# @app.post("/embedding")
# def embeddings(prompt_request: EmbeddingRequest):
# params = {"prompt": prompt_request.prompt}
# print("Received prompt: ", params["prompt"])
# output = worker.get_embeddings(params["prompt"])
# return {"response": [float(x) for x in output]}
if __name__ == "__main__":
run_worker_manager(

View File

@ -1,3 +0,0 @@
from pilot.speech.say import say_text
__all__ = ["say_text"]

View File

@ -1,21 +1,19 @@
import json
import uuid
from langchain.embeddings import HuggingFaceEmbeddings, logger
from pilot.common.schema import DBType
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.summary.rdbms_db_summary import RdbmsSummary
from pilot.scene.chat_factory import ChatFactory
from pilot.common.schema import DBType
from pilot.configs.model_config import LOGDIR
from pilot.summary.rdbms_db_summary import RdbmsSummary
from pilot.utils import build_logger
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
@ -33,6 +31,8 @@ class DBSummaryClient:
def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store"""
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.embedding_engine.string_embedding import StringEmbedding
db_summary_client = RdbmsSummary(dbname, db_type)
embeddings = HuggingFaceEmbeddings(
@ -82,6 +82,8 @@ class DBSummaryClient:
logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
vector_store_config = {
"vector_store_name": dbname + "_profile",
"vector_store_type": CFG.VECTOR_STORE_TYPE,
@ -97,6 +99,8 @@ class DBSummaryClient:
def get_similar_tables(self, dbname, query, topk):
"""get user query related tables info"""
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
vector_store_config = {
"vector_store_name": dbname + "_summary",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -149,6 +153,8 @@ class DBSummaryClient:
)
def init_db_profile(self, db_summary_client, dbname, embeddings):
from pilot.embedding_engine.string_embedding import StringEmbedding
profile_store_config = {
"vector_store_name": dbname + "_profile",
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,

View File

@ -1,6 +1,4 @@
import httpx
from inspect import signature
import typing_inspect
import logging
from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple
from dataclasses import is_dataclass, asdict
@ -9,6 +7,8 @@ T = TypeVar("T")
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
import typing_inspect
"""Extract actual dataclass from generic type hints like List[dataclass], Optional[dataclass], etc."""
if typing_inspect.is_generic_type(type_hint) and typing_inspect.get_args(type_hint):
return typing_inspect.get_args(type_hint)[0]
@ -30,6 +30,8 @@ def _api_remote(path, method="GET"):
sig = signature(func)
async def wrapper(self, *args, **kwargs):
import httpx
base_url = self.base_url # Get base_url from class instance
bound = sig.bind(self, *args, **kwargs)

View File

@ -16,6 +16,22 @@ server_error_msg = (
handler = None
def _get_logging_level() -> str:
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
def setup_logging(logging_level=None, logger_name: str = None):
if not logging_level:
logging_level = _get_logging_level()
if type(logging_level) is str:
logging_level = logging.getLevelName(logging_level.upper())
if logger_name:
logger = logging.getLogger(logger_name)
logger.setLevel(logging_level)
else:
logging.basicConfig(level=logging_level, encoding="utf-8")
def get_gpu_memory(max_gpus=None):
import torch
@ -47,7 +63,7 @@ def build_logger(logger_name, logger_filename):
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO, encoding="utf-8")
setup_logging()
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
@ -73,11 +89,11 @@ def build_logger(logger_name, logger_filename):
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
logging.basicConfig(level=logging.INFO, encoding="utf-8")
setup_logging()
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
setup_logging(logger_name=logger_name)
return logger

View File

@ -2,7 +2,6 @@ import os
from typing import Any
from chromadb.config import Settings
from langchain.vectorstores import Chroma
from pilot.logs import logger
from pilot.vector_store.base import VectorStoreBase
@ -11,6 +10,8 @@ class ChromaStore(VectorStoreBase):
"""chroma database"""
def __init__(self, ctx: {}) -> None:
from langchain.vectorstores import Chroma
self.ctx = ctx
self.embeddings = ctx.get("embeddings", None)
self.persist_dir = os.path.join(

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Any, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document
from pymilvus import Collection, DataType, connections, utility
from pilot.logs import logger
@ -279,7 +280,9 @@ class MilvusStore(VectorStoreBase):
round_decimal: int = -1,
timeout: Optional[int] = None,
**kwargs: Any,
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]:
):
from langchain.docstore.document import Document
self.col.load()
# use default index params.
if param is None: