refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

4
.gitignore vendored
View File

@@ -26,9 +26,9 @@ sdist/
var/ var/
wheels/ wheels/
models/ /models/
# Soft link # Soft link
models /models
plugins/ plugins/
pip-wheel-metadata/ pip-wheel-metadata/

17
dbgpt/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
from dbgpt.component import SystemApp, BaseComponent
__ALL__ = ["SystemApp", "BaseComponent"]
_CORE_LIBS = ["core", "rag", "model", "agent", "datasource", "vis", "storage", "train"]
_SERVE_LIBS = ["serve"]
_LIBS = _CORE_LIBS + _SERVE_LIBS
def __getattr__(name: str):
# Lazy load
import importlib
if name in _LIBS:
return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,4 @@
"""This is a private module.
You should not import anything from this module.
"""

View File

@@ -1,10 +1,8 @@
import asyncio import asyncio
from typing import Coroutine, List, Any from typing import Coroutine, List, Any
from starlette.responses import StreamingResponse
from pilot.scene.base_chat import BaseChat from dbgpt.app.scene import BaseChat, ChatFactory
from pilot.scene.chat_factory import ChatFactory
chat_factory = ChatFactory() chat_factory = ChatFactory()

View File

@@ -5,11 +5,11 @@ from __future__ import annotations
import os import os
from typing import List, Optional, TYPE_CHECKING from typing import List, Optional, TYPE_CHECKING
from pilot.singleton import Singleton from dbgpt.util.singleton import Singleton
if TYPE_CHECKING: if TYPE_CHECKING:
from auto_gpt_plugin_template import AutoGPTPluginTemplate from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.component import SystemApp from dbgpt.component import SystemApp
class Config(metaclass=Singleton): class Config(metaclass=Singleton):
@@ -120,7 +120,7 @@ class Config(metaclass=Singleton):
) )
self.speak_mode = False self.speak_mode = False
from pilot.prompts.prompt_registry import PromptTemplateRegistry from dbgpt.core._private.prompt_registry import PromptTemplateRegistry
self.prompt_template_registry = PromptTemplateRegistry() self.prompt_template_registry = PromptTemplateRegistry()
### Related configuration of built-in commands ### Related configuration of built-in commands

View File

@@ -1,4 +1,4 @@
from pydantic import Field, BaseModel from dbgpt._private.pydantic import Field, BaseModel
DEFAULT_CONTEXT_WINDOW = 3900 DEFAULT_CONTEXT_WINDOW = 3900
DEFAULT_NUM_OUTPUTS = 256 DEFAULT_NUM_OUTPUTS = 256

View File

@@ -0,0 +1,33 @@
import pydantic
if pydantic.VERSION.startswith("1."):
PYDANTIC_VERSION = 1
from pydantic import (
BaseModel,
Extra,
Field,
NonNegativeFloat,
NonNegativeInt,
PositiveFloat,
PositiveInt,
ValidationError,
root_validator,
validator,
PrivateAttr,
)
else:
PYDANTIC_VERSION = 2
# pydantic 2.x
from pydantic.v1 import (
BaseModel,
Extra,
Field,
NonNegativeFloat,
NonNegativeInt,
PositiveFloat,
PositiveInt,
ValidationError,
root_validator,
validator,
PrivateAttr,
)

View File

@@ -3,8 +3,8 @@ import json
import requests import requests
from pilot.base_modules.agent.commands.command_mange import command from dbgpt.agent.commands.command_mange import command
from pilot.configs.config import Config from dbgpt._private.config import Config
CFG = Config() CFG = Config()

View File

@@ -7,8 +7,8 @@ import logging
import requests import requests
from PIL import Image from PIL import Image
from pilot.base_modules.agent.commands.command_mange import command from dbgpt.agent.commands.command_mange import command
from pilot.configs.config import Config from dbgpt._private.config import Config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()

View File

@@ -7,7 +7,7 @@ from typing import Dict
from .exception_not_commands import NotCommands from .exception_not_commands import NotCommands
from .generator import PluginPromptGenerator from .generator import PluginPromptGenerator
from pilot.configs.config import Config from dbgpt._private.config import Config
def _resolve_pathlike_command_args(command_args): def _resolve_pathlike_command_args(command_args):
@@ -36,7 +36,7 @@ def execute_ai_response_json(
Returns: Returns:
""" """
from pilot.speech.say import say_text from dbgpt.util.speech.say import say_text
cfg = Config() cfg = Config()

View File

@@ -1,20 +1,17 @@
import functools import functools
import importlib import importlib
import inspect import inspect
import time
import json import json
import logging import logging
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import pandas as pd
from pilot.common.json_utils import serialize from dbgpt.util.json_utils import serialize
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Optional, List from typing import Any, Callable, Optional, List
from pydantic import BaseModel from dbgpt._private.pydantic import BaseModel
from pilot.base_modules.agent.common.schema import Status, ApiTagType from dbgpt.agent.common.schema import Status
from pilot.base_modules.agent.commands.command import execute_command from dbgpt.agent.commands.command import execute_command
from pilot.base_modules.agent.commands.generator import PluginPromptGenerator from dbgpt.util.string_utils import extract_content_open_ending, extract_content
from pilot.common.string_utils import extract_content_open_ending, extract_content
# Unique identifier for auto-gpt commands # Unique identifier for auto-gpt commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command" AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"

View File

@@ -1,6 +1,6 @@
from pandas import DataFrame from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command from dbgpt.agent.commands.command_mange import command
import pandas as pd import pandas as pd
import uuid import uuid
import os import os
@@ -11,14 +11,15 @@ matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as mtick import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager from matplotlib.font_manager import FontManager
from pilot.common.string_utils import is_scientific_notation from dbgpt.util.string_utils import is_scientific_notation
from dbgpt.configs.model_config import PILOT_PATH
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
static_message_img_path = os.path.join(os.getcwd(), "message/img") static_message_img_path = os.path.join(PILOT_PATH, "message/img")
def data_pre_classification(df: DataFrame): def data_pre_classification(df: DataFrame):

View File

@@ -1,6 +1,6 @@
from pandas import DataFrame from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command from dbgpt.agent.commands.command_mange import command
import logging import logging

View File

@@ -1,6 +1,6 @@
from pandas import DataFrame from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command from dbgpt.agent.commands.command_mange import command
import logging import logging

View File

@@ -1,5 +1,4 @@
""" A module for generating custom prompt strings.""" """ A module for generating custom prompt strings."""
import json
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional

View File

@@ -1,5 +1,3 @@
import json
import time
import logging import logging
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
@@ -7,12 +5,10 @@ from fastapi import (
UploadFile, UploadFile,
File, File,
) )
from abc import ABC, abstractmethod from abc import ABC
from typing import List from typing import List
from pilot.configs.model_config import LOGDIR
from dbgpt.app.openapi.api_view_model import (
from pilot.openapi.api_view_model import (
Result, Result,
) )
@@ -21,15 +17,14 @@ from .model import (
PagenationFilter, PagenationFilter,
PagenationResult, PagenationResult,
PluginHubFilter, PluginHubFilter,
MyPluginFilter,
) )
from .hub.agent_hub import AgentHub from .hub.agent_hub import AgentHub
from .db.plugin_hub_db import PluginHubEntity from .db.plugin_hub_db import PluginHubEntity
from .plugins_util import scan_plugins from .plugins_util import scan_plugins
from .commands.generator import PluginPromptGenerator from .commands.generator import PluginPromptGenerator
from pilot.configs.model_config import PLUGINS_DIR from dbgpt.configs.model_config import PLUGINS_DIR
from pilot.component import BaseComponent, ComponentType, SystemApp from dbgpt.component import BaseComponent, ComponentType, SystemApp
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -79,6 +74,7 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
if update_param.branch is not None and len(update_param.branch) > 0 if update_param.branch is not None and len(update_param.branch) > 0
else None else None
) )
# TODO change it to async
agent_hub.refresh_hub_from_git(update_param.url, branch, authorization) agent_hub.refresh_hub_from_git(update_param.url, branch, authorization)
return Result.succ(None) return Result.succ(None)
except Exception as e: except Exception as e:

View File

@@ -1,10 +1,9 @@
from datetime import datetime from datetime import datetime
from typing import List from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import Column, Integer, String, Index, DateTime, func
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from pilot.base_modules.meta_data.base_dao import BaseDao from dbgpt.storage.metadata import BaseDao
from pilot.base_modules.meta_data.meta_data import ( from dbgpt.storage.metadata.meta_data import (
Base, Base,
engine, engine,
session, session,

View File

@@ -1,12 +1,10 @@
from datetime import datetime from datetime import datetime
import pytz import pytz
from typing import List from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, DDL
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from pilot.base_modules.meta_data.meta_data import Base
from pilot.base_modules.meta_data.base_dao import BaseDao from dbgpt.storage.metadata import BaseDao
from pilot.base_modules.meta_data.meta_data import ( from dbgpt.storage.metadata.meta_data import (
Base, Base,
engine, engine,
session, session,

View File

@@ -1,7 +1,7 @@
from typing import TypedDict, Optional, Dict, List from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass from dataclasses import dataclass
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any from typing import TypeVar, Generic, Any
from dbgpt._private.pydantic import BaseModel, Field
T = TypeVar("T") T = TypeVar("T")

View File

@@ -4,22 +4,19 @@ import json
import os import os
import glob import glob
import zipfile import zipfile
import fnmatch
import requests
import git import git
import threading import threading
import datetime import datetime
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from urllib.parse import urlparse
from zipimport import zipimporter from zipimport import zipimporter
import requests import requests
from auto_gpt_plugin_template import AutoGPTPluginTemplate from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.configs.model_config import PLUGINS_DIR from dbgpt.configs.model_config import PLUGINS_DIR
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,10 +1,12 @@
"""The app package.
This package will not be uploaded to PyPI. So, your can't import it if some other package depends on it.
"""
import os import os
import random import random
import sys import sys
from dotenv import load_dotenv from dotenv import load_dotenv
from pilot.base_modules.agent import PluginPromptGenerator
if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
print("Setting random seed to 42") print("Setting random seed to 42")

View File

@@ -1,22 +1,22 @@
import click import click
import os import os
from pilot.server.base import WebWerverParameters from dbgpt.app.base import WebServerParameters
from pilot.configs.model_config import LOGDIR from dbgpt.configs.model_config import LOGDIR
from pilot.utils.parameter_utils import EnvArgumentParser from dbgpt.util.parameter_utils import EnvArgumentParser
from pilot.utils.command_utils import _run_current_with_daemon, _stop_service from dbgpt.util.command_utils import _run_current_with_daemon, _stop_service
@click.command(name="webserver") @click.command(name="webserver")
@EnvArgumentParser.create_click_option(WebWerverParameters) @EnvArgumentParser.create_click_option(WebServerParameters)
def start_webserver(**kwargs): def start_webserver(**kwargs):
"""Start webserver(dbgpt_server.py)""" """Start webserver(dbgpt_server.py)"""
if kwargs["daemon"]: if kwargs["daemon"]:
log_file = os.path.join(LOGDIR, "webserver_uvicorn.log") log_file = os.path.join(LOGDIR, "webserver_uvicorn.log")
_run_current_with_daemon("WebServer", log_file) _run_current_with_daemon("WebServer", log_file)
else: else:
from pilot.server.dbgpt_server import run_webserver from dbgpt.app.dbgpt_server import run_webserver
run_webserver(WebWerverParameters(**kwargs)) run_webserver(WebServerParameters(**kwargs))
@click.command(name="webserver") @click.command(name="webserver")

View File

@@ -2,13 +2,13 @@ import signal
import os import os
import threading import threading
import sys import sys
from typing import Optional, Any from typing import Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.component import SystemApp from dbgpt.component import SystemApp
from pilot.utils.parameter_utils import BaseParameters from dbgpt.util.parameter_utils import BaseParameters
from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade from dbgpt.storage.metadata.meta_data import ddl_init_and_upgrade
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
@@ -21,15 +21,15 @@ def signal_handler(sig, frame):
def async_db_summary(system_app: SystemApp): def async_db_summary(system_app: SystemApp):
"""async db schema into vector db""" """async db schema into vector db"""
from pilot.summary.db_summary_client import DBSummaryClient from dbgpt.rag.summary.db_summary_client import DBSummaryClient
client = DBSummaryClient(system_app=system_app) client = DBSummaryClient(system_app=system_app)
thread = threading.Thread(target=client.init_db_summary) thread = threading.Thread(target=client.init_db_summary)
thread.start() thread.start()
def server_init(param: "WebWerverParameters", system_app: SystemApp): def server_init(param: "WebServerParameters", system_app: SystemApp):
from pilot.base_modules.agent.commands.command_mange import CommandRegistry from dbgpt.agent.commands.command_mange import CommandRegistry
# logger.info(f"args: {args}") # logger.info(f"args: {args}")
@@ -44,8 +44,8 @@ def server_init(param: "WebWerverParameters", system_app: SystemApp):
# Loader plugins and commands # Loader plugins and commands
command_categories = [ command_categories = [
"pilot.base_modules.agent.commands.built_in.audio_text", "dbgpt.agent.commands.built_in.audio_text",
"pilot.base_modules.agent.commands.built_in.image_gen", "dbgpt.agent.commands.built_in.image_gen",
] ]
# exclude commands # exclude commands
command_categories = [ command_categories = [
@@ -58,9 +58,9 @@ def server_init(param: "WebWerverParameters", system_app: SystemApp):
cfg.command_registry = command_registry cfg.command_registry = command_registry
command_disply_commands = [ command_disply_commands = [
"pilot.base_modules.agent.commands.disply_type.show_chart_gen", "dbgpt.agent.commands.disply_type.show_chart_gen",
"pilot.base_modules.agent.commands.disply_type.show_table_gen", "dbgpt.agent.commands.disply_type.show_table_gen",
"pilot.base_modules.agent.commands.disply_type.show_text_gen", "dbgpt.agent.commands.disply_type.show_text_gen",
] ]
command_disply_registry = CommandRegistry() command_disply_registry = CommandRegistry()
for command in command_disply_commands: for command in command_disply_commands:
@@ -69,7 +69,7 @@ def server_init(param: "WebWerverParameters", system_app: SystemApp):
def _create_model_start_listener(system_app: SystemApp): def _create_model_start_listener(system_app: SystemApp):
from pilot.connections.manages.connection_manager import ConnectManager from dbgpt.datasource.manages.connection_manager import ConnectManager
cfg = Config() cfg = Config()
@@ -84,7 +84,7 @@ def _create_model_start_listener(system_app: SystemApp):
@dataclass @dataclass
class WebWerverParameters(BaseParameters): class WebServerParameters(BaseParameters):
host: Optional[str] = field( host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Webserver deploy host"} default="0.0.0.0", metadata={"help": "Webserver deploy host"}
) )

View File

@@ -1,14 +1,14 @@
""" """
This code file will be deprecated in the future. This code file will be deprecated in the future.
We have integrated fastchat. For details, see: pilot/model/model_adapter.py We have integrated fastchat. For details, see: dbgpt/model/model_adapter.py
""" """
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from functools import cache from functools import cache
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
from pilot.model.conversation import Conversation, get_conv_template from dbgpt.model.conversation import Conversation, get_conv_template
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
class BaseChatAdpter: class BaseChatAdpter:
@@ -20,7 +20,7 @@ class BaseChatAdpter:
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
"""Return the generate stream handler func""" """Return the generate stream handler func"""
from pilot.model.inference import generate_stream from dbgpt.model.inference import generate_stream
return generate_stream return generate_stream
@@ -134,7 +134,7 @@ class VicunaChatAdapter(BaseChatAdpter):
return None return None
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.vicuna_base_llm import generate_stream from dbgpt.model.llm_out.vicuna_base_llm import generate_stream
if self._is_llama2_based(model_path): if self._is_llama2_based(model_path):
return super().get_generate_stream_func(model_path) return super().get_generate_stream_func(model_path)
@@ -148,7 +148,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
return "chatglm" in model_path return "chatglm" in model_path
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream from dbgpt.model.llm_out.chatglm_llm import chatglm_generate_stream
return chatglm_generate_stream return chatglm_generate_stream
@@ -160,7 +160,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
return "guanaco" in model_path return "guanaco" in model_path
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream from dbgpt.model.llm_out.guanaco_llm import guanaco_generate_stream
return guanaco_generate_stream return guanaco_generate_stream
@@ -172,7 +172,7 @@ class FalconChatAdapter(BaseChatAdpter):
return "falcon" in model_path return "falcon" in model_path
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.falcon_llm import falcon_generate_output from dbgpt.model.llm_out.falcon_llm import falcon_generate_output
return falcon_generate_output return falcon_generate_output
@@ -182,7 +182,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
return "proxyllm" in model_path return "proxyllm" in model_path
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream from dbgpt.model.llm_out.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream return proxyllm_generate_stream
@@ -192,7 +192,7 @@ class GorillaChatAdapter(BaseChatAdpter):
return "gorilla" in model_path return "gorilla" in model_path
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.gorilla_llm import generate_stream from dbgpt.model.llm_out.gorilla_llm import generate_stream
return generate_stream return generate_stream
@@ -202,7 +202,7 @@ class GPT4AllChatAdapter(BaseChatAdpter):
return "gptj-6b" in model_path return "gptj-6b" in model_path
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream from dbgpt.model.llm_out.gpt4all_llm import gpt4all_generate_stream
return gpt4all_generate_stream return gpt4all_generate_stream
@@ -245,7 +245,7 @@ class WizardLMChatAdapter(BaseChatAdpter):
class LlamaCppChatAdapter(BaseChatAdpter): class LlamaCppChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
from pilot.model.adapter import LlamaCppAdapater from dbgpt.model.adapter import LlamaCppAdapater
if "llama-cpp" == model_path: if "llama-cpp" == model_path:
return True return True
@@ -256,7 +256,7 @@ class LlamaCppChatAdapter(BaseChatAdpter):
return get_conv_template("llama-2") return get_conv_template("llama-2")
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.llama_cpp_llm import generate_stream from dbgpt.model.llm_out.llama_cpp_llm import generate_stream
return generate_stream return generate_stream

View File

@@ -2,14 +2,13 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Type from typing import TYPE_CHECKING, Any, Type
import os
from pilot.component import ComponentType, SystemApp from dbgpt.component import ComponentType, SystemApp
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.configs.model_config import MODEL_DISK_CACHE_DIR from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR
from pilot.utils.executor_utils import DefaultExecutorFactory from dbgpt.util.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.base import WebWerverParameters from dbgpt.app.base import WebServerParameters
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@@ -21,23 +20,23 @@ CFG = Config()
def initialize_components( def initialize_components(
param: WebWerverParameters, param: WebServerParameters,
system_app: SystemApp, system_app: SystemApp,
embedding_model_name: str, embedding_model_name: str,
embedding_model_path: str, embedding_model_path: str,
): ):
from pilot.model.cluster.controller.controller import controller from dbgpt.model.cluster.controller.controller import controller
# Register global default executor factory first # Register global default executor factory first
system_app.register(DefaultExecutorFactory) system_app.register(DefaultExecutorFactory)
system_app.register_instance(controller) system_app.register_instance(controller)
# Register global default RAGGraphFactory # Register global default RAGGraphFactory
# from pilot.graph_engine.graph_factory import DefaultRAGGraphFactory # from dbgpt.graph_engine.graph_factory import DefaultRAGGraphFactory
# system_app.register(DefaultRAGGraphFactory) # system_app.register(DefaultRAGGraphFactory)
from pilot.base_modules.agent.controller import module_agent from dbgpt.agent.controller import module_agent
system_app.register_instance(module_agent) system_app.register_instance(module_agent)
@@ -49,7 +48,7 @@ def initialize_components(
def _initialize_embedding_model( def _initialize_embedding_model(
param: WebWerverParameters, param: WebServerParameters,
system_app: SystemApp, system_app: SystemApp,
embedding_model_name: str, embedding_model_name: str,
embedding_model_path: str, embedding_model_path: str,
@@ -79,8 +78,8 @@ class RemoteEmbeddingFactory(EmbeddingFactory):
def create( def create(
self, model_name: str = None, embedding_cls: Type = None self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings": ) -> "Embeddings":
from pilot.model.cluster import WorkerManagerFactory from dbgpt.model.cluster import WorkerManagerFactory
from pilot.model.cluster.embedding.remote_embedding import RemoteEmbeddings from dbgpt.model.cluster.embedding.remote_embedding import RemoteEmbeddings
if embedding_cls: if embedding_cls:
raise NotImplementedError raise NotImplementedError
@@ -116,9 +115,9 @@ class LocalEmbeddingFactory(EmbeddingFactory):
return self._model return self._model
def _load_model(self) -> "Embeddings": def _load_model(self) -> "Embeddings":
from pilot.model.cluster.embedding.loader import EmbeddingLoader from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params from dbgpt.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.parameter import ( from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters, BaseEmbeddingModelParameters,
EmbeddingModelParameters, EmbeddingModelParameters,
@@ -140,7 +139,7 @@ class LocalEmbeddingFactory(EmbeddingFactory):
def _initialize_model_cache(system_app: SystemApp): def _initialize_model_cache(system_app: SystemApp):
from pilot.cache import initialize_cache from dbgpt.storage.cache import initialize_cache
if not CFG.MODEL_CACHE_ENABLE: if not CFG.MODEL_CACHE_ENABLE:
logger.info("Model cache is not enable") logger.info("Model cache is not enable")
@@ -153,7 +152,7 @@ def _initialize_model_cache(system_app: SystemApp):
def _initialize_awel(system_app: SystemApp): def _initialize_awel(system_app: SystemApp):
from pilot.awel import initialize_awel from dbgpt.core.awel import initialize_awel
from pilot.configs.model_config import _DAG_DEFINITION_DIR from dbgpt.configs.model_config import _DAG_DEFINITION_DIR
initialize_awel(system_app, _DAG_DEFINITION_DIR) initialize_awel(system_app, _DAG_DEFINITION_DIR)

View File

@@ -2,53 +2,54 @@ import os
import argparse import argparse
import sys import sys
from typing import List from typing import List
import logging
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG, LOGDIR from dbgpt.configs.model_config import (
from pilot.component import SystemApp LLM_MODEL_CONFIG,
EMBEDDING_MODEL_CONFIG,
LOGDIR,
ROOT_PATH,
)
from dbgpt.component import SystemApp
from pilot.server.base import ( from dbgpt.app.base import (
server_init, server_init,
WebWerverParameters, WebServerParameters,
_create_model_start_listener, _create_model_start_listener,
) )
from pilot.server.component_configs import initialize_components from dbgpt.app.component_configs import initialize_components
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pilot.server.knowledge.api import router as knowledge_router from dbgpt.app.knowledge.api import router as knowledge_router
from pilot.server.prompt.api import router as prompt_router from dbgpt.app.prompt.api import router as prompt_router
from pilot.server.llm_manage.api import router as llm_manage_api from dbgpt.app.llm_manage.api import router as llm_manage_api
from pilot.openapi.api_v1.api_v1 import router as api_v1 from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler from dbgpt.app.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1 from dbgpt.app.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
from pilot.base_modules.agent.commands.disply_type.show_chart_gen import ( from dbgpt.agent.commands.disply_type.show_chart_gen import (
static_message_img_path, static_message_img_path,
) )
from pilot.model.cluster import initialize_worker_manager_in_client from dbgpt.model.cluster import initialize_worker_manager_in_client
from pilot.utils.utils import ( from dbgpt.util.utils import (
setup_logging, setup_logging,
_get_logging_level, _get_logging_level,
logging_str_to_uvicorn_level, logging_str_to_uvicorn_level,
setup_http_service_logging, setup_http_service_logging,
) )
from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName from dbgpt.util.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName
from pilot.utils.parameter_utils import _get_dict_from_obj from dbgpt.util.parameter_utils import _get_dict_from_obj
from pilot.utils.system_utils import get_system_info from dbgpt.util.system_utils import get_system_info
from pilot.base_modules.agent.controller import router as agent_route
static_file_path = os.path.join(ROOT_PATH, "dbgpt", "app/static")
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config() CFG = Config()
@@ -106,15 +107,15 @@ app.add_exception_handler(RequestValidationError, validation_exception_handler)
def _get_webserver_params(args: List[str] = None): def _get_webserver_params(args: List[str] = None):
from pilot.utils.parameter_utils import EnvArgumentParser from dbgpt.util.parameter_utils import EnvArgumentParser
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option( parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
WebWerverParameters WebServerParameters
) )
return WebWerverParameters(**vars(parser.parse_args(args=args))) return WebServerParameters(**vars(parser.parse_args(args=args)))
def initialize_app(param: WebWerverParameters = None, args: List[str] = None): def initialize_app(param: WebServerParameters = None, args: List[str] = None):
"""Initialize app """Initialize app
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook. If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
Args: Args:
@@ -127,7 +128,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
if not param.log_level: if not param.log_level:
param.log_level = _get_logging_level() param.log_level = _get_logging_level()
setup_logging( setup_logging(
"pilot", logging_level=param.log_level, logger_filename=param.log_file "dbgpt", logging_level=param.log_level, logger_filename=param.log_file
) )
# Before start # Before start
@@ -180,7 +181,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
return param return param
def run_uvicorn(param: WebWerverParameters): def run_uvicorn(param: WebServerParameters):
import uvicorn import uvicorn
setup_http_service_logging() setup_http_service_logging()
@@ -192,7 +193,7 @@ def run_uvicorn(param: WebWerverParameters):
) )
def run_webserver(param: WebWerverParameters = None): def run_webserver(param: WebServerParameters = None):
if not param: if not param:
param = _get_webserver_params() param = _get_webserver_params()
initialize_tracer( initialize_tracer(

View File

@@ -3,7 +3,7 @@ import logging
import os import os
import functools import functools
from pilot.configs.model_config import DATASETS_DIR from dbgpt.configs.model_config import DATASETS_DIR
_DEFAULT_API_ADDRESS: str = "http://127.0.0.1:5000" _DEFAULT_API_ADDRESS: str = "http://127.0.0.1:5000"
API_ADDRESS: str = _DEFAULT_API_ADDRESS API_ADDRESS: str = _DEFAULT_API_ADDRESS
@@ -129,7 +129,7 @@ def load(
chunk_overlap: int, chunk_overlap: int,
): ):
"""Load your local documents to DB-GPT""" """Load your local documents to DB-GPT"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_init from dbgpt.app.knowledge._cli.knowledge_client import knowledge_init
knowledge_init( knowledge_init(
API_ADDRESS, API_ADDRESS,
@@ -165,7 +165,7 @@ def load(
) )
def delete(space_name: str, doc_name: str, y: bool): def delete(space_name: str, doc_name: str, y: bool):
"""Delete your knowledge space or document in space""" """Delete your knowledge space or document in space"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_delete from dbgpt.app.knowledge._cli.knowledge_client import knowledge_delete
knowledge_delete(API_ADDRESS, space_name, doc_name, confirm=y) knowledge_delete(API_ADDRESS, space_name, doc_name, confirm=y)
@@ -227,7 +227,7 @@ def list(
output: str, output: str,
): ):
"""List knowledge space""" """List knowledge space"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_list from dbgpt.app.knowledge._cli.knowledge_client import knowledge_list
knowledge_list( knowledge_list(
API_ADDRESS, space_name, page, page_size, doc_id, show_content, output API_ADDRESS, space_name, page, page_size, doc_id, show_content, output

View File

@@ -6,18 +6,18 @@ import logging
from urllib.parse import urljoin from urllib.parse import urljoin
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from pilot.openapi.api_view_model import Result from dbgpt.app.openapi.api_view_model import Result
from pilot.server.knowledge.request.request import ( from dbgpt.app.knowledge.request.request import (
KnowledgeQueryRequest, KnowledgeQueryRequest,
KnowledgeDocumentRequest, KnowledgeDocumentRequest,
ChunkQueryRequest, ChunkQueryRequest,
DocumentQueryRequest, DocumentQueryRequest,
) )
from pilot.embedding_engine.knowledge_type import KnowledgeType from dbgpt.rag.embedding_engine.knowledge_type import KnowledgeType
from pilot.server.knowledge.request.request import DocumentSyncRequest from dbgpt.app.knowledge.request.request import DocumentSyncRequest
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
HTTP_HEADERS = {"Content-Type": "application/json"} HTTP_HEADERS = {"Content-Type": "application/json"}

View File

@@ -5,19 +5,19 @@ import logging
from fastapi import APIRouter, File, UploadFile, Form from fastapi import APIRouter, File, UploadFile, Form
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.configs.model_config import ( from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
) )
from pilot.openapi.api_v1.api_v1 import no_stream_generator, stream_generator from dbgpt.app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
from pilot.openapi.api_view_model import Result from dbgpt.app.openapi.api_view_model import Result
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.server.knowledge.service import KnowledgeService from dbgpt.app.knowledge.service import KnowledgeService
from pilot.server.knowledge.request.request import ( from dbgpt.app.knowledge.request.request import (
KnowledgeQueryRequest, KnowledgeQueryRequest,
KnowledgeQueryResponse, KnowledgeQueryResponse,
KnowledgeDocumentRequest, KnowledgeDocumentRequest,
@@ -29,8 +29,8 @@ from pilot.server.knowledge.request.request import (
DocumentSummaryRequest, DocumentSummaryRequest,
) )
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from pilot.utils.tracer import root_tracer, SpanType from dbgpt.util.tracer import root_tracer, SpanType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -253,8 +253,8 @@ async def document_summary(request: DocumentSummaryRequest):
async def entity_extract(request: EntityExtractRequest): async def entity_extract(request: EntityExtractRequest):
logger.info(f"Received params: {request}") logger.info(f"Received params: {request}")
try: try:
from pilot.scene.base import ChatScene from dbgpt.app.scene import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream from dbgpt._private.chat_util import llm_chat_response_nostream
import uuid import uuid
chat_param = { chat_param = {

View File

@@ -3,14 +3,14 @@ from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from pilot.base_modules.meta_data.base_dao import BaseDao from dbgpt.storage.metadata import BaseDao
from pilot.base_modules.meta_data.meta_data import ( from dbgpt.storage.metadata.meta_data import (
Base, Base,
engine, engine,
session, session,
META_DATA_DATABASE, META_DATA_DATABASE,
) )
from pilot.configs.config import Config from dbgpt._private.config import Config
CFG = Config() CFG = Config()

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from pilot.base_modules.meta_data.base_dao import BaseDao from dbgpt.storage.metadata import BaseDao
from pilot.base_modules.meta_data.meta_data import ( from dbgpt.storage.metadata.meta_data import (
Base, Base,
engine, engine,
session, session,
META_DATA_DATABASE, META_DATA_DATABASE,
) )
from pilot.configs.config import Config from dbgpt._private.config import Config
CFG = Config() CFG = Config()

View File

@@ -1,6 +1,6 @@
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from dbgpt._private.pydantic import BaseModel
from fastapi import UploadFile from fastapi import UploadFile

View File

@@ -1,6 +1,6 @@
from typing import List from typing import List
from pydantic import BaseModel from dbgpt._private.pydantic import BaseModel
class ChunkQueryResponse(BaseModel): class ChunkQueryResponse(BaseModel):

View File

@@ -2,29 +2,29 @@ import json
import logging import logging
from datetime import datetime from datetime import datetime
from pilot.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.configs.model_config import ( from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
) )
from pilot.component import ComponentType from dbgpt.component import ComponentType
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from pilot.server.knowledge.chunk_db import ( from dbgpt.app.knowledge.chunk_db import (
DocumentChunkEntity, DocumentChunkEntity,
DocumentChunkDao, DocumentChunkDao,
) )
from pilot.server.knowledge.document_db import ( from dbgpt.app.knowledge.document_db import (
KnowledgeDocumentDao, KnowledgeDocumentDao,
KnowledgeDocumentEntity, KnowledgeDocumentEntity,
) )
from pilot.server.knowledge.space_db import ( from dbgpt.app.knowledge.space_db import (
KnowledgeSpaceDao, KnowledgeSpaceDao,
KnowledgeSpaceEntity, KnowledgeSpaceEntity,
) )
from pilot.server.knowledge.request.request import ( from dbgpt.app.knowledge.request.request import (
KnowledgeSpaceRequest, KnowledgeSpaceRequest,
KnowledgeDocumentRequest, KnowledgeDocumentRequest,
DocumentQueryRequest, DocumentQueryRequest,
@@ -35,7 +35,7 @@ from pilot.server.knowledge.request.request import (
) )
from enum import Enum from enum import Enum
from pilot.server.knowledge.request.response import ( from dbgpt.app.knowledge.request.response import (
ChunkQueryResponse, ChunkQueryResponse,
DocumentQueryResponse, DocumentQueryResponse,
SpaceQueryResponse, SpaceQueryResponse,
@@ -192,9 +192,9 @@ class KnowledgeService:
- space: Knowledge Space Name - space: Knowledge Space Name
- sync_request: DocumentSyncRequest - sync_request: DocumentSyncRequest
""" """
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter from dbgpt.rag.embedding_engine.pre_text_splitter import PreTextSplitter
from langchain.text_splitter import ( from langchain.text_splitter import (
RecursiveCharacterTextSplitter, RecursiveCharacterTextSplitter,
SpacyTextSplitter, SpacyTextSplitter,
@@ -432,7 +432,7 @@ class KnowledgeService:
f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store" f"async_knowledge_graph, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to graph store"
) )
try: try:
from pilot.rag.graph_engine.graph_factory import RAGGraphFactory from dbgpt.rag.graph_engine.graph_factory import RAGGraphFactory
rag_engine = CFG.SYSTEM_APP.get_component( rag_engine = CFG.SYSTEM_APP.get_component(
ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory ComponentType.RAG_GRAPH_DEFAULT.value, RAGGraphFactory
@@ -454,10 +454,10 @@ class KnowledgeService:
- doc: KnowledgeDocumentEntity - doc: KnowledgeDocumentEntity
""" """
texts = [doc.page_content for doc in chunk_docs] texts = [doc.page_content for doc in chunk_docs]
from pilot.common.prompt_util import PromptHelper from dbgpt.util.prompt_util import PromptHelper
prompt_helper = PromptHelper() prompt_helper = PromptHelper()
from pilot.scene.chat_knowledge.summary.prompt import prompt from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts) texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts)
logger.info( logger.info(
@@ -501,7 +501,7 @@ class KnowledgeService:
return knowledge_document_dao.update_knowledge_document(doc) return knowledge_document_dao.update_knowledge_document(doc)
def _build_default_context(self): def _build_default_context(self):
from pilot.scene.chat_knowledge.v1.prompt import ( from dbgpt.app.scene.chat_knowledge.v1.prompt import (
PROMPT_SCENE_DEFINE, PROMPT_SCENE_DEFINE,
_DEFAULT_TEMPLATE, _DEFAULT_TEMPLATE,
) )
@@ -556,7 +556,7 @@ class KnowledgeService:
Returns: Returns:
chat: BaseChat, refine summary chat. chat: BaseChat, refine summary chat.
""" """
from pilot.scene.base import ChatScene from dbgpt.app.scene import ChatScene
chat_param = { chat_param = {
"chat_session_id": conn_uid, "chat_session_id": conn_uid,
@@ -568,7 +568,7 @@ class KnowledgeService:
executor = CFG.SYSTEM_APP.get_component( executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create() ).create()
from pilot.openapi.api_v1.api_v1 import CHAT_FACTORY from dbgpt.app.openapi.api_v1.api_v1 import CHAT_FACTORY
chat = await blocking_func_to_async( chat = await blocking_func_to_async(
executor, executor,
@@ -596,8 +596,8 @@ class KnowledgeService:
Returns: Returns:
Document: refine summary context document. Document: refine summary context document.
""" """
from pilot.scene.base import ChatScene from dbgpt.app.scene import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream from dbgpt._private.chat_util import llm_chat_response_nostream
import uuid import uuid
tasks = [] tasks = []
@@ -618,7 +618,7 @@ class KnowledgeService:
ChatScene.ExtractSummary.value(), **{"chat_param": chat_param} ChatScene.ExtractSummary.value(), **{"chat_param": chat_param}
) )
) )
from pilot.common.chat_util import run_async_tasks from dbgpt._private.chat_util import run_async_tasks
summary_iters = await run_async_tasks( summary_iters = await run_async_tasks(
tasks=tasks, concurrency_limit=concurrency_limit tasks=tasks, concurrency_limit=concurrency_limit
@@ -629,8 +629,8 @@ class KnowledgeService:
summary_iters, summary_iters,
) )
) )
from pilot.common.prompt_util import PromptHelper from dbgpt.util.prompt_util import PromptHelper
from pilot.scene.chat_knowledge.summary.prompt import prompt from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
prompt_helper = PromptHelper() prompt_helper = PromptHelper()
summary_iters = prompt_helper.repack( summary_iters = prompt_helper.repack(

View File

@@ -2,15 +2,15 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from pilot.base_modules.meta_data.base_dao import BaseDao from dbgpt.storage.metadata import BaseDao
from pilot.base_modules.meta_data.meta_data import ( from dbgpt.storage.metadata.meta_data import (
Base, Base,
engine, engine,
session, session,
META_DATA_DATABASE, META_DATA_DATABASE,
) )
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
CFG = Config() CFG = Config()

View File

@@ -1,14 +1,12 @@
from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from pilot.component import ComponentType from dbgpt.component import ComponentType
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.model.cluster import WorkerStartupRequest, WorkerManagerFactory from dbgpt.model.cluster import WorkerStartupRequest, WorkerManagerFactory
from pilot.openapi.api_view_model import Result from dbgpt.app.openapi.api_view_model import Result
from pilot.server.llm_manage.request.request import ModelResponse from dbgpt.app.llm_manage.request.request import ModelResponse
CFG = Config() CFG = Config()
router = APIRouter() router = APIRouter()
@@ -18,7 +16,7 @@ router = APIRouter()
async def model_params(): async def model_params():
print(f"/worker/model/params") print(f"/worker/model/params")
try: try:
from pilot.model.cluster import WorkerManagerFactory from dbgpt.model.cluster import WorkerManagerFactory
worker_manager = CFG.SYSTEM_APP.get_component( worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
@@ -42,7 +40,7 @@ async def model_params():
async def model_list(): async def model_list():
print(f"/worker/model/list") print(f"/worker/model/list")
try: try:
from pilot.model.cluster.controller.controller import BaseModelController from dbgpt.model.cluster.controller.controller import BaseModelController
controller = CFG.SYSTEM_APP.get_component( controller = CFG.SYSTEM_APP.get_component(
ComponentType.MODEL_CONTROLLER, BaseModelController ComponentType.MODEL_CONTROLLER, BaseModelController
@@ -85,7 +83,7 @@ async def model_list():
async def model_stop(request: WorkerStartupRequest): async def model_stop(request: WorkerStartupRequest):
print(f"/v1/worker/model/stop:") print(f"/v1/worker/model/stop:")
try: try:
from pilot.model.cluster.controller.controller import BaseModelController from dbgpt.model.cluster.controller.controller import BaseModelController
worker_manager = CFG.SYSTEM_APP.get_component( worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory

View File

@@ -7,9 +7,9 @@ import sys
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG from dbgpt.configs.model_config import LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG
from pilot.model.cluster import run_worker_manager from dbgpt.model.cluster import run_worker_manager
CFG = Config() CFG = Config()

View File

@@ -6,23 +6,18 @@ import aiofiles
import logging import logging
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Request,
File, File,
UploadFile, UploadFile,
Form,
Body, Body,
BackgroundTasks,
Depends, Depends,
) )
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.exceptions import RequestValidationError
from typing import List, Optional from typing import List, Optional
import tempfile
from concurrent.futures import Executor from concurrent.futures import Executor
from pilot.component import ComponentType from dbgpt.component import ComponentType
from pilot.openapi.api_view_model import ( from dbgpt.app.openapi.api_view_model import (
Result, Result,
ConversationVo, ConversationVo,
MessageVo, MessageVo,
@@ -31,24 +26,20 @@ from pilot.openapi.api_view_model import (
DeltaMessage, DeltaMessage,
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
) )
from pilot.connections.db_conn_info import DBConfig, DbTypeInfo from dbgpt.datasource.db_conn_info import DBConfig, DbTypeInfo
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.server.knowledge.service import KnowledgeService from dbgpt.app.knowledge.service import KnowledgeService
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from pilot.scene.base_chat import BaseChat from dbgpt.app.scene import BaseChat, ChatScene, ChatFactory
from pilot.scene.base import ChatScene from dbgpt.core.interface.message import OnceConversation
from pilot.scene.chat_factory import ChatFactory from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.common.schema import DBType from dbgpt.rag.summary.db_summary_client import DBSummaryClient
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
from pilot.scene.message import OnceConversation from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from dbgpt.model.base import FlatSupportedModel
from pilot.summary.db_summary_client import DBSummaryClient from dbgpt.util.tracer import root_tracer, SpanType
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory from dbgpt.util.executor_utils import (
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
from pilot.utils.tracer import root_tracer, SpanType
from pilot.utils.executor_utils import (
ExecutorFactory, ExecutorFactory,
blocking_func_to_async, blocking_func_to_async,
DefaultExecutorFactory, DefaultExecutorFactory,

View File

@@ -8,14 +8,14 @@ from fastapi import (
from typing import List from typing import List
import logging import logging
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.scene.chat_factory import ChatFactory from dbgpt.app.scene import ChatFactory
from pilot.openapi.api_view_model import ( from dbgpt.app.openapi.api_view_model import (
Result, Result,
) )
from pilot.openapi.editor_view_model import ( from dbgpt.app.openapi.editor_view_model import (
ChatDbRounds, ChatDbRounds,
ChartList, ChartList,
ChartDetail, ChartDetail,
@@ -24,11 +24,15 @@ from pilot.openapi.editor_view_model import (
DbTable, DbTable,
) )
from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData from dbgpt.app.openapi.api_v1.editor.sql_editor import (
from pilot.scene.message import OnceConversation DataNode,
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader ChartRunData,
from pilot.scene.chat_db.data_loader import DbDataLoader SqlRunData,
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory )
from dbgpt.core.interface.message import OnceConversation
from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader
from dbgpt.app.scene.chat_db.data_loader import DbDataLoader
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
router = APIRouter() router = APIRouter()
CFG = Config() CFG = Config()
@@ -56,8 +60,8 @@ async def get_editor_tables(
key=field[0], key=field[0],
type=field[1], type=field[1],
default_value=field[2], default_value=field[2],
can_null=field[3], can_null=field[3] or "YES",
comment=field[-1], comment=str(field[-1]),
) )
) )
@@ -145,7 +149,7 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}") logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
chat_history_fac = ChatHistory() chat_history_fac = ChatHistory()
history_mem = chat_history_fac.get_store_instance(sql_edit_context.con_uid) history_mem = chat_history_fac.get_store_instance(sql_edit_context.conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name) conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)

View File

@@ -1,6 +1,6 @@
from typing import List from typing import List
from pydantic import BaseModel, Field, root_validator, validator, Extra from dbgpt._private.pydantic import BaseModel
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem
class DataNode(BaseModel): class DataNode(BaseModel):

View File

@@ -1,11 +1,10 @@
from fastapi import APIRouter, Body, Request from fastapi import APIRouter, Body, Request
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
from pilot.openapi.api_v1.feedback.feed_back_db import ( from dbgpt.app.openapi.api_v1.feedback.feed_back_db import (
ChatFeedBackDao, ChatFeedBackDao,
ChatFeedBackEntity,
) )
from pilot.openapi.api_view_model import Result from dbgpt.app.openapi.api_view_model import Result
router = APIRouter() router = APIRouter()
chat_feed_back = ChatFeedBackDao() chat_feed_back = ChatFeedBackDao()

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from pilot.base_modules.meta_data.base_dao import BaseDao from dbgpt.storage.metadata import BaseDao
from pilot.base_modules.meta_data.meta_data import ( from dbgpt.storage.metadata.meta_data import (
Base, Base,
engine, engine,
session, session,
META_DATA_DATABASE, META_DATA_DATABASE,
) )
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody from dbgpt.app.openapi.api_v1.feedback.feed_back_model import FeedBackBody
class ChatFeedBackEntity(Base): class ChatFeedBackEntity(Base):

View File

@@ -1,4 +1,4 @@
from pydantic.main import BaseModel from dbgpt._private.pydantic import BaseModel
from typing import Optional from typing import Optional

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field from dbgpt._private.pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any, Optional, Literal, List from typing import TypeVar, Generic, Any, Optional, Literal, List
import uuid import uuid
import time import time

View File

@@ -1,6 +1,6 @@
from fastapi import Request from fastapi import Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from pilot.openapi.api_view_model import Result from dbgpt.app.openapi.api_view_model import Result
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field from dbgpt._private.pydantic import BaseModel, Field
from typing import List, Any from typing import List, Any

View File

@@ -1,8 +1,8 @@
from fastapi import APIRouter, File, UploadFile, Form from fastapi import APIRouter
from pilot.openapi.api_view_model import Result from dbgpt.app.openapi.api_view_model import Result
from pilot.server.prompt.service import PromptManageService from dbgpt.app.prompt.service import PromptManageService
from pilot.server.prompt.request.request import PromptManageRequest from dbgpt.app.prompt.request.request import PromptManageRequest
router = APIRouter() router = APIRouter()

View File

@@ -2,16 +2,16 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from pilot.base_modules.meta_data.base_dao import BaseDao from dbgpt.storage.metadata import BaseDao
from pilot.base_modules.meta_data.meta_data import ( from dbgpt.storage.metadata.meta_data import (
Base, Base,
engine, engine,
session, session,
META_DATA_DATABASE, META_DATA_DATABASE,
) )
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.server.prompt.request.request import PromptManageRequest from dbgpt.app.prompt.request.request import PromptManageRequest
CFG = Config() CFG = Config()

View File

@@ -1,8 +1,8 @@
from typing import List from typing import List
from pydantic import BaseModel from dbgpt._private.pydantic import BaseModel
from typing import Optional from typing import Optional
from pydantic import BaseModel from dbgpt._private.pydantic import BaseModel
class PromptManageRequest(BaseModel): class PromptManageRequest(BaseModel):

View File

@@ -1,5 +1,5 @@
from typing import List from typing import List
from pydantic import BaseModel from dbgpt._private.pydantic import BaseModel
class PromptQueryResponse(BaseModel): class PromptQueryResponse(BaseModel):

View File

@@ -1,8 +1,8 @@
from datetime import datetime from datetime import datetime
from pilot.server.prompt.request.request import PromptManageRequest from dbgpt.app.prompt.request.request import PromptManageRequest
from pilot.server.prompt.request.response import PromptQueryResponse from dbgpt.app.prompt.request.response import PromptQueryResponse
from pilot.server.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity from dbgpt.app.prompt.prompt_manage_db import PromptManageDao, PromptManageEntity
prompt_manage_dao = PromptManageDao() prompt_manage_dao = PromptManageDao()

View File

@@ -0,0 +1,3 @@
from dbgpt.app.scene.base_chat import BaseChat
from dbgpt.app.scene.chat_factory import ChatFactory
from dbgpt.app.scene.base import ChatScene

View File

@@ -6,18 +6,18 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Dict from typing import Any, List, Dict
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.component import ComponentType from dbgpt.component import ComponentType
from pilot.prompts.prompt_new import PromptTemplate from dbgpt.core.interface.prompt import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation from dbgpt.core.interface.message import OnceConversation
from pilot.utils import get_or_create_event_loop from dbgpt.util import get_or_create_event_loop
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace from dbgpt.util.tracer import root_tracer, trace
from pydantic import Extra from dbgpt._private.pydantic import Extra
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
from pilot.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG from dbgpt.core.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG
from pilot.model.operator.model_operator import ModelOperator, ModelStreamOperator from dbgpt.model.operator.model_operator import ModelOperator, ModelStreamOperator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
headers = {"User-Agent": "dbgpt Client"} headers = {"User-Agent": "dbgpt Client"}
@@ -535,16 +535,16 @@ def _build_model_operator(
Returns: Returns:
BaseOperator: The final operator in the constructed DAG, typically a join node. BaseOperator: The final operator in the constructed DAG, typically a join node.
""" """
from pilot.model.cluster import WorkerManagerFactory from dbgpt.model.cluster import WorkerManagerFactory
from pilot.awel import JoinOperator from dbgpt.core.awel import JoinOperator
from pilot.model.operator.model_operator import ( from dbgpt.model.operator.model_operator import (
ModelCacheBranchOperator, ModelCacheBranchOperator,
CachedModelStreamOperator, CachedModelStreamOperator,
CachedModelOperator, CachedModelOperator,
ModelSaveCacheOperator, ModelSaveCacheOperator,
ModelStreamSaveCacheOperator, ModelStreamSaveCacheOperator,
) )
from pilot.cache import CacheManager from dbgpt.storage.cache import CacheManager
# Fetch worker and cache managers from the system configuration # Fetch worker and cache managers from the system configuration
worker_manager = CFG.SYSTEM_APP.get_component( worker_manager = CFG.SYSTEM_APP.get_component(

View File

@@ -1,17 +1,13 @@
from typing import List, Dict from typing import List, Dict
import logging import logging
from pilot.scene.base_chat import BaseChat from dbgpt.app.scene import BaseChat, ChatScene
from pilot.scene.base import ChatScene from dbgpt._private.config import Config
from pilot.configs.config import Config from dbgpt.agent.commands.command_mange import ApiCall
from pilot.base_modules.agent.commands.command import execute_command from dbgpt.agent import PluginPromptGenerator
from pilot.base_modules.agent.commands.command_mange import ApiCall from dbgpt.component import ComponentType
from pilot.base_modules.agent import PluginPromptGenerator from dbgpt.agent.controller import ModuleAgent
from pilot.common.string_utils import extract_content from dbgpt.util.tracer import root_tracer, trace
from .prompt import prompt
from pilot.component import ComponentType
from pilot.base_modules.agent.controller import ModuleAgent
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()

View File

@@ -1,4 +1,4 @@
from pilot.prompts.example_base import ExampleSelector from dbgpt.core._private.example_base import ExampleSelector
## Two examples are defined by default ## Two examples are defined by default
EXAMPLES = [ EXAMPLES = [

View File

@@ -1,6 +1,5 @@
import json
from typing import Dict, NamedTuple from typing import Dict, NamedTuple
from pilot.out_parser.base import BaseOutputParser, T from dbgpt.core.interface.output_parser import BaseOutputParser
class PluginAction(NamedTuple): class PluginAction(NamedTuple):

View File

@@ -1,11 +1,8 @@
import json from dbgpt.core.interface.prompt import PromptTemplate
from pilot.prompts.prompt_new import PromptTemplate from dbgpt._private.config import Config
from pilot.configs.config import Config from dbgpt.app.scene import ChatScene
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle, ExampleType
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser from dbgpt.app.scene.chat_execution.out_parser import PluginChatOutputParser
from pilot.scene.chat_execution.example import plugin_example
CFG = Config() CFG = Config()
@@ -65,8 +62,6 @@ _PROMPT_SCENE_DEFINE = (
RESPONSE_FORMAT = None RESPONSE_FORMAT = None
EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output ### Whether the model service is streaming output
PROMPT_NEED_STREAM_OUT = True PROMPT_NEED_STREAM_OUT = True
@@ -77,9 +72,7 @@ prompt = PromptTemplate(
template_define=_PROMPT_SCENE_DEFINE, template_define=_PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT, stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser( output_parser=PluginChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
temperature=1 temperature=1
# example_selector=plugin_example, # example_selector=plugin_example,
) )

View File

@@ -3,17 +3,15 @@ import os
import uuid import uuid
from typing import List, Dict from typing import List, Dict
from pilot.scene.base_chat import BaseChat from dbgpt.app.scene import BaseChat, ChatScene
from pilot.scene.base import ChatScene from dbgpt._private.config import Config
from pilot.configs.config import Config from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import (
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData, ChartData,
ReportData, ReportData,
) )
from pilot.scene.chat_dashboard.prompt import prompt from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader from dbgpt.util.executor_utils import blocking_func_to_async
from pilot.utils.executor_utils import blocking_func_to_async from dbgpt.util.tracer import trace
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -57,7 +55,7 @@ class ChatDashboard(BaseChat):
@trace() @trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
try: try:
from pilot.summary.db_summary_client import DBSummaryClient from dbgpt.rag.summary.db_summary_client import DBSummaryClient
except ImportError: except ImportError:
raise ValueError("Could not import DBSummaryClient. ") raise ValueError("Could not import DBSummaryClient. ")

View File

@@ -2,8 +2,8 @@ from typing import List
from decimal import Decimal from decimal import Decimal
import logging import logging
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem
CFG = Config() CFG = Config()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel from dbgpt._private.pydantic import BaseModel
from typing import List, Any from typing import List, Any

View File

@@ -2,8 +2,8 @@ import json
import logging import logging
from typing import NamedTuple, List from typing import NamedTuple, List
from pilot.out_parser.base import BaseOutputParser, T from dbgpt.core.interface.output_parser import BaseOutputParser
from pilot.scene.base import ChatScene from dbgpt.app.scene import ChatScene
class ChartItem(NamedTuple): class ChartItem(NamedTuple):
@@ -17,8 +17,8 @@ logger = logging.getLogger(__name__)
class ChatDashboardOutputParser(BaseOutputParser): class ChatDashboardOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool): def __init__(self, is_stream_out: bool, **kwargs):
super().__init__(sep=sep, is_stream_out=is_stream_out) super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_prompt_response(self, model_out_text): def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text) clean_str = super().parse_prompt_response(model_out_text)

View File

@@ -1,9 +1,8 @@
import json import json
from pilot.prompts.prompt_new import PromptTemplate from dbgpt.core.interface.prompt import PromptTemplate
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.scene.base import ChatScene from dbgpt.app.scene import ChatScene
from pilot.scene.chat_dashboard.out_parser import ChatDashboardOutputParser, ChartItem from dbgpt.app.scene.chat_dashboard.out_parser import ChatDashboardOutputParser
from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
@@ -40,7 +39,6 @@ RESPONSE_FORMAT = [
} }
] ]
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_STREAM_OUT = False PROMPT_NEED_STREAM_OUT = False
@@ -51,8 +49,6 @@ prompt = PromptTemplate(
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT, stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=ChatDashboardOutputParser( output_parser=ChatDashboardOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
) )
CFG.prompt_template_registry.register(prompt, is_default=True) CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -1,23 +1,20 @@
import json
import os import os
import asyncio import logging
from typing import List, Any, Dict from typing import Dict
from pilot.scene.base_chat import BaseChat, logger from dbgpt.app.scene import BaseChat, ChatScene
from pilot.scene.base import ChatScene from dbgpt._private.config import Config
from pilot.common.sql_database import Database from dbgpt.agent.commands.command_mange import ApiCall
from pilot.configs.config import Config from dbgpt.app.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.base_modules.agent.commands.command_mange import ApiCall from dbgpt.app.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt from dbgpt.util.path_utils import has_path
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning from dbgpt.util.tracer import root_tracer, trace
from pilot.common.path_utils import has_path
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.base_modules.agent.common.schema import Status
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
logger = logging.getLogger(__name__)
class ChatExcel(BaseChat): class ChatExcel(BaseChat):
"""a Excel analyzer to analyze Excel Data""" """a Excel analyzer to analyze Excel Data"""

View File

@@ -1,8 +1,8 @@
import json import json
import logging import logging
from typing import Dict, NamedTuple, List from typing import NamedTuple
from pilot.out_parser.base import BaseOutputParser, T from dbgpt.core.interface.output_parser import BaseOutputParser
from pilot.configs.config import Config from dbgpt._private.config import Config
CFG = Config() CFG = Config()
@@ -17,8 +17,8 @@ logger = logging.getLogger(__name__)
class ChatExcelOutputParser(BaseOutputParser): class ChatExcelOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool): def __init__(self, is_stream_out: bool, **kwargs):
super().__init__(sep=sep, is_stream_out=is_stream_out) super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_prompt_response(self, model_out_text): def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text) clean_str = super().parse_prompt_response(model_out_text)

View File

@@ -1,11 +1,9 @@
import json from dbgpt.core.interface.prompt import PromptTemplate
from pilot.prompts.prompt_new import PromptTemplate from dbgpt._private.config import Config
from pilot.configs.config import Config from dbgpt.app.scene import ChatScene
from pilot.scene.base import ChatScene from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.out_parser import (
from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import (
ChatExcelOutputParser, ChatExcelOutputParser,
) )
from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
@@ -53,7 +51,6 @@ _PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
) )
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_STREAM_OUT = True PROMPT_NEED_STREAM_OUT = True
@@ -68,9 +65,7 @@ prompt = PromptTemplate(
template_define=_PROMPT_SCENE_DEFINE, template_define=_PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT, stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=ChatExcelOutputParser( output_parser=ChatExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
need_historical_messages=True, need_historical_messages=True,
# example_selector=sql_data_example, # example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE, temperature=PROMPT_TEMPERATURE,

View File

@@ -1,18 +1,11 @@
import json import json
from typing import Any, Dict from typing import Any, Dict
from pilot.scene.base_message import HumanMessage, ViewMessage, AIMessage from dbgpt.core.interface.message import ViewMessage, AIMessage
from pilot.scene.base_chat import BaseChat from dbgpt.app.scene import BaseChat, ChatScene
from pilot.scene.base import ChatScene from dbgpt.util.json_utils import DateTimeEncoder
from pilot.common.sql_database import Database from dbgpt.util.executor_utils import blocking_func_to_async
from pilot.configs.config import Config from dbgpt.util.tracer import trace
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.json_utils.utilities import DateTimeEncoder
from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config()
class ExcelLearning(BaseChat): class ExcelLearning(BaseChat):

View File

@@ -1,10 +1,7 @@
import json import json
import logging import logging
from typing import Dict, NamedTuple, List from typing import NamedTuple, List
from pilot.out_parser.base import BaseOutputParser, T from dbgpt.core.interface.output_parser import BaseOutputParser
from pilot.configs.config import Config
CFG = Config()
class ExcelResponse(NamedTuple): class ExcelResponse(NamedTuple):
@@ -17,8 +14,8 @@ logger = logging.getLogger(__name__)
class LearningExcelOutputParser(BaseOutputParser): class LearningExcelOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool): def __init__(self, is_stream_out: bool, **kwargs):
super().__init__(sep=sep, is_stream_out=is_stream_out) super().__init__(is_stream_out=is_stream_out, **kwargs)
self.is_downgraded = False self.is_downgraded = False
def parse_prompt_response(self, model_out_text): def parse_prompt_response(self, model_out_text):

View File

@@ -1,11 +1,10 @@
import json import json
from pilot.prompts.prompt_new import PromptTemplate from dbgpt.core.interface.prompt import PromptTemplate
from pilot.configs.config import Config from dbgpt._private.config import Config
from pilot.scene.base import ChatScene from dbgpt.app.scene import ChatScene
from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import ( from dbgpt.app.scene.chat_data.chat_excel.excel_learning.out_parser import (
LearningExcelOutputParser, LearningExcelOutputParser,
) )
from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
@@ -66,8 +65,6 @@ PROMPT_SCENE_DEFINE = (
) )
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_STREAM_OUT = False PROMPT_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output. # Temperature is a configuration hyperparameter that controls the randomness of language model output.
@@ -82,9 +79,7 @@ prompt = PromptTemplate(
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT, stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=LearningExcelOutputParser( output_parser=LearningExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
# example_selector=sql_data_example, # example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE, temperature=PROMPT_TEMPERATURE,
) )

Some files were not shown because too many files have changed in this diff Show More