mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
refactor: Refactor for core SDK (#1092)
This commit is contained in:
parent
ba7248adbb
commit
2d905191f8
13
Makefile
13
Makefile
@ -81,6 +81,19 @@ clean: ## Clean up the environment
|
||||
find . -type d -name '.pytest_cache' -delete
|
||||
find . -type d -name '.coverage' -delete
|
||||
|
||||
.PHONY: clean-dist
|
||||
clean-dist: ## Clean up the distribution
|
||||
rm -rf dist/ *.egg-info build/
|
||||
|
||||
.PHONY: package
|
||||
package: clean-dist ## Package the project for distribution
|
||||
IS_DEV_MODE=false python setup.py sdist bdist_wheel
|
||||
|
||||
.PHONY: upload
|
||||
upload: package ## Upload the package to PyPI
|
||||
# upload to testpypi: twine upload --repository testpypi dist/*
|
||||
twine upload dist/*
|
||||
|
||||
.PHONY: help
|
||||
help: ## Display this help screen
|
||||
@echo "Available commands:"
|
||||
|
@ -1,12 +1,18 @@
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
|
||||
__ALL__ = ["SystemApp", "BaseComponent"]
|
||||
"""DB-GPT: Next Generation Data Interaction Solution with LLMs.
|
||||
"""
|
||||
from dbgpt import _version # noqa: E402
|
||||
from dbgpt.component import BaseComponent, SystemApp # noqa: F401
|
||||
|
||||
_CORE_LIBS = ["core", "rag", "model", "agent", "datasource", "vis", "storage", "train"]
|
||||
_SERVE_LIBS = ["serve"]
|
||||
_LIBS = _CORE_LIBS + _SERVE_LIBS
|
||||
|
||||
|
||||
__version__ = _version.version
|
||||
|
||||
__ALL__ = ["__version__", "SystemApp", "BaseComponent"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
# Lazy load
|
||||
import importlib
|
||||
|
1
dbgpt/_version.py
Normal file
1
dbgpt/_version.py
Normal file
@ -0,0 +1 @@
|
||||
version = "0.4.7"
|
@ -579,7 +579,7 @@ if __name__ == "__main__":
|
||||
from dbgpt.agent.agents.agent import AgentContext
|
||||
from dbgpt.agent.agents.user_proxy_agent import UserProxyAgent
|
||||
from dbgpt.core.interface.llm import ModelMetadata
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
context: AgentContext = AgentContext(
|
||||
|
@ -9,7 +9,7 @@ from functools import cache
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.model.conversation import Conversation, get_conv_template
|
||||
from dbgpt.model.llm.conversation import Conversation, get_conv_template
|
||||
|
||||
|
||||
class BaseChatAdpter:
|
||||
@ -21,7 +21,7 @@ class BaseChatAdpter:
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
"""Return the generate stream handler func"""
|
||||
from dbgpt.model.inference import generate_stream
|
||||
from dbgpt.model.llm.inference import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
|
@ -171,13 +171,13 @@ class BaseChat(ABC):
|
||||
|
||||
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
|
||||
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
|
||||
return await llm_task.call(call_data={"data": request})
|
||||
return await llm_task.call(call_data=request)
|
||||
|
||||
async def call_streaming_operator(
|
||||
self, request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
|
||||
async for out in await llm_task.call_stream(call_data={"data": request}):
|
||||
async for out in await llm_task.call_stream(call_data=request):
|
||||
yield out
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
@ -251,11 +251,9 @@ class BaseChat(ABC):
|
||||
str_history=self.prompt_template.str_history,
|
||||
request_context=req_ctx,
|
||||
)
|
||||
node_input = {
|
||||
"data": ChatComposerInput(
|
||||
messages=self.history_messages, prompt_dict=input_values
|
||||
)
|
||||
}
|
||||
node_input = ChatComposerInput(
|
||||
messages=self.history_messages, prompt_dict=input_values
|
||||
)
|
||||
# llm_messages = self.generate_llm_messages()
|
||||
model_request: ModelRequest = await node.call(call_data=node_input)
|
||||
model_request.context.cache_enable = self.model_cache_enable
|
||||
|
@ -87,7 +87,7 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
|
||||
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
messages = await end_node.call(
|
||||
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
||||
call_data=input_value, dag_ctx=self.current_dag_context
|
||||
)
|
||||
span_id = self._request_context.span_id
|
||||
model_request = ModelRequest.build_request(
|
||||
|
@ -1,3 +1,7 @@
|
||||
"""Component module for dbgpt.
|
||||
|
||||
Manages the lifecycle and registration of components.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
@ -22,6 +22,7 @@ from .operator.common_operator import (
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
ReduceStreamOperator,
|
||||
TriggerOperator,
|
||||
)
|
||||
from .operator.stream_operator import (
|
||||
StreamifyAbsOperator,
|
||||
@ -50,6 +51,7 @@ __all__ = [
|
||||
"BaseOperator",
|
||||
"JoinOperator",
|
||||
"ReduceStreamOperator",
|
||||
"TriggerOperator",
|
||||
"MapOperator",
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
@ -150,4 +152,6 @@ def setup_dev_environment(
|
||||
for trigger in dag.trigger_nodes:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
if trigger_manager.keep_running():
|
||||
# Should keep running
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
@ -28,7 +28,7 @@ from ..task.base import OUT, T, TaskOutput
|
||||
|
||||
F = TypeVar("F", bound=FunctionType)
|
||||
|
||||
CALL_DATA = Union[Dict, Dict[str, Dict]]
|
||||
CALL_DATA = Union[Dict[str, Any], Any]
|
||||
|
||||
|
||||
class WorkflowRunner(ABC, Generic[T]):
|
||||
@ -197,6 +197,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
if call_data:
|
||||
call_data = {"data": call_data}
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, exist_dag_ctx=dag_ctx
|
||||
)
|
||||
@ -242,6 +244,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
if call_data:
|
||||
call_data = {"data": call_data}
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
||||
)
|
||||
|
@ -28,6 +28,14 @@ EMPTY_DATA = _EMPTY_DATA_TYPE()
|
||||
SKIP_DATA = _EMPTY_DATA_TYPE()
|
||||
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
|
||||
|
||||
|
||||
def is_empty_data(data: Any):
|
||||
"""Check if the data is empty."""
|
||||
if isinstance(data, _EMPTY_DATA_TYPE):
|
||||
return data in (EMPTY_DATA, SKIP_DATA)
|
||||
return False
|
||||
|
||||
|
||||
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
|
||||
|
@ -24,7 +24,6 @@ from .base import (
|
||||
EMPTY_DATA,
|
||||
OUT,
|
||||
PLACEHOLDER_DATA,
|
||||
SKIP_DATA,
|
||||
InputContext,
|
||||
InputSource,
|
||||
MapFunc,
|
||||
@ -37,6 +36,7 @@ from .base import (
|
||||
TaskState,
|
||||
TransformFunc,
|
||||
UnStreamFunc,
|
||||
is_empty_data,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -99,7 +99,7 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
return is_empty_data(self._data)
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
@ -171,7 +171,7 @@ class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
return is_empty_data(self._data)
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
@ -330,7 +330,7 @@ class SimpleCallDataInputSource(BaseInputSource):
|
||||
"""
|
||||
call_data = task_ctx.call_data
|
||||
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
|
||||
if data == EMPTY_DATA:
|
||||
if is_empty_data(data):
|
||||
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||
return data
|
||||
|
||||
|
@ -1,12 +1,8 @@
|
||||
"""Http trigger for AWEL."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
from ..dag.base import DAG
|
||||
@ -15,9 +11,10 @@ from .base import Trigger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter
|
||||
from starlette.requests import Request
|
||||
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -32,9 +29,9 @@ class HttpTrigger(Trigger):
|
||||
self,
|
||||
endpoint: str,
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
request_body: Optional["RequestBody"] = None,
|
||||
streaming_response: bool = False,
|
||||
streaming_predict_func: Optional[StreamingPredictFunc] = None,
|
||||
streaming_predict_func: Optional["StreamingPredictFunc"] = None,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
@ -69,6 +66,7 @@ class HttpTrigger(Trigger):
|
||||
router (APIRouter): The router to mount the trigger.
|
||||
"""
|
||||
from fastapi import Depends
|
||||
from starlette.requests import Request
|
||||
|
||||
methods = [self._methods] if isinstance(self._methods, str) else self._methods
|
||||
|
||||
@ -114,8 +112,10 @@ class HttpTrigger(Trigger):
|
||||
|
||||
|
||||
async def _parse_request_body(
|
||||
request: Request, request_body_cls: Optional[RequestBody]
|
||||
request: "Request", request_body_cls: Optional["RequestBody"]
|
||||
):
|
||||
from starlette.requests import Request
|
||||
|
||||
if not request_body_cls:
|
||||
return None
|
||||
if request_body_cls == Request:
|
||||
@ -152,7 +152,7 @@ async def _trigger_dag(
|
||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
if not streaming_response:
|
||||
return await end_node.call(call_data={"data": body})
|
||||
return await end_node.call(call_data=body)
|
||||
else:
|
||||
headers = response_headers
|
||||
media_type = response_media_type if response_media_type else "text/event-stream"
|
||||
@ -163,7 +163,7 @@ async def _trigger_dag(
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
generator = await end_node.call_stream(call_data={"data": body})
|
||||
generator = await end_node.call_stream(call_data=body)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(dag._after_dag_end)
|
||||
return StreamingResponse(
|
||||
|
@ -24,6 +24,14 @@ class TriggerManager(ABC):
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
"""Register a trigger to current manager."""
|
||||
|
||||
def keep_running(self) -> bool:
|
||||
"""Whether keep running.
|
||||
|
||||
Returns:
|
||||
bool: Whether keep running, True means keep running, False means stop.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class HttpTriggerManager(TriggerManager):
|
||||
"""Http trigger manager.
|
||||
@ -64,6 +72,8 @@ class HttpTriggerManager(TriggerManager):
|
||||
self._trigger_map[trigger_id] = trigger
|
||||
|
||||
def _init_app(self, system_app: SystemApp):
|
||||
if not self.keep_running():
|
||||
return
|
||||
logger.info(
|
||||
f"Include router {self._router} to prefix path {self._router_prefix}"
|
||||
)
|
||||
@ -72,6 +82,14 @@ class HttpTriggerManager(TriggerManager):
|
||||
raise RuntimeError("System app not initialized")
|
||||
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
|
||||
|
||||
def keep_running(self) -> bool:
|
||||
"""Whether keep running.
|
||||
|
||||
Returns:
|
||||
bool: Whether keep running, True means keep running, False means stop.
|
||||
"""
|
||||
return len(self._trigger_map) > 0
|
||||
|
||||
|
||||
class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
"""Default trigger manager for AWEL.
|
||||
@ -105,3 +123,11 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
"""After register, init the trigger manager."""
|
||||
if self.system_app:
|
||||
self.http_trigger._init_app(self.system_app)
|
||||
|
||||
def keep_running(self) -> bool:
|
||||
"""Whether keep running.
|
||||
|
||||
Returns:
|
||||
bool: Whether keep running, True means keep running, False means stop.
|
||||
"""
|
||||
return self.http_trigger.keep_running()
|
||||
|
@ -70,7 +70,7 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
||||
end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0])
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
return await end_node.call(
|
||||
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
||||
call_data=input_value, dag_ctx=self.current_dag_context
|
||||
)
|
||||
|
||||
def _build_composer_dag(self) -> DAG:
|
||||
|
@ -150,7 +150,7 @@ class PromptBuilderOperator(
|
||||
)
|
||||
)
|
||||
|
||||
single_input = {"data": {"dialect": "mysql"}}
|
||||
single_input = {"dialect": "mysql"}
|
||||
single_expected_messages = [
|
||||
ModelMessage(
|
||||
content="Please write a mysql SQL count the length of a field",
|
||||
|
@ -1,9 +1,12 @@
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
try:
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
except ImportError as exc:
|
||||
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
|
||||
DefaultLLMClient = None
|
||||
|
||||
# from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
|
||||
__ALL__ = [
|
||||
"DefaultLLMClient",
|
||||
"OpenAILLMClient",
|
||||
]
|
||||
_exports = []
|
||||
if DefaultLLMClient:
|
||||
_exports.append("DefaultLLMClient")
|
||||
|
||||
__ALL__ = _exports
|
||||
|
@ -137,7 +137,7 @@ class ModelLoader:
|
||||
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
|
||||
import torch
|
||||
|
||||
from dbgpt.model.compression import compress_module
|
||||
from dbgpt.model.llm.compression import compress_module
|
||||
|
||||
device = model_params.device
|
||||
max_memory = None
|
@ -19,7 +19,7 @@ from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.conversation import Conversation
|
||||
from dbgpt.model.llm.conversation import Conversation
|
||||
from dbgpt.model.parameter import (
|
||||
LlamaCppModelParameters,
|
||||
ModelParameters,
|
||||
|
@ -1,13 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from dbgpt.util.model_utils import GPUInfo
|
||||
from dbgpt.util.parameter_utils import ParameterDescription
|
||||
|
||||
|
||||
|
@ -12,9 +12,9 @@ from dbgpt.core import (
|
||||
ModelOutput,
|
||||
)
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
from dbgpt.model.adapter.loader import ModelLoader, _get_model_real_path
|
||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.loader import ModelLoader, _get_model_real_path
|
||||
from dbgpt.model.parameter import ModelParameters
|
||||
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
|
||||
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
|
||||
|
@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Type
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import ModelMetadata
|
||||
from dbgpt.model.adapter.loader import _get_model_real_path
|
||||
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.loader import _get_model_real_path
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
|
@ -8,7 +8,7 @@ import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List
|
||||
from typing import Awaitable, Callable, Iterator
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
@ -16,12 +16,7 @@ from fastapi.responses import StreamingResponse
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.configs.model_config import LOGDIR
|
||||
from dbgpt.core import ModelMetadata, ModelOutput
|
||||
from dbgpt.model.base import (
|
||||
ModelInstance,
|
||||
WorkerApplyOutput,
|
||||
WorkerApplyType,
|
||||
WorkerSupportedModel,
|
||||
)
|
||||
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
|
||||
from dbgpt.model.cluster.base import *
|
||||
from dbgpt.model.cluster.manager_base import (
|
||||
WorkerManager,
|
||||
@ -30,8 +25,8 @@ from dbgpt.model.cluster.manager_base import (
|
||||
)
|
||||
from dbgpt.model.cluster.registry import ModelRegistry
|
||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||
from dbgpt.model.llm_utils import list_supported_models
|
||||
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
||||
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
|
||||
from dbgpt.model.utils.llm_utils import list_supported_models
|
||||
from dbgpt.util.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
ParameterDescription,
|
||||
|
@ -18,7 +18,7 @@ from transformers.generation.logits_process import (
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
from dbgpt.model.llm_utils import is_partial_stop, is_sentence_complete
|
||||
from dbgpt.model.utils.llm_utils import is_partial_stop, is_sentence_complete
|
||||
|
||||
|
||||
def prepare_logits_processor(
|
@ -1,9 +1,9 @@
|
||||
from dbgpt.model.operator.llm_operator import (
|
||||
from dbgpt.model.operator.llm_operator import ( # noqa: F401
|
||||
LLMOperator,
|
||||
MixinLLMOperator,
|
||||
StreamingLLMOperator,
|
||||
)
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
"MixinLLMOperator",
|
||||
|
@ -6,7 +6,6 @@ from dbgpt.component import ComponentType
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -19,31 +18,30 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
||||
|
||||
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(default_client)
|
||||
self._default_llm_client = default_client
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
if not self._llm_client:
|
||||
worker_manager_factory: WorkerManagerFactory = (
|
||||
self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
default_component=None,
|
||||
)
|
||||
)
|
||||
if worker_manager_factory:
|
||||
try:
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
|
||||
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
||||
else:
|
||||
if self._default_llm_client is None:
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
|
||||
self._default_llm_client = OpenAILLMClient()
|
||||
logger.info(
|
||||
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
|
||||
worker_manager_factory: WorkerManagerFactory = (
|
||||
self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY,
|
||||
WorkerManagerFactory,
|
||||
default_component=None,
|
||||
)
|
||||
)
|
||||
self._llm_client = self._default_llm_client
|
||||
if worker_manager_factory:
|
||||
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
||||
except Exception as e:
|
||||
logger.warning(f"Load worker manager failed: {e}.")
|
||||
if not self._llm_client:
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
|
||||
logger.info("Can't find worker manager factory, use OpenAILLMClient.")
|
||||
self._llm_client = OpenAILLMClient()
|
||||
return self._llm_client
|
||||
|
||||
|
||||
|
@ -6,11 +6,8 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
from dbgpt.model.conversation import conv_templates
|
||||
from dbgpt.util.parameter_utils import BaseParameters
|
||||
|
||||
suported_prompt_templates = ",".join(conv_templates.keys())
|
||||
|
||||
|
||||
class WorkerType(str, Enum):
|
||||
LLM = "llm"
|
||||
@ -299,7 +296,8 @@ class ModelParameters(BaseModelParameters):
|
||||
prompt_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
|
||||
"help": f"Prompt template. If None, the prompt template is automatically "
|
||||
f"determined from model path"
|
||||
},
|
||||
)
|
||||
max_context_size: Optional[int] = field(
|
||||
@ -450,7 +448,8 @@ class ProxyModelParameters(BaseModelParameters):
|
||||
proxyllm_backend: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The model name actually pass to current proxy server url, such as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
|
||||
"help": "The model name actually pass to current proxy server url, such "
|
||||
"as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
|
||||
},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
@ -463,13 +462,15 @@ class ProxyModelParameters(BaseModelParameters):
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Device to run model. If None, the device is automatically determined"
|
||||
"help": "Device to run model. If None, the device is automatically "
|
||||
"determined"
|
||||
},
|
||||
)
|
||||
prompt_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
|
||||
"help": f"Prompt template. If None, the prompt template is automatically "
|
||||
f"determined from model path"
|
||||
},
|
||||
)
|
||||
max_context_size: Optional[int] = field(
|
||||
@ -478,7 +479,8 @@ class ProxyModelParameters(BaseModelParameters):
|
||||
llm_client_class: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The class name of llm client, such as dbgpt.model.proxy.llms.proxy_model.ProxyModel"
|
||||
"help": "The class name of llm client, such as "
|
||||
"dbgpt.model.proxy.llms.proxy_model.ProxyModel"
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -37,8 +37,8 @@ def list_supported_models():
|
||||
def _list_supported_models(
|
||||
worker_type: str, model_config: Dict[str, str]
|
||||
) -> List[SupportedModel]:
|
||||
from dbgpt.model.adapter.loader import _get_model_real_path
|
||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||
from dbgpt.model.loader import _get_model_real_path
|
||||
|
||||
ret = []
|
||||
for model_name, model_path in model_config.items():
|
@ -67,7 +67,7 @@ class AwelLayoutChatManager(ManagerAgent):
|
||||
message=start_message, sender=self, reviewer=reviewer
|
||||
)
|
||||
final_generate_context: AgentGenerateContext = await last_node.call(
|
||||
call_data={"data": start_message_context}
|
||||
call_data=start_message_context
|
||||
)
|
||||
last_message = final_generate_context.rely_messages[-1]
|
||||
|
||||
|
@ -28,7 +28,7 @@ from dbgpt.core.interface.llm import ModelMetadata
|
||||
from dbgpt.serve.agent.team.plan.team_auto_plan import AutoPlanChatManager
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
||||
|
@ -30,7 +30,7 @@ parent_dir = os.path.dirname(current_dir)
|
||||
test_plugin_dir = os.path.join(parent_dir, "test_files")
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
||||
|
@ -24,7 +24,7 @@ from dbgpt.agent.memory.gpts_memory import GptsMemory
|
||||
from dbgpt.core.interface.llm import ModelMetadata
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
||||
|
@ -27,7 +27,7 @@ from dbgpt.core.interface.llm import ModelMetadata
|
||||
|
||||
|
||||
def summary_example_with_success():
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
context: AgentContext = AgentContext(
|
||||
|
@ -24,7 +24,7 @@ from dbgpt.agent.memory.gpts_memory import GptsMemory
|
||||
from dbgpt.core.interface.llm import ModelMetadata
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
||||
|
@ -25,7 +25,7 @@ from dbgpt.core.interface.llm import ModelMetadata
|
||||
|
||||
|
||||
def summary_example_with_success():
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
context: AgentContext = AgentContext(conv_id="summarize", llm_provider=llm_client)
|
||||
@ -76,7 +76,7 @@ def summary_example_with_success():
|
||||
|
||||
|
||||
def summary_example_with_faliure():
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient()
|
||||
context: AgentContext = AgentContext(conv_id="summarize", llm_provider=llm_client)
|
||||
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operator.embedding import EmbeddingRetrieverOperator
|
||||
|
@ -32,7 +32,7 @@ from typing import Dict
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.operator.rewrite import QueryRewriteOperator
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ from typing import Dict
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
from dbgpt.rag.operator.knowledge import KnowledgeOperator
|
||||
from dbgpt.rag.operator.summary import SummaryAssemblerOperator
|
||||
|
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
|
||||
"""Query rewrite example.
|
||||
|
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
|
@ -7,7 +7,7 @@ from dbgpt.core.operator import (
|
||||
PromptBuilderOperator,
|
||||
RequestBuilderOperator,
|
||||
)
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
with DAG("simple_sdk_llm_example_dag") as dag:
|
||||
prompt_task = PromptBuilderOperator(
|
||||
@ -20,8 +20,6 @@ with DAG("simple_sdk_llm_example_dag") as dag:
|
||||
|
||||
if __name__ == "__main__":
|
||||
output = asyncio.run(
|
||||
out_parse_task.call(
|
||||
call_data={"data": {"dialect": "mysql", "table_name": "user"}}
|
||||
)
|
||||
out_parse_task.call(call_data={"dialect": "mysql", "table_name": "user"})
|
||||
)
|
||||
print(f"output: \n\n{output}")
|
||||
|
@ -17,7 +17,7 @@ from dbgpt.core.operator import (
|
||||
)
|
||||
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
from dbgpt.model import OpenAILLMClient
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator
|
||||
|
||||
|
||||
@ -144,12 +144,10 @@ with DAG("simple_sdk_llm_sql_example") as dag:
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_data = {
|
||||
"data": {
|
||||
"db_name": "test_db",
|
||||
"dialect": "sqlite",
|
||||
"top_k": 5,
|
||||
"user_input": "What is the name and age of the user with age less than 18",
|
||||
}
|
||||
"db_name": "test_db",
|
||||
"dialect": "sqlite",
|
||||
"top_k": 5,
|
||||
"user_input": "What is the name and age of the user with age less than 18",
|
||||
}
|
||||
output = asyncio.run(sql_result_task.call(call_data=input_data))
|
||||
print(f"\nthoughts: {output.get('thoughts')}\n")
|
||||
|
91
setup.py
91
setup.py
@ -14,6 +14,11 @@ import functools
|
||||
with open("README.md", mode="r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
|
||||
# If you modify the version, please modify the version in the following files:
|
||||
# dbgpt/_version.py
|
||||
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7")
|
||||
|
||||
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
||||
LLAMA_CPP_GPU_ACCELERATION = (
|
||||
os.getenv("LLAMA_CPP_GPU_ACCELERATION", "true").lower() == "true"
|
||||
@ -352,31 +357,41 @@ def llama_cpp_python_cuda_requires():
|
||||
|
||||
def core_requires():
|
||||
"""
|
||||
pip install db-gpt or pip install "db-gpt[core]"
|
||||
pip install dbgpt or pip install "dbgpt[core]"
|
||||
"""
|
||||
setup_spec.extras["core"] = [
|
||||
"aiohttp==3.8.4",
|
||||
"chardet==5.1.0",
|
||||
"importlib-resources==5.12.0",
|
||||
"psutil==5.9.4",
|
||||
"python-dotenv==1.0.0",
|
||||
"colorama==0.4.6",
|
||||
"prettytable",
|
||||
"cachetools",
|
||||
"pydantic<2,>=1",
|
||||
]
|
||||
# Just use by DB-GPT internal, we should find the smallest dependency set for run we core unit test.
|
||||
# Simple command line dependencies
|
||||
setup_spec.extras["cli"] = setup_spec.extras["core"] + [
|
||||
"prettytable",
|
||||
"click",
|
||||
"psutil==5.9.4",
|
||||
"colorama==0.4.6",
|
||||
]
|
||||
# Just use by DB-GPT internal, we should find the smallest dependency set for run
|
||||
# we core unit test.
|
||||
# The dependency "framework" is too large for now.
|
||||
setup_spec.extras["simple_framework"] = setup_spec.extras["core"] + [
|
||||
setup_spec.extras["simple_framework"] = setup_spec.extras["cli"] + [
|
||||
"pydantic<2,>=1",
|
||||
"httpx",
|
||||
"jinja2",
|
||||
"fastapi==0.98.0",
|
||||
"uvicorn",
|
||||
"shortuuid",
|
||||
# change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2
|
||||
# change from fixed version 2.0.22 to variable version, because other
|
||||
# dependencies are >=1.4, such as pydoris is <2
|
||||
"SQLAlchemy>=1.4,<3",
|
||||
# for cache
|
||||
"msgpack",
|
||||
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
||||
# for cache
|
||||
# TODO: pympler has not been updated for a long time and needs to
|
||||
# find a new toolkit.
|
||||
"pympler",
|
||||
"sqlparse==0.4.4",
|
||||
"duckdb==0.8.1",
|
||||
@ -418,7 +433,7 @@ def core_requires():
|
||||
|
||||
def knowledge_requires():
|
||||
"""
|
||||
pip install "db-gpt[knowledge]"
|
||||
pip install "dbgpt[knowledge]"
|
||||
"""
|
||||
setup_spec.extras["knowledge"] = [
|
||||
"spacy==3.5.3",
|
||||
@ -435,7 +450,7 @@ def knowledge_requires():
|
||||
|
||||
def llama_cpp_requires():
|
||||
"""
|
||||
pip install "db-gpt[llama_cpp]"
|
||||
pip install "dbgpt[llama_cpp]"
|
||||
"""
|
||||
setup_spec.extras["llama_cpp"] = ["llama-cpp-python"]
|
||||
llama_cpp_python_cuda_requires()
|
||||
@ -523,7 +538,7 @@ def quantization_requires():
|
||||
|
||||
def all_vector_store_requires():
|
||||
"""
|
||||
pip install "db-gpt[vstore]"
|
||||
pip install "dbgpt[vstore]"
|
||||
"""
|
||||
setup_spec.extras["vstore"] = [
|
||||
"grpcio==1.47.5", # maybe delete it
|
||||
@ -534,7 +549,7 @@ def all_vector_store_requires():
|
||||
|
||||
def all_datasource_requires():
|
||||
"""
|
||||
pip install "db-gpt[datasource]"
|
||||
pip install "dbgpt[datasource]"
|
||||
"""
|
||||
|
||||
setup_spec.extras["datasource"] = [
|
||||
@ -552,7 +567,7 @@ def all_datasource_requires():
|
||||
|
||||
def openai_requires():
|
||||
"""
|
||||
pip install "db-gpt[openai]"
|
||||
pip install "dbgpt[openai]"
|
||||
"""
|
||||
setup_spec.extras["openai"] = ["tiktoken"]
|
||||
if BUILD_VERSION_OPENAI:
|
||||
@ -567,28 +582,28 @@ def openai_requires():
|
||||
|
||||
def gpt4all_requires():
|
||||
"""
|
||||
pip install "db-gpt[gpt4all]"
|
||||
pip install "dbgpt[gpt4all]"
|
||||
"""
|
||||
setup_spec.extras["gpt4all"] = ["gpt4all"]
|
||||
|
||||
|
||||
def vllm_requires():
|
||||
"""
|
||||
pip install "db-gpt[vllm]"
|
||||
pip install "dbgpt[vllm]"
|
||||
"""
|
||||
setup_spec.extras["vllm"] = ["vllm"]
|
||||
|
||||
|
||||
def cache_requires():
|
||||
"""
|
||||
pip install "db-gpt[cache]"
|
||||
pip install "dbgpt[cache]"
|
||||
"""
|
||||
setup_spec.extras["cache"] = ["rocksdict"]
|
||||
|
||||
|
||||
def default_requires():
|
||||
"""
|
||||
pip install "db-gpt[default]"
|
||||
pip install "dbgpt[default]"
|
||||
"""
|
||||
setup_spec.extras["default"] = [
|
||||
# "tokenizers==0.13.3",
|
||||
@ -637,14 +652,46 @@ default_requires()
|
||||
all_requires()
|
||||
init_install_requires()
|
||||
|
||||
# Packages to exclude when IS_DEV_MODE is False
|
||||
excluded_packages = ["tests", "*.tests", "*.tests.*", "examples"]
|
||||
|
||||
if IS_DEV_MODE:
|
||||
packages = find_packages(exclude=excluded_packages)
|
||||
else:
|
||||
packages = find_packages(
|
||||
exclude=excluded_packages,
|
||||
include=[
|
||||
"dbgpt",
|
||||
"dbgpt._private",
|
||||
"dbgpt._private.*",
|
||||
"dbgpt.cli",
|
||||
"dbgpt.cli.*",
|
||||
"dbgpt.configs",
|
||||
"dbgpt.configs.*",
|
||||
"dbgpt.core",
|
||||
"dbgpt.core.*",
|
||||
"dbgpt.util",
|
||||
"dbgpt.util.*",
|
||||
"dbgpt.model",
|
||||
"dbgpt.model.proxy",
|
||||
"dbgpt.model.proxy.*",
|
||||
"dbgpt.model.operator",
|
||||
"dbgpt.model.operator.*",
|
||||
"dbgpt.model.utils",
|
||||
"dbgpt.model.utils.*",
|
||||
],
|
||||
)
|
||||
|
||||
setuptools.setup(
|
||||
name="db-gpt",
|
||||
packages=find_packages(exclude=("tests", "*.tests", "*.tests.*", "examples")),
|
||||
version="0.4.5",
|
||||
name="dbgpt",
|
||||
packages=packages,
|
||||
version=DB_GPT_VERSION,
|
||||
author="csunny",
|
||||
author_email="cfqcsunny@gmail.com",
|
||||
description="DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment."
|
||||
" With this solution, you can be assured that there is no risk of data leakage, and your data is 100% private and secure.",
|
||||
description="DB-GPT is an experimental open-source project that uses localized GPT "
|
||||
"large models to interact with your data and environment."
|
||||
" With this solution, you can be assured that there is no risk of data leakage, "
|
||||
"and your data is 100% private and secure.",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
install_requires=setup_spec.install_requires,
|
||||
|
Loading…
Reference in New Issue
Block a user