refactor: Refactor for core SDK (#1092)

This commit is contained in:
Fangyin Cheng 2024-01-21 09:57:57 +08:00 committed by GitHub
parent ba7248adbb
commit 2d905191f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 236 additions and 133 deletions

View File

@ -81,6 +81,19 @@ clean: ## Clean up the environment
find . -type d -name '.pytest_cache' -delete
find . -type d -name '.coverage' -delete
.PHONY: clean-dist
clean-dist: ## Clean up the distribution
rm -rf dist/ *.egg-info build/
.PHONY: package
package: clean-dist ## Package the project for distribution
IS_DEV_MODE=false python setup.py sdist bdist_wheel
.PHONY: upload
upload: package ## Upload the package to PyPI
# upload to testpypi: twine upload --repository testpypi dist/*
twine upload dist/*
.PHONY: help
help: ## Display this help screen
@echo "Available commands:"

View File

@ -1,12 +1,18 @@
from dbgpt.component import BaseComponent, SystemApp
__ALL__ = ["SystemApp", "BaseComponent"]
"""DB-GPT: Next Generation Data Interaction Solution with LLMs.
"""
from dbgpt import _version # noqa: E402
from dbgpt.component import BaseComponent, SystemApp # noqa: F401
_CORE_LIBS = ["core", "rag", "model", "agent", "datasource", "vis", "storage", "train"]
_SERVE_LIBS = ["serve"]
_LIBS = _CORE_LIBS + _SERVE_LIBS
__version__ = _version.version
__ALL__ = ["__version__", "SystemApp", "BaseComponent"]
def __getattr__(name: str):
# Lazy load
import importlib

1
dbgpt/_version.py Normal file
View File

@ -0,0 +1 @@
version = "0.4.7"

View File

@ -579,7 +579,7 @@ if __name__ == "__main__":
from dbgpt.agent.agents.agent import AgentContext
from dbgpt.agent.agents.user_proxy_agent import UserProxyAgent
from dbgpt.core.interface.llm import ModelMetadata
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
llm_client = OpenAILLMClient()
context: AgentContext = AgentContext(

View File

@ -9,7 +9,7 @@ from functools import cache
from typing import Dict, List, Tuple
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.model.conversation import Conversation, get_conv_template
from dbgpt.model.llm.conversation import Conversation, get_conv_template
class BaseChatAdpter:
@ -21,7 +21,7 @@ class BaseChatAdpter:
def get_generate_stream_func(self, model_path: str):
"""Return the generate stream handler func"""
from dbgpt.model.inference import generate_stream
from dbgpt.model.llm.inference import generate_stream
return generate_stream

View File

@ -171,13 +171,13 @@ class BaseChat(ABC):
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
return await llm_task.call(call_data={"data": request})
return await llm_task.call(call_data=request)
async def call_streaming_operator(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
async for out in await llm_task.call_stream(call_data={"data": request}):
async for out in await llm_task.call_stream(call_data=request):
yield out
def do_action(self, prompt_response):
@ -251,11 +251,9 @@ class BaseChat(ABC):
str_history=self.prompt_template.str_history,
request_context=req_ctx,
)
node_input = {
"data": ChatComposerInput(
messages=self.history_messages, prompt_dict=input_values
)
}
node_input = ChatComposerInput(
messages=self.history_messages, prompt_dict=input_values
)
# llm_messages = self.generate_llm_messages()
model_request: ModelRequest = await node.call(call_data=node_input)
model_request.context.cache_enable = self.model_cache_enable

View File

@ -87,7 +87,7 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
# Sub dag, use the same dag context in the parent dag
messages = await end_node.call(
call_data={"data": input_value}, dag_ctx=self.current_dag_context
call_data=input_value, dag_ctx=self.current_dag_context
)
span_id = self._request_context.span_id
model_request = ModelRequest.build_request(

View File

@ -1,3 +1,7 @@
"""Component module for dbgpt.
Manages the lifecycle and registration of components.
"""
from __future__ import annotations
import asyncio

View File

@ -22,6 +22,7 @@ from .operator.common_operator import (
JoinOperator,
MapOperator,
ReduceStreamOperator,
TriggerOperator,
)
from .operator.stream_operator import (
StreamifyAbsOperator,
@ -50,6 +51,7 @@ __all__ = [
"BaseOperator",
"JoinOperator",
"ReduceStreamOperator",
"TriggerOperator",
"MapOperator",
"BranchOperator",
"InputOperator",
@ -150,4 +152,6 @@ def setup_dev_environment(
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
uvicorn.run(app, host=host, port=port)
if trigger_manager.keep_running():
# Should keep running
uvicorn.run(app, host=host, port=port)

View File

@ -28,7 +28,7 @@ from ..task.base import OUT, T, TaskOutput
F = TypeVar("F", bound=FunctionType)
CALL_DATA = Union[Dict, Dict[str, Dict]]
CALL_DATA = Union[Dict[str, Any], Any]
class WorkflowRunner(ABC, Generic[T]):
@ -197,6 +197,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Returns:
OUT: The output of the node after execution.
"""
if call_data:
call_data = {"data": call_data}
out_ctx = await self._runner.execute_workflow(
self, call_data, exist_dag_ctx=dag_ctx
)
@ -242,6 +244,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
if call_data:
call_data = {"data": call_data}
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
)

View File

@ -28,6 +28,14 @@ EMPTY_DATA = _EMPTY_DATA_TYPE()
SKIP_DATA = _EMPTY_DATA_TYPE()
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
def is_empty_data(data: Any):
"""Check if the data is empty."""
if isinstance(data, _EMPTY_DATA_TYPE):
return data in (EMPTY_DATA, SKIP_DATA)
return False
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]

View File

@ -24,7 +24,6 @@ from .base import (
EMPTY_DATA,
OUT,
PLACEHOLDER_DATA,
SKIP_DATA,
InputContext,
InputSource,
MapFunc,
@ -37,6 +36,7 @@ from .base import (
TaskState,
TransformFunc,
UnStreamFunc,
is_empty_data,
)
logger = logging.getLogger(__name__)
@ -99,7 +99,7 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
@property
def is_empty(self) -> bool:
"""Return True if the output data is empty."""
return self._data == EMPTY_DATA or self._data == SKIP_DATA
return is_empty_data(self._data)
@property
def is_none(self) -> bool:
@ -171,7 +171,7 @@ class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
@property
def is_empty(self) -> bool:
"""Return True if the output data is empty."""
return self._data == EMPTY_DATA or self._data == SKIP_DATA
return is_empty_data(self._data)
@property
def is_none(self) -> bool:
@ -330,7 +330,7 @@ class SimpleCallDataInputSource(BaseInputSource):
"""
call_data = task_ctx.call_data
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
if data == EMPTY_DATA:
if is_empty_data(data):
raise ValueError("No call data for current SimpleCallDataInputSource")
return data

View File

@ -1,12 +1,8 @@
"""Http trigger for AWEL."""
from __future__ import annotations
import logging
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
from starlette.requests import Request
from dbgpt._private.pydantic import BaseModel
from ..dag.base import DAG
@ -15,9 +11,10 @@ from .base import Trigger
if TYPE_CHECKING:
from fastapi import APIRouter
from starlette.requests import Request
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
logger = logging.getLogger(__name__)
@ -32,9 +29,9 @@ class HttpTrigger(Trigger):
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
request_body: Optional["RequestBody"] = None,
streaming_response: bool = False,
streaming_predict_func: Optional[StreamingPredictFunc] = None,
streaming_predict_func: Optional["StreamingPredictFunc"] = None,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
@ -69,6 +66,7 @@ class HttpTrigger(Trigger):
router (APIRouter): The router to mount the trigger.
"""
from fastapi import Depends
from starlette.requests import Request
methods = [self._methods] if isinstance(self._methods, str) else self._methods
@ -114,8 +112,10 @@ class HttpTrigger(Trigger):
async def _parse_request_body(
request: Request, request_body_cls: Optional[RequestBody]
request: "Request", request_body_cls: Optional["RequestBody"]
):
from starlette.requests import Request
if not request_body_cls:
return None
if request_body_cls == Request:
@ -152,7 +152,7 @@ async def _trigger_dag(
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = cast(BaseOperator, leaf_nodes[0])
if not streaming_response:
return await end_node.call(call_data={"data": body})
return await end_node.call(call_data=body)
else:
headers = response_headers
media_type = response_media_type if response_media_type else "text/event-stream"
@ -163,7 +163,7 @@ async def _trigger_dag(
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
generator = await end_node.call_stream(call_data={"data": body})
generator = await end_node.call_stream(call_data=body)
background_tasks = BackgroundTasks()
background_tasks.add_task(dag._after_dag_end)
return StreamingResponse(

View File

@ -24,6 +24,14 @@ class TriggerManager(ABC):
def register_trigger(self, trigger: Any) -> None:
"""Register a trigger to current manager."""
def keep_running(self) -> bool:
"""Whether keep running.
Returns:
bool: Whether keep running, True means keep running, False means stop.
"""
return False
class HttpTriggerManager(TriggerManager):
"""Http trigger manager.
@ -64,6 +72,8 @@ class HttpTriggerManager(TriggerManager):
self._trigger_map[trigger_id] = trigger
def _init_app(self, system_app: SystemApp):
if not self.keep_running():
return
logger.info(
f"Include router {self._router} to prefix path {self._router_prefix}"
)
@ -72,6 +82,14 @@ class HttpTriggerManager(TriggerManager):
raise RuntimeError("System app not initialized")
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
def keep_running(self) -> bool:
"""Whether keep running.
Returns:
bool: Whether keep running, True means keep running, False means stop.
"""
return len(self._trigger_map) > 0
class DefaultTriggerManager(TriggerManager, BaseComponent):
"""Default trigger manager for AWEL.
@ -105,3 +123,11 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
"""After register, init the trigger manager."""
if self.system_app:
self.http_trigger._init_app(self.system_app)
def keep_running(self) -> bool:
"""Whether keep running.
Returns:
bool: Whether keep running, True means keep running, False means stop.
"""
return self.http_trigger.keep_running()

View File

@ -70,7 +70,7 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0])
# Sub dag, use the same dag context in the parent dag
return await end_node.call(
call_data={"data": input_value}, dag_ctx=self.current_dag_context
call_data=input_value, dag_ctx=self.current_dag_context
)
def _build_composer_dag(self) -> DAG:

View File

@ -150,7 +150,7 @@ class PromptBuilderOperator(
)
)
single_input = {"data": {"dialect": "mysql"}}
single_input = {"dialect": "mysql"}
single_expected_messages = [
ModelMessage(
content="Please write a mysql SQL count the length of a field",

View File

@ -1,9 +1,12 @@
from dbgpt.model.cluster.client import DefaultLLMClient
try:
from dbgpt.model.cluster.client import DefaultLLMClient
except ImportError as exc:
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
DefaultLLMClient = None
# from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
__ALL__ = [
"DefaultLLMClient",
"OpenAILLMClient",
]
_exports = []
if DefaultLLMClient:
_exports.append("DefaultLLMClient")
__ALL__ = _exports

View File

@ -137,7 +137,7 @@ class ModelLoader:
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
import torch
from dbgpt.model.compression import compress_module
from dbgpt.model.llm.compression import compress_module
device = model_params.device
max_memory = None

View File

@ -19,7 +19,7 @@ from dbgpt.configs.model_config import get_device
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
from dbgpt.model.base import ModelType
from dbgpt.model.conversation import Conversation
from dbgpt.model.llm.conversation import Conversation
from dbgpt.model.parameter import (
LlamaCppModelParameters,
ModelParameters,

View File

@ -1,13 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import time
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, TypedDict
from typing import Dict, List, Optional
from dbgpt.util.model_utils import GPUInfo
from dbgpt.util.parameter_utils import ParameterDescription

View File

@ -12,9 +12,9 @@ from dbgpt.core import (
ModelOutput,
)
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.loader import ModelLoader, _get_model_real_path
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.loader import ModelLoader, _get_model_real_path
from dbgpt.model.parameter import ModelParameters
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj

View File

@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Type
from dbgpt.configs.model_config import get_device
from dbgpt.core import ModelMetadata
from dbgpt.model.adapter.loader import _get_model_real_path
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.loader import _get_model_real_path
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,

View File

@ -8,7 +8,7 @@ import sys
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List
from typing import Awaitable, Callable, Iterator
from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse
@ -16,12 +16,7 @@ from fastapi.responses import StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import LOGDIR
from dbgpt.core import ModelMetadata, ModelOutput
from dbgpt.model.base import (
ModelInstance,
WorkerApplyOutput,
WorkerApplyType,
WorkerSupportedModel,
)
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
from dbgpt.model.cluster.base import *
from dbgpt.model.cluster.manager_base import (
WorkerManager,
@ -30,8 +25,8 @@ from dbgpt.model.cluster.manager_base import (
)
from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.llm_utils import list_supported_models
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
from dbgpt.model.utils.llm_utils import list_supported_models
from dbgpt.util.parameter_utils import (
EnvArgumentParser,
ParameterDescription,

View File

@ -18,7 +18,7 @@ from transformers.generation.logits_process import (
TopPLogitsWarper,
)
from dbgpt.model.llm_utils import is_partial_stop, is_sentence_complete
from dbgpt.model.utils.llm_utils import is_partial_stop, is_sentence_complete
def prepare_logits_processor(

View File

@ -1,9 +1,9 @@
from dbgpt.model.operator.llm_operator import (
from dbgpt.model.operator.llm_operator import ( # noqa: F401
LLMOperator,
MixinLLMOperator,
StreamingLLMOperator,
)
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator # noqa: F401
__ALL__ = [
"MixinLLMOperator",

View File

@ -6,7 +6,6 @@ from dbgpt.component import ComponentType
from dbgpt.core import LLMClient
from dbgpt.core.awel import BaseOperator
from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator
from dbgpt.model.cluster import WorkerManagerFactory
logger = logging.getLogger(__name__)
@ -19,31 +18,30 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
super().__init__(default_client)
self._default_llm_client = default_client
@property
def llm_client(self) -> LLMClient:
if not self._llm_client:
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
if worker_manager_factory:
try:
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.client import DefaultLLMClient
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
else:
if self._default_llm_client is None:
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
self._default_llm_client = OpenAILLMClient()
logger.info(
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
self._llm_client = self._default_llm_client
if worker_manager_factory:
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
except Exception as e:
logger.warning(f"Load worker manager failed: {e}.")
if not self._llm_client:
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
logger.info("Can't find worker manager factory, use OpenAILLMClient.")
self._llm_client = OpenAILLMClient()
return self._llm_client

View File

@ -6,11 +6,8 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional, Tuple, Union
from dbgpt.model.conversation import conv_templates
from dbgpt.util.parameter_utils import BaseParameters
suported_prompt_templates = ",".join(conv_templates.keys())
class WorkerType(str, Enum):
LLM = "llm"
@ -299,7 +296,8 @@ class ModelParameters(BaseModelParameters):
prompt_template: Optional[str] = field(
default=None,
metadata={
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
"help": f"Prompt template. If None, the prompt template is automatically "
f"determined from model path"
},
)
max_context_size: Optional[int] = field(
@ -450,7 +448,8 @@ class ProxyModelParameters(BaseModelParameters):
proxyllm_backend: Optional[str] = field(
default=None,
metadata={
"help": "The model name actually pass to current proxy server url, such as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
"help": "The model name actually pass to current proxy server url, such "
"as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
},
)
model_type: Optional[str] = field(
@ -463,13 +462,15 @@ class ProxyModelParameters(BaseModelParameters):
device: Optional[str] = field(
default=None,
metadata={
"help": "Device to run model. If None, the device is automatically determined"
"help": "Device to run model. If None, the device is automatically "
"determined"
},
)
prompt_template: Optional[str] = field(
default=None,
metadata={
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
"help": f"Prompt template. If None, the prompt template is automatically "
f"determined from model path"
},
)
max_context_size: Optional[int] = field(
@ -478,7 +479,8 @@ class ProxyModelParameters(BaseModelParameters):
llm_client_class: Optional[str] = field(
default=None,
metadata={
"help": "The class name of llm client, such as dbgpt.model.proxy.llms.proxy_model.ProxyModel"
"help": "The class name of llm client, such as "
"dbgpt.model.proxy.llms.proxy_model.ProxyModel"
},
)

View File

@ -37,8 +37,8 @@ def list_supported_models():
def _list_supported_models(
worker_type: str, model_config: Dict[str, str]
) -> List[SupportedModel]:
from dbgpt.model.adapter.loader import _get_model_real_path
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.loader import _get_model_real_path
ret = []
for model_name, model_path in model_config.items():

View File

@ -67,7 +67,7 @@ class AwelLayoutChatManager(ManagerAgent):
message=start_message, sender=self, reviewer=reviewer
)
final_generate_context: AgentGenerateContext = await last_node.call(
call_data={"data": start_message_context}
call_data=start_message_context
)
last_message = final_generate_context.rely_messages[-1]

View File

@ -28,7 +28,7 @@ from dbgpt.core.interface.llm import ModelMetadata
from dbgpt.serve.agent.team.plan.team_auto_plan import AutoPlanChatManager
if __name__ == "__main__":
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
llm_client = OpenAILLMClient()
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)

View File

@ -30,7 +30,7 @@ parent_dir = os.path.dirname(current_dir)
test_plugin_dir = os.path.join(parent_dir, "test_files")
if __name__ == "__main__":
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
llm_client = OpenAILLMClient()
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)

View File

@ -24,7 +24,7 @@ from dbgpt.agent.memory.gpts_memory import GptsMemory
from dbgpt.core.interface.llm import ModelMetadata
if __name__ == "__main__":
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
llm_client = OpenAILLMClient()
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)

View File

@ -27,7 +27,7 @@ from dbgpt.core.interface.llm import ModelMetadata
def summary_example_with_success():
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
llm_client = OpenAILLMClient()
context: AgentContext = AgentContext(

View File

@ -24,7 +24,7 @@ from dbgpt.agent.memory.gpts_memory import GptsMemory
from dbgpt.core.interface.llm import ModelMetadata
if __name__ == "__main__":
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
llm_client = OpenAILLMClient()
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)

View File

@ -25,7 +25,7 @@ from dbgpt.core.interface.llm import ModelMetadata
def summary_example_with_success():
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
llm_client = OpenAILLMClient()
context: AgentContext = AgentContext(conv_id="summarize", llm_provider=llm_client)
@ -76,7 +76,7 @@ def summary_example_with_success():
def summary_example_with_faliure():
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
llm_client = OpenAILLMClient()
context: AgentContext = AgentContext(conv_id="summarize", llm_provider=llm_client)

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.operator.embedding import EmbeddingRetrieverOperator

View File

@ -32,7 +32,7 @@ from typing import Dict
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.operator.rewrite import QueryRewriteOperator

View File

@ -30,7 +30,7 @@ from typing import Dict
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.knowledge.base import KnowledgeType
from dbgpt.rag.operator.knowledge import KnowledgeOperator
from dbgpt.rag.operator.summary import SummaryAssemblerOperator

View File

@ -1,7 +1,7 @@
import asyncio
import os
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.retriever.rewrite import QueryRewrite
"""Query rewrite example.

View File

@ -1,6 +1,6 @@
import asyncio
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.serve.rag.assembler.summary import SummaryAssembler

View File

@ -7,7 +7,7 @@ from dbgpt.core.operator import (
PromptBuilderOperator,
RequestBuilderOperator,
)
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
with DAG("simple_sdk_llm_example_dag") as dag:
prompt_task = PromptBuilderOperator(
@ -20,8 +20,6 @@ with DAG("simple_sdk_llm_example_dag") as dag:
if __name__ == "__main__":
output = asyncio.run(
out_parse_task.call(
call_data={"data": {"dialect": "mysql", "table_name": "user"}}
)
out_parse_task.call(call_data={"dialect": "mysql", "table_name": "user"})
)
print(f"output: \n\n{output}")

View File

@ -17,7 +17,7 @@ from dbgpt.core.operator import (
)
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator
@ -144,12 +144,10 @@ with DAG("simple_sdk_llm_sql_example") as dag:
if __name__ == "__main__":
input_data = {
"data": {
"db_name": "test_db",
"dialect": "sqlite",
"top_k": 5,
"user_input": "What is the name and age of the user with age less than 18",
}
"db_name": "test_db",
"dialect": "sqlite",
"top_k": 5,
"user_input": "What is the name and age of the user with age less than 18",
}
output = asyncio.run(sql_result_task.call(call_data=input_data))
print(f"\nthoughts: {output.get('thoughts')}\n")

View File

@ -14,6 +14,11 @@ import functools
with open("README.md", mode="r", encoding="utf-8") as fh:
long_description = fh.read()
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
# If you modify the version, please modify the version in the following files:
# dbgpt/_version.py
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7")
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
LLAMA_CPP_GPU_ACCELERATION = (
os.getenv("LLAMA_CPP_GPU_ACCELERATION", "true").lower() == "true"
@ -352,31 +357,41 @@ def llama_cpp_python_cuda_requires():
def core_requires():
"""
pip install db-gpt or pip install "db-gpt[core]"
pip install dbgpt or pip install "dbgpt[core]"
"""
setup_spec.extras["core"] = [
"aiohttp==3.8.4",
"chardet==5.1.0",
"importlib-resources==5.12.0",
"psutil==5.9.4",
"python-dotenv==1.0.0",
"colorama==0.4.6",
"prettytable",
"cachetools",
"pydantic<2,>=1",
]
# Just use by DB-GPT internal, we should find the smallest dependency set for run we core unit test.
# Simple command line dependencies
setup_spec.extras["cli"] = setup_spec.extras["core"] + [
"prettytable",
"click",
"psutil==5.9.4",
"colorama==0.4.6",
]
# Just use by DB-GPT internal, we should find the smallest dependency set for run
# we core unit test.
# The dependency "framework" is too large for now.
setup_spec.extras["simple_framework"] = setup_spec.extras["core"] + [
setup_spec.extras["simple_framework"] = setup_spec.extras["cli"] + [
"pydantic<2,>=1",
"httpx",
"jinja2",
"fastapi==0.98.0",
"uvicorn",
"shortuuid",
# change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2
# change from fixed version 2.0.22 to variable version, because other
# dependencies are >=1.4, such as pydoris is <2
"SQLAlchemy>=1.4,<3",
# for cache
"msgpack",
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
# for cache
# TODO: pympler has not been updated for a long time and needs to
# find a new toolkit.
"pympler",
"sqlparse==0.4.4",
"duckdb==0.8.1",
@ -418,7 +433,7 @@ def core_requires():
def knowledge_requires():
"""
pip install "db-gpt[knowledge]"
pip install "dbgpt[knowledge]"
"""
setup_spec.extras["knowledge"] = [
"spacy==3.5.3",
@ -435,7 +450,7 @@ def knowledge_requires():
def llama_cpp_requires():
"""
pip install "db-gpt[llama_cpp]"
pip install "dbgpt[llama_cpp]"
"""
setup_spec.extras["llama_cpp"] = ["llama-cpp-python"]
llama_cpp_python_cuda_requires()
@ -523,7 +538,7 @@ def quantization_requires():
def all_vector_store_requires():
"""
pip install "db-gpt[vstore]"
pip install "dbgpt[vstore]"
"""
setup_spec.extras["vstore"] = [
"grpcio==1.47.5", # maybe delete it
@ -534,7 +549,7 @@ def all_vector_store_requires():
def all_datasource_requires():
"""
pip install "db-gpt[datasource]"
pip install "dbgpt[datasource]"
"""
setup_spec.extras["datasource"] = [
@ -552,7 +567,7 @@ def all_datasource_requires():
def openai_requires():
"""
pip install "db-gpt[openai]"
pip install "dbgpt[openai]"
"""
setup_spec.extras["openai"] = ["tiktoken"]
if BUILD_VERSION_OPENAI:
@ -567,28 +582,28 @@ def openai_requires():
def gpt4all_requires():
"""
pip install "db-gpt[gpt4all]"
pip install "dbgpt[gpt4all]"
"""
setup_spec.extras["gpt4all"] = ["gpt4all"]
def vllm_requires():
"""
pip install "db-gpt[vllm]"
pip install "dbgpt[vllm]"
"""
setup_spec.extras["vllm"] = ["vllm"]
def cache_requires():
"""
pip install "db-gpt[cache]"
pip install "dbgpt[cache]"
"""
setup_spec.extras["cache"] = ["rocksdict"]
def default_requires():
"""
pip install "db-gpt[default]"
pip install "dbgpt[default]"
"""
setup_spec.extras["default"] = [
# "tokenizers==0.13.3",
@ -637,14 +652,46 @@ default_requires()
all_requires()
init_install_requires()
# Packages to exclude when IS_DEV_MODE is False
excluded_packages = ["tests", "*.tests", "*.tests.*", "examples"]
if IS_DEV_MODE:
packages = find_packages(exclude=excluded_packages)
else:
packages = find_packages(
exclude=excluded_packages,
include=[
"dbgpt",
"dbgpt._private",
"dbgpt._private.*",
"dbgpt.cli",
"dbgpt.cli.*",
"dbgpt.configs",
"dbgpt.configs.*",
"dbgpt.core",
"dbgpt.core.*",
"dbgpt.util",
"dbgpt.util.*",
"dbgpt.model",
"dbgpt.model.proxy",
"dbgpt.model.proxy.*",
"dbgpt.model.operator",
"dbgpt.model.operator.*",
"dbgpt.model.utils",
"dbgpt.model.utils.*",
],
)
setuptools.setup(
name="db-gpt",
packages=find_packages(exclude=("tests", "*.tests", "*.tests.*", "examples")),
version="0.4.5",
name="dbgpt",
packages=packages,
version=DB_GPT_VERSION,
author="csunny",
author_email="cfqcsunny@gmail.com",
description="DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment."
" With this solution, you can be assured that there is no risk of data leakage, and your data is 100% private and secure.",
description="DB-GPT is an experimental open-source project that uses localized GPT "
"large models to interact with your data and environment."
" With this solution, you can be assured that there is no risk of data leakage, "
"and your data is 100% private and secure.",
long_description=long_description,
long_description_content_type="text/markdown",
install_requires=setup_spec.install_requires,