mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 02:46:40 +00:00
refactor: Refactor for core SDK (#1092)
This commit is contained in:
parent
ba7248adbb
commit
2d905191f8
13
Makefile
13
Makefile
@ -81,6 +81,19 @@ clean: ## Clean up the environment
|
|||||||
find . -type d -name '.pytest_cache' -delete
|
find . -type d -name '.pytest_cache' -delete
|
||||||
find . -type d -name '.coverage' -delete
|
find . -type d -name '.coverage' -delete
|
||||||
|
|
||||||
|
.PHONY: clean-dist
|
||||||
|
clean-dist: ## Clean up the distribution
|
||||||
|
rm -rf dist/ *.egg-info build/
|
||||||
|
|
||||||
|
.PHONY: package
|
||||||
|
package: clean-dist ## Package the project for distribution
|
||||||
|
IS_DEV_MODE=false python setup.py sdist bdist_wheel
|
||||||
|
|
||||||
|
.PHONY: upload
|
||||||
|
upload: package ## Upload the package to PyPI
|
||||||
|
# upload to testpypi: twine upload --repository testpypi dist/*
|
||||||
|
twine upload dist/*
|
||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help: ## Display this help screen
|
help: ## Display this help screen
|
||||||
@echo "Available commands:"
|
@echo "Available commands:"
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
from dbgpt.component import BaseComponent, SystemApp
|
"""DB-GPT: Next Generation Data Interaction Solution with LLMs.
|
||||||
|
"""
|
||||||
__ALL__ = ["SystemApp", "BaseComponent"]
|
from dbgpt import _version # noqa: E402
|
||||||
|
from dbgpt.component import BaseComponent, SystemApp # noqa: F401
|
||||||
|
|
||||||
_CORE_LIBS = ["core", "rag", "model", "agent", "datasource", "vis", "storage", "train"]
|
_CORE_LIBS = ["core", "rag", "model", "agent", "datasource", "vis", "storage", "train"]
|
||||||
_SERVE_LIBS = ["serve"]
|
_SERVE_LIBS = ["serve"]
|
||||||
_LIBS = _CORE_LIBS + _SERVE_LIBS
|
_LIBS = _CORE_LIBS + _SERVE_LIBS
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = _version.version
|
||||||
|
|
||||||
|
__ALL__ = ["__version__", "SystemApp", "BaseComponent"]
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str):
|
def __getattr__(name: str):
|
||||||
# Lazy load
|
# Lazy load
|
||||||
import importlib
|
import importlib
|
||||||
|
1
dbgpt/_version.py
Normal file
1
dbgpt/_version.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
version = "0.4.7"
|
@ -579,7 +579,7 @@ if __name__ == "__main__":
|
|||||||
from dbgpt.agent.agents.agent import AgentContext
|
from dbgpt.agent.agents.agent import AgentContext
|
||||||
from dbgpt.agent.agents.user_proxy_agent import UserProxyAgent
|
from dbgpt.agent.agents.user_proxy_agent import UserProxyAgent
|
||||||
from dbgpt.core.interface.llm import ModelMetadata
|
from dbgpt.core.interface.llm import ModelMetadata
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
|
|
||||||
llm_client = OpenAILLMClient()
|
llm_client = OpenAILLMClient()
|
||||||
context: AgentContext = AgentContext(
|
context: AgentContext = AgentContext(
|
||||||
|
@ -9,7 +9,7 @@ from functools import cache
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||||
from dbgpt.model.conversation import Conversation, get_conv_template
|
from dbgpt.model.llm.conversation import Conversation, get_conv_template
|
||||||
|
|
||||||
|
|
||||||
class BaseChatAdpter:
|
class BaseChatAdpter:
|
||||||
@ -21,7 +21,7 @@ class BaseChatAdpter:
|
|||||||
|
|
||||||
def get_generate_stream_func(self, model_path: str):
|
def get_generate_stream_func(self, model_path: str):
|
||||||
"""Return the generate stream handler func"""
|
"""Return the generate stream handler func"""
|
||||||
from dbgpt.model.inference import generate_stream
|
from dbgpt.model.llm.inference import generate_stream
|
||||||
|
|
||||||
return generate_stream
|
return generate_stream
|
||||||
|
|
||||||
|
@ -171,13 +171,13 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
|
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
|
||||||
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
|
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
|
||||||
return await llm_task.call(call_data={"data": request})
|
return await llm_task.call(call_data=request)
|
||||||
|
|
||||||
async def call_streaming_operator(
|
async def call_streaming_operator(
|
||||||
self, request: ModelRequest
|
self, request: ModelRequest
|
||||||
) -> AsyncIterator[ModelOutput]:
|
) -> AsyncIterator[ModelOutput]:
|
||||||
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
|
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
|
||||||
async for out in await llm_task.call_stream(call_data={"data": request}):
|
async for out in await llm_task.call_stream(call_data=request):
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
@ -251,11 +251,9 @@ class BaseChat(ABC):
|
|||||||
str_history=self.prompt_template.str_history,
|
str_history=self.prompt_template.str_history,
|
||||||
request_context=req_ctx,
|
request_context=req_ctx,
|
||||||
)
|
)
|
||||||
node_input = {
|
node_input = ChatComposerInput(
|
||||||
"data": ChatComposerInput(
|
|
||||||
messages=self.history_messages, prompt_dict=input_values
|
messages=self.history_messages, prompt_dict=input_values
|
||||||
)
|
)
|
||||||
}
|
|
||||||
# llm_messages = self.generate_llm_messages()
|
# llm_messages = self.generate_llm_messages()
|
||||||
model_request: ModelRequest = await node.call(call_data=node_input)
|
model_request: ModelRequest = await node.call(call_data=node_input)
|
||||||
model_request.context.cache_enable = self.model_cache_enable
|
model_request.context.cache_enable = self.model_cache_enable
|
||||||
|
@ -87,7 +87,7 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
|
|||||||
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
||||||
# Sub dag, use the same dag context in the parent dag
|
# Sub dag, use the same dag context in the parent dag
|
||||||
messages = await end_node.call(
|
messages = await end_node.call(
|
||||||
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
call_data=input_value, dag_ctx=self.current_dag_context
|
||||||
)
|
)
|
||||||
span_id = self._request_context.span_id
|
span_id = self._request_context.span_id
|
||||||
model_request = ModelRequest.build_request(
|
model_request = ModelRequest.build_request(
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
"""Component module for dbgpt.
|
||||||
|
|
||||||
|
Manages the lifecycle and registration of components.
|
||||||
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -22,6 +22,7 @@ from .operator.common_operator import (
|
|||||||
JoinOperator,
|
JoinOperator,
|
||||||
MapOperator,
|
MapOperator,
|
||||||
ReduceStreamOperator,
|
ReduceStreamOperator,
|
||||||
|
TriggerOperator,
|
||||||
)
|
)
|
||||||
from .operator.stream_operator import (
|
from .operator.stream_operator import (
|
||||||
StreamifyAbsOperator,
|
StreamifyAbsOperator,
|
||||||
@ -50,6 +51,7 @@ __all__ = [
|
|||||||
"BaseOperator",
|
"BaseOperator",
|
||||||
"JoinOperator",
|
"JoinOperator",
|
||||||
"ReduceStreamOperator",
|
"ReduceStreamOperator",
|
||||||
|
"TriggerOperator",
|
||||||
"MapOperator",
|
"MapOperator",
|
||||||
"BranchOperator",
|
"BranchOperator",
|
||||||
"InputOperator",
|
"InputOperator",
|
||||||
@ -150,4 +152,6 @@ def setup_dev_environment(
|
|||||||
for trigger in dag.trigger_nodes:
|
for trigger in dag.trigger_nodes:
|
||||||
trigger_manager.register_trigger(trigger)
|
trigger_manager.register_trigger(trigger)
|
||||||
trigger_manager.after_register()
|
trigger_manager.after_register()
|
||||||
|
if trigger_manager.keep_running():
|
||||||
|
# Should keep running
|
||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
@ -28,7 +28,7 @@ from ..task.base import OUT, T, TaskOutput
|
|||||||
|
|
||||||
F = TypeVar("F", bound=FunctionType)
|
F = TypeVar("F", bound=FunctionType)
|
||||||
|
|
||||||
CALL_DATA = Union[Dict, Dict[str, Dict]]
|
CALL_DATA = Union[Dict[str, Any], Any]
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunner(ABC, Generic[T]):
|
class WorkflowRunner(ABC, Generic[T]):
|
||||||
@ -197,6 +197,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
Returns:
|
Returns:
|
||||||
OUT: The output of the node after execution.
|
OUT: The output of the node after execution.
|
||||||
"""
|
"""
|
||||||
|
if call_data:
|
||||||
|
call_data = {"data": call_data}
|
||||||
out_ctx = await self._runner.execute_workflow(
|
out_ctx = await self._runner.execute_workflow(
|
||||||
self, call_data, exist_dag_ctx=dag_ctx
|
self, call_data, exist_dag_ctx=dag_ctx
|
||||||
)
|
)
|
||||||
@ -242,6 +244,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
Returns:
|
Returns:
|
||||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||||
"""
|
"""
|
||||||
|
if call_data:
|
||||||
|
call_data = {"data": call_data}
|
||||||
out_ctx = await self._runner.execute_workflow(
|
out_ctx = await self._runner.execute_workflow(
|
||||||
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
||||||
)
|
)
|
||||||
|
@ -28,6 +28,14 @@ EMPTY_DATA = _EMPTY_DATA_TYPE()
|
|||||||
SKIP_DATA = _EMPTY_DATA_TYPE()
|
SKIP_DATA = _EMPTY_DATA_TYPE()
|
||||||
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
|
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
|
||||||
|
|
||||||
|
|
||||||
|
def is_empty_data(data: Any):
|
||||||
|
"""Check if the data is empty."""
|
||||||
|
if isinstance(data, _EMPTY_DATA_TYPE):
|
||||||
|
return data in (EMPTY_DATA, SKIP_DATA)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||||
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||||
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
|
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
|
||||||
|
@ -24,7 +24,6 @@ from .base import (
|
|||||||
EMPTY_DATA,
|
EMPTY_DATA,
|
||||||
OUT,
|
OUT,
|
||||||
PLACEHOLDER_DATA,
|
PLACEHOLDER_DATA,
|
||||||
SKIP_DATA,
|
|
||||||
InputContext,
|
InputContext,
|
||||||
InputSource,
|
InputSource,
|
||||||
MapFunc,
|
MapFunc,
|
||||||
@ -37,6 +36,7 @@ from .base import (
|
|||||||
TaskState,
|
TaskState,
|
||||||
TransformFunc,
|
TransformFunc,
|
||||||
UnStreamFunc,
|
UnStreamFunc,
|
||||||
|
is_empty_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -99,7 +99,7 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
|||||||
@property
|
@property
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
"""Return True if the output data is empty."""
|
"""Return True if the output data is empty."""
|
||||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
return is_empty_data(self._data)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_none(self) -> bool:
|
def is_none(self) -> bool:
|
||||||
@ -171,7 +171,7 @@ class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
|||||||
@property
|
@property
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
"""Return True if the output data is empty."""
|
"""Return True if the output data is empty."""
|
||||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
return is_empty_data(self._data)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_none(self) -> bool:
|
def is_none(self) -> bool:
|
||||||
@ -330,7 +330,7 @@ class SimpleCallDataInputSource(BaseInputSource):
|
|||||||
"""
|
"""
|
||||||
call_data = task_ctx.call_data
|
call_data = task_ctx.call_data
|
||||||
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
|
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
|
||||||
if data == EMPTY_DATA:
|
if is_empty_data(data):
|
||||||
raise ValueError("No call data for current SimpleCallDataInputSource")
|
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -1,12 +1,8 @@
|
|||||||
"""Http trigger for AWEL."""
|
"""Http trigger for AWEL."""
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||||
|
|
||||||
from starlette.requests import Request
|
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel
|
from dbgpt._private.pydantic import BaseModel
|
||||||
|
|
||||||
from ..dag.base import DAG
|
from ..dag.base import DAG
|
||||||
@ -15,9 +11,10 @@ from .base import Trigger
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
||||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -32,9 +29,9 @@ class HttpTrigger(Trigger):
|
|||||||
self,
|
self,
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
methods: Optional[Union[str, List[str]]] = "GET",
|
methods: Optional[Union[str, List[str]]] = "GET",
|
||||||
request_body: Optional[RequestBody] = None,
|
request_body: Optional["RequestBody"] = None,
|
||||||
streaming_response: bool = False,
|
streaming_response: bool = False,
|
||||||
streaming_predict_func: Optional[StreamingPredictFunc] = None,
|
streaming_predict_func: Optional["StreamingPredictFunc"] = None,
|
||||||
response_model: Optional[Type] = None,
|
response_model: Optional[Type] = None,
|
||||||
response_headers: Optional[Dict[str, str]] = None,
|
response_headers: Optional[Dict[str, str]] = None,
|
||||||
response_media_type: Optional[str] = None,
|
response_media_type: Optional[str] = None,
|
||||||
@ -69,6 +66,7 @@ class HttpTrigger(Trigger):
|
|||||||
router (APIRouter): The router to mount the trigger.
|
router (APIRouter): The router to mount the trigger.
|
||||||
"""
|
"""
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
methods = [self._methods] if isinstance(self._methods, str) else self._methods
|
methods = [self._methods] if isinstance(self._methods, str) else self._methods
|
||||||
|
|
||||||
@ -114,8 +112,10 @@ class HttpTrigger(Trigger):
|
|||||||
|
|
||||||
|
|
||||||
async def _parse_request_body(
|
async def _parse_request_body(
|
||||||
request: Request, request_body_cls: Optional[RequestBody]
|
request: "Request", request_body_cls: Optional["RequestBody"]
|
||||||
):
|
):
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
if not request_body_cls:
|
if not request_body_cls:
|
||||||
return None
|
return None
|
||||||
if request_body_cls == Request:
|
if request_body_cls == Request:
|
||||||
@ -152,7 +152,7 @@ async def _trigger_dag(
|
|||||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||||
if not streaming_response:
|
if not streaming_response:
|
||||||
return await end_node.call(call_data={"data": body})
|
return await end_node.call(call_data=body)
|
||||||
else:
|
else:
|
||||||
headers = response_headers
|
headers = response_headers
|
||||||
media_type = response_media_type if response_media_type else "text/event-stream"
|
media_type = response_media_type if response_media_type else "text/event-stream"
|
||||||
@ -163,7 +163,7 @@ async def _trigger_dag(
|
|||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"Transfer-Encoding": "chunked",
|
"Transfer-Encoding": "chunked",
|
||||||
}
|
}
|
||||||
generator = await end_node.call_stream(call_data={"data": body})
|
generator = await end_node.call_stream(call_data=body)
|
||||||
background_tasks = BackgroundTasks()
|
background_tasks = BackgroundTasks()
|
||||||
background_tasks.add_task(dag._after_dag_end)
|
background_tasks.add_task(dag._after_dag_end)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
@ -24,6 +24,14 @@ class TriggerManager(ABC):
|
|||||||
def register_trigger(self, trigger: Any) -> None:
|
def register_trigger(self, trigger: Any) -> None:
|
||||||
"""Register a trigger to current manager."""
|
"""Register a trigger to current manager."""
|
||||||
|
|
||||||
|
def keep_running(self) -> bool:
|
||||||
|
"""Whether keep running.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether keep running, True means keep running, False means stop.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class HttpTriggerManager(TriggerManager):
|
class HttpTriggerManager(TriggerManager):
|
||||||
"""Http trigger manager.
|
"""Http trigger manager.
|
||||||
@ -64,6 +72,8 @@ class HttpTriggerManager(TriggerManager):
|
|||||||
self._trigger_map[trigger_id] = trigger
|
self._trigger_map[trigger_id] = trigger
|
||||||
|
|
||||||
def _init_app(self, system_app: SystemApp):
|
def _init_app(self, system_app: SystemApp):
|
||||||
|
if not self.keep_running():
|
||||||
|
return
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Include router {self._router} to prefix path {self._router_prefix}"
|
f"Include router {self._router} to prefix path {self._router_prefix}"
|
||||||
)
|
)
|
||||||
@ -72,6 +82,14 @@ class HttpTriggerManager(TriggerManager):
|
|||||||
raise RuntimeError("System app not initialized")
|
raise RuntimeError("System app not initialized")
|
||||||
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
|
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
|
||||||
|
|
||||||
|
def keep_running(self) -> bool:
|
||||||
|
"""Whether keep running.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether keep running, True means keep running, False means stop.
|
||||||
|
"""
|
||||||
|
return len(self._trigger_map) > 0
|
||||||
|
|
||||||
|
|
||||||
class DefaultTriggerManager(TriggerManager, BaseComponent):
|
class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||||
"""Default trigger manager for AWEL.
|
"""Default trigger manager for AWEL.
|
||||||
@ -105,3 +123,11 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
|||||||
"""After register, init the trigger manager."""
|
"""After register, init the trigger manager."""
|
||||||
if self.system_app:
|
if self.system_app:
|
||||||
self.http_trigger._init_app(self.system_app)
|
self.http_trigger._init_app(self.system_app)
|
||||||
|
|
||||||
|
def keep_running(self) -> bool:
|
||||||
|
"""Whether keep running.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether keep running, True means keep running, False means stop.
|
||||||
|
"""
|
||||||
|
return self.http_trigger.keep_running()
|
||||||
|
@ -70,7 +70,7 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
|||||||
end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0])
|
end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0])
|
||||||
# Sub dag, use the same dag context in the parent dag
|
# Sub dag, use the same dag context in the parent dag
|
||||||
return await end_node.call(
|
return await end_node.call(
|
||||||
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
call_data=input_value, dag_ctx=self.current_dag_context
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_composer_dag(self) -> DAG:
|
def _build_composer_dag(self) -> DAG:
|
||||||
|
@ -150,7 +150,7 @@ class PromptBuilderOperator(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
single_input = {"data": {"dialect": "mysql"}}
|
single_input = {"dialect": "mysql"}
|
||||||
single_expected_messages = [
|
single_expected_messages = [
|
||||||
ModelMessage(
|
ModelMessage(
|
||||||
content="Please write a mysql SQL count the length of a field",
|
content="Please write a mysql SQL count the length of a field",
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
try:
|
||||||
|
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||||
|
except ImportError as exc:
|
||||||
|
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
|
||||||
|
DefaultLLMClient = None
|
||||||
|
|
||||||
# from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
|
|
||||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
|
||||||
|
|
||||||
__ALL__ = [
|
_exports = []
|
||||||
"DefaultLLMClient",
|
if DefaultLLMClient:
|
||||||
"OpenAILLMClient",
|
_exports.append("DefaultLLMClient")
|
||||||
]
|
|
||||||
|
__ALL__ = _exports
|
||||||
|
@ -137,7 +137,7 @@ class ModelLoader:
|
|||||||
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
|
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dbgpt.model.compression import compress_module
|
from dbgpt.model.llm.compression import compress_module
|
||||||
|
|
||||||
device = model_params.device
|
device = model_params.device
|
||||||
max_memory = None
|
max_memory = None
|
@ -19,7 +19,7 @@ from dbgpt.configs.model_config import get_device
|
|||||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||||
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
|
||||||
from dbgpt.model.base import ModelType
|
from dbgpt.model.base import ModelType
|
||||||
from dbgpt.model.conversation import Conversation
|
from dbgpt.model.llm.conversation import Conversation
|
||||||
from dbgpt.model.parameter import (
|
from dbgpt.model.parameter import (
|
||||||
LlamaCppModelParameters,
|
LlamaCppModelParameters,
|
||||||
ModelParameters,
|
ModelParameters,
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import time
|
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, TypedDict
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from dbgpt.util.model_utils import GPUInfo
|
|
||||||
from dbgpt.util.parameter_utils import ParameterDescription
|
from dbgpt.util.parameter_utils import ParameterDescription
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@ from dbgpt.core import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
)
|
)
|
||||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||||
|
from dbgpt.model.adapter.loader import ModelLoader, _get_model_real_path
|
||||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||||
from dbgpt.model.loader import ModelLoader, _get_model_real_path
|
|
||||||
from dbgpt.model.parameter import ModelParameters
|
from dbgpt.model.parameter import ModelParameters
|
||||||
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
|
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
|
||||||
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
|
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
|
||||||
|
@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Type
|
|||||||
|
|
||||||
from dbgpt.configs.model_config import get_device
|
from dbgpt.configs.model_config import get_device
|
||||||
from dbgpt.core import ModelMetadata
|
from dbgpt.core import ModelMetadata
|
||||||
|
from dbgpt.model.adapter.loader import _get_model_real_path
|
||||||
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
|
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
|
||||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||||
from dbgpt.model.loader import _get_model_real_path
|
|
||||||
from dbgpt.model.parameter import (
|
from dbgpt.model.parameter import (
|
||||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||||
BaseEmbeddingModelParameters,
|
BaseEmbeddingModelParameters,
|
||||||
|
@ -8,7 +8,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Awaitable, Callable, Dict, Iterator, List
|
from typing import Awaitable, Callable, Iterator
|
||||||
|
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, FastAPI
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@ -16,12 +16,7 @@ from fastapi.responses import StreamingResponse
|
|||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.configs.model_config import LOGDIR
|
from dbgpt.configs.model_config import LOGDIR
|
||||||
from dbgpt.core import ModelMetadata, ModelOutput
|
from dbgpt.core import ModelMetadata, ModelOutput
|
||||||
from dbgpt.model.base import (
|
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
|
||||||
ModelInstance,
|
|
||||||
WorkerApplyOutput,
|
|
||||||
WorkerApplyType,
|
|
||||||
WorkerSupportedModel,
|
|
||||||
)
|
|
||||||
from dbgpt.model.cluster.base import *
|
from dbgpt.model.cluster.base import *
|
||||||
from dbgpt.model.cluster.manager_base import (
|
from dbgpt.model.cluster.manager_base import (
|
||||||
WorkerManager,
|
WorkerManager,
|
||||||
@ -30,8 +25,8 @@ from dbgpt.model.cluster.manager_base import (
|
|||||||
)
|
)
|
||||||
from dbgpt.model.cluster.registry import ModelRegistry
|
from dbgpt.model.cluster.registry import ModelRegistry
|
||||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||||
from dbgpt.model.llm_utils import list_supported_models
|
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
|
||||||
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
|
from dbgpt.model.utils.llm_utils import list_supported_models
|
||||||
from dbgpt.util.parameter_utils import (
|
from dbgpt.util.parameter_utils import (
|
||||||
EnvArgumentParser,
|
EnvArgumentParser,
|
||||||
ParameterDescription,
|
ParameterDescription,
|
||||||
|
@ -18,7 +18,7 @@ from transformers.generation.logits_process import (
|
|||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
)
|
)
|
||||||
|
|
||||||
from dbgpt.model.llm_utils import is_partial_stop, is_sentence_complete
|
from dbgpt.model.utils.llm_utils import is_partial_stop, is_sentence_complete
|
||||||
|
|
||||||
|
|
||||||
def prepare_logits_processor(
|
def prepare_logits_processor(
|
@ -1,9 +1,9 @@
|
|||||||
from dbgpt.model.operator.llm_operator import (
|
from dbgpt.model.operator.llm_operator import ( # noqa: F401
|
||||||
LLMOperator,
|
LLMOperator,
|
||||||
MixinLLMOperator,
|
MixinLLMOperator,
|
||||||
StreamingLLMOperator,
|
StreamingLLMOperator,
|
||||||
)
|
)
|
||||||
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator # noqa: F401
|
||||||
|
|
||||||
__ALL__ = [
|
__ALL__ = [
|
||||||
"MixinLLMOperator",
|
"MixinLLMOperator",
|
||||||
|
@ -6,7 +6,6 @@ from dbgpt.component import ComponentType
|
|||||||
from dbgpt.core import LLMClient
|
from dbgpt.core import LLMClient
|
||||||
from dbgpt.core.awel import BaseOperator
|
from dbgpt.core.awel import BaseOperator
|
||||||
from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator
|
from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator
|
||||||
from dbgpt.model.cluster import WorkerManagerFactory
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -19,11 +18,14 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
|||||||
|
|
||||||
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
||||||
super().__init__(default_client)
|
super().__init__(default_client)
|
||||||
self._default_llm_client = default_client
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def llm_client(self) -> LLMClient:
|
def llm_client(self) -> LLMClient:
|
||||||
if not self._llm_client:
|
if not self._llm_client:
|
||||||
|
try:
|
||||||
|
from dbgpt.model.cluster import WorkerManagerFactory
|
||||||
|
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||||
|
|
||||||
worker_manager_factory: WorkerManagerFactory = (
|
worker_manager_factory: WorkerManagerFactory = (
|
||||||
self.system_app.get_component(
|
self.system_app.get_component(
|
||||||
ComponentType.WORKER_MANAGER_FACTORY,
|
ComponentType.WORKER_MANAGER_FACTORY,
|
||||||
@ -32,18 +34,14 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if worker_manager_factory:
|
if worker_manager_factory:
|
||||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
|
||||||
|
|
||||||
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
||||||
else:
|
except Exception as e:
|
||||||
if self._default_llm_client is None:
|
logger.warning(f"Load worker manager failed: {e}.")
|
||||||
|
if not self._llm_client:
|
||||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||||
|
|
||||||
self._default_llm_client = OpenAILLMClient()
|
logger.info("Can't find worker manager factory, use OpenAILLMClient.")
|
||||||
logger.info(
|
self._llm_client = OpenAILLMClient()
|
||||||
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
|
|
||||||
)
|
|
||||||
self._llm_client = self._default_llm_client
|
|
||||||
return self._llm_client
|
return self._llm_client
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,11 +6,8 @@ from dataclasses import dataclass, field
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
from dbgpt.model.conversation import conv_templates
|
|
||||||
from dbgpt.util.parameter_utils import BaseParameters
|
from dbgpt.util.parameter_utils import BaseParameters
|
||||||
|
|
||||||
suported_prompt_templates = ",".join(conv_templates.keys())
|
|
||||||
|
|
||||||
|
|
||||||
class WorkerType(str, Enum):
|
class WorkerType(str, Enum):
|
||||||
LLM = "llm"
|
LLM = "llm"
|
||||||
@ -299,7 +296,8 @@ class ModelParameters(BaseModelParameters):
|
|||||||
prompt_template: Optional[str] = field(
|
prompt_template: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
|
"help": f"Prompt template. If None, the prompt template is automatically "
|
||||||
|
f"determined from model path"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_context_size: Optional[int] = field(
|
max_context_size: Optional[int] = field(
|
||||||
@ -450,7 +448,8 @@ class ProxyModelParameters(BaseModelParameters):
|
|||||||
proxyllm_backend: Optional[str] = field(
|
proxyllm_backend: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The model name actually pass to current proxy server url, such as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
|
"help": "The model name actually pass to current proxy server url, such "
|
||||||
|
"as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@ -463,13 +462,15 @@ class ProxyModelParameters(BaseModelParameters):
|
|||||||
device: Optional[str] = field(
|
device: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Device to run model. If None, the device is automatically determined"
|
"help": "Device to run model. If None, the device is automatically "
|
||||||
|
"determined"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
prompt_template: Optional[str] = field(
|
prompt_template: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
|
"help": f"Prompt template. If None, the prompt template is automatically "
|
||||||
|
f"determined from model path"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_context_size: Optional[int] = field(
|
max_context_size: Optional[int] = field(
|
||||||
@ -478,7 +479,8 @@ class ProxyModelParameters(BaseModelParameters):
|
|||||||
llm_client_class: Optional[str] = field(
|
llm_client_class: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The class name of llm client, such as dbgpt.model.proxy.llms.proxy_model.ProxyModel"
|
"help": "The class name of llm client, such as "
|
||||||
|
"dbgpt.model.proxy.llms.proxy_model.ProxyModel"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,8 +37,8 @@ def list_supported_models():
|
|||||||
def _list_supported_models(
|
def _list_supported_models(
|
||||||
worker_type: str, model_config: Dict[str, str]
|
worker_type: str, model_config: Dict[str, str]
|
||||||
) -> List[SupportedModel]:
|
) -> List[SupportedModel]:
|
||||||
|
from dbgpt.model.adapter.loader import _get_model_real_path
|
||||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||||
from dbgpt.model.loader import _get_model_real_path
|
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
for model_name, model_path in model_config.items():
|
for model_name, model_path in model_config.items():
|
@ -67,7 +67,7 @@ class AwelLayoutChatManager(ManagerAgent):
|
|||||||
message=start_message, sender=self, reviewer=reviewer
|
message=start_message, sender=self, reviewer=reviewer
|
||||||
)
|
)
|
||||||
final_generate_context: AgentGenerateContext = await last_node.call(
|
final_generate_context: AgentGenerateContext = await last_node.call(
|
||||||
call_data={"data": start_message_context}
|
call_data=start_message_context
|
||||||
)
|
)
|
||||||
last_message = final_generate_context.rely_messages[-1]
|
last_message = final_generate_context.rely_messages[-1]
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from dbgpt.core.interface.llm import ModelMetadata
|
|||||||
from dbgpt.serve.agent.team.plan.team_auto_plan import AutoPlanChatManager
|
from dbgpt.serve.agent.team.plan.team_auto_plan import AutoPlanChatManager
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
|
|
||||||
llm_client = OpenAILLMClient()
|
llm_client = OpenAILLMClient()
|
||||||
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
||||||
|
@ -30,7 +30,7 @@ parent_dir = os.path.dirname(current_dir)
|
|||||||
test_plugin_dir = os.path.join(parent_dir, "test_files")
|
test_plugin_dir = os.path.join(parent_dir, "test_files")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
|
|
||||||
llm_client = OpenAILLMClient()
|
llm_client = OpenAILLMClient()
|
||||||
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
||||||
|
@ -24,7 +24,7 @@ from dbgpt.agent.memory.gpts_memory import GptsMemory
|
|||||||
from dbgpt.core.interface.llm import ModelMetadata
|
from dbgpt.core.interface.llm import ModelMetadata
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
|
|
||||||
llm_client = OpenAILLMClient()
|
llm_client = OpenAILLMClient()
|
||||||
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
||||||
|
@ -27,7 +27,7 @@ from dbgpt.core.interface.llm import ModelMetadata
|
|||||||
|
|
||||||
|
|
||||||
def summary_example_with_success():
|
def summary_example_with_success():
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
|
|
||||||
llm_client = OpenAILLMClient()
|
llm_client = OpenAILLMClient()
|
||||||
context: AgentContext = AgentContext(
|
context: AgentContext = AgentContext(
|
||||||
|
@ -24,7 +24,7 @@ from dbgpt.agent.memory.gpts_memory import GptsMemory
|
|||||||
from dbgpt.core.interface.llm import ModelMetadata
|
from dbgpt.core.interface.llm import ModelMetadata
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
|
|
||||||
llm_client = OpenAILLMClient()
|
llm_client = OpenAILLMClient()
|
||||||
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
context: AgentContext = AgentContext(conv_id="test456", llm_provider=llm_client)
|
||||||
|
@ -25,7 +25,7 @@ from dbgpt.core.interface.llm import ModelMetadata
|
|||||||
|
|
||||||
|
|
||||||
def summary_example_with_success():
|
def summary_example_with_success():
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
|
|
||||||
llm_client = OpenAILLMClient()
|
llm_client = OpenAILLMClient()
|
||||||
context: AgentContext = AgentContext(conv_id="summarize", llm_provider=llm_client)
|
context: AgentContext = AgentContext(conv_id="summarize", llm_provider=llm_client)
|
||||||
@ -76,7 +76,7 @@ def summary_example_with_success():
|
|||||||
|
|
||||||
|
|
||||||
def summary_example_with_faliure():
|
def summary_example_with_faliure():
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
|
|
||||||
llm_client = OpenAILLMClient()
|
llm_client = OpenAILLMClient()
|
||||||
context: AgentContext = AgentContext(conv_id="summarize", llm_provider=llm_client)
|
context: AgentContext = AgentContext(conv_id="summarize", llm_provider=llm_client)
|
||||||
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
from dbgpt.rag.chunk import Chunk
|
from dbgpt.rag.chunk import Chunk
|
||||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
from dbgpt.rag.operator.embedding import EmbeddingRetrieverOperator
|
from dbgpt.rag.operator.embedding import EmbeddingRetrieverOperator
|
||||||
|
@ -32,7 +32,7 @@ from typing import Dict
|
|||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
from dbgpt.rag.operator.rewrite import QueryRewriteOperator
|
from dbgpt.rag.operator.rewrite import QueryRewriteOperator
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ from typing import Dict
|
|||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||||
from dbgpt.rag.operator.knowledge import KnowledgeOperator
|
from dbgpt.rag.operator.knowledge import KnowledgeOperator
|
||||||
from dbgpt.rag.operator.summary import SummaryAssemblerOperator
|
from dbgpt.rag.operator.summary import SummaryAssemblerOperator
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||||
|
|
||||||
"""Query rewrite example.
|
"""Query rewrite example.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||||
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
from dbgpt.rag.knowledge.factory import KnowledgeFactory
|
||||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||||
|
@ -7,7 +7,7 @@ from dbgpt.core.operator import (
|
|||||||
PromptBuilderOperator,
|
PromptBuilderOperator,
|
||||||
RequestBuilderOperator,
|
RequestBuilderOperator,
|
||||||
)
|
)
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
|
|
||||||
with DAG("simple_sdk_llm_example_dag") as dag:
|
with DAG("simple_sdk_llm_example_dag") as dag:
|
||||||
prompt_task = PromptBuilderOperator(
|
prompt_task = PromptBuilderOperator(
|
||||||
@ -20,8 +20,6 @@ with DAG("simple_sdk_llm_example_dag") as dag:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
output = asyncio.run(
|
output = asyncio.run(
|
||||||
out_parse_task.call(
|
out_parse_task.call(call_data={"dialect": "mysql", "table_name": "user"})
|
||||||
call_data={"data": {"dialect": "mysql", "table_name": "user"}}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
print(f"output: \n\n{output}")
|
print(f"output: \n\n{output}")
|
||||||
|
@ -17,7 +17,7 @@ from dbgpt.core.operator import (
|
|||||||
)
|
)
|
||||||
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
|
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
|
||||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model.proxy import OpenAILLMClient
|
||||||
from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator
|
from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator
|
||||||
|
|
||||||
|
|
||||||
@ -144,13 +144,11 @@ with DAG("simple_sdk_llm_sql_example") as dag:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
input_data = {
|
input_data = {
|
||||||
"data": {
|
|
||||||
"db_name": "test_db",
|
"db_name": "test_db",
|
||||||
"dialect": "sqlite",
|
"dialect": "sqlite",
|
||||||
"top_k": 5,
|
"top_k": 5,
|
||||||
"user_input": "What is the name and age of the user with age less than 18",
|
"user_input": "What is the name and age of the user with age less than 18",
|
||||||
}
|
}
|
||||||
}
|
|
||||||
output = asyncio.run(sql_result_task.call(call_data=input_data))
|
output = asyncio.run(sql_result_task.call(call_data=input_data))
|
||||||
print(f"\nthoughts: {output.get('thoughts')}\n")
|
print(f"\nthoughts: {output.get('thoughts')}\n")
|
||||||
print(f"sql: {output.get('sql')}\n")
|
print(f"sql: {output.get('sql')}\n")
|
||||||
|
91
setup.py
91
setup.py
@ -14,6 +14,11 @@ import functools
|
|||||||
with open("README.md", mode="r", encoding="utf-8") as fh:
|
with open("README.md", mode="r", encoding="utf-8") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
|
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
|
||||||
|
# If you modify the version, please modify the version in the following files:
|
||||||
|
# dbgpt/_version.py
|
||||||
|
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7")
|
||||||
|
|
||||||
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
||||||
LLAMA_CPP_GPU_ACCELERATION = (
|
LLAMA_CPP_GPU_ACCELERATION = (
|
||||||
os.getenv("LLAMA_CPP_GPU_ACCELERATION", "true").lower() == "true"
|
os.getenv("LLAMA_CPP_GPU_ACCELERATION", "true").lower() == "true"
|
||||||
@ -352,31 +357,41 @@ def llama_cpp_python_cuda_requires():
|
|||||||
|
|
||||||
def core_requires():
|
def core_requires():
|
||||||
"""
|
"""
|
||||||
pip install db-gpt or pip install "db-gpt[core]"
|
pip install dbgpt or pip install "dbgpt[core]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["core"] = [
|
setup_spec.extras["core"] = [
|
||||||
"aiohttp==3.8.4",
|
"aiohttp==3.8.4",
|
||||||
"chardet==5.1.0",
|
"chardet==5.1.0",
|
||||||
"importlib-resources==5.12.0",
|
"importlib-resources==5.12.0",
|
||||||
"psutil==5.9.4",
|
|
||||||
"python-dotenv==1.0.0",
|
"python-dotenv==1.0.0",
|
||||||
"colorama==0.4.6",
|
|
||||||
"prettytable",
|
|
||||||
"cachetools",
|
"cachetools",
|
||||||
|
"pydantic<2,>=1",
|
||||||
]
|
]
|
||||||
# Just use by DB-GPT internal, we should find the smallest dependency set for run we core unit test.
|
# Simple command line dependencies
|
||||||
|
setup_spec.extras["cli"] = setup_spec.extras["core"] + [
|
||||||
|
"prettytable",
|
||||||
|
"click",
|
||||||
|
"psutil==5.9.4",
|
||||||
|
"colorama==0.4.6",
|
||||||
|
]
|
||||||
|
# Just use by DB-GPT internal, we should find the smallest dependency set for run
|
||||||
|
# we core unit test.
|
||||||
# The dependency "framework" is too large for now.
|
# The dependency "framework" is too large for now.
|
||||||
setup_spec.extras["simple_framework"] = setup_spec.extras["core"] + [
|
setup_spec.extras["simple_framework"] = setup_spec.extras["cli"] + [
|
||||||
"pydantic<2,>=1",
|
"pydantic<2,>=1",
|
||||||
"httpx",
|
"httpx",
|
||||||
"jinja2",
|
"jinja2",
|
||||||
"fastapi==0.98.0",
|
"fastapi==0.98.0",
|
||||||
|
"uvicorn",
|
||||||
"shortuuid",
|
"shortuuid",
|
||||||
# change from fixed version 2.0.22 to variable version, because other dependencies are >=1.4, such as pydoris is <2
|
# change from fixed version 2.0.22 to variable version, because other
|
||||||
|
# dependencies are >=1.4, such as pydoris is <2
|
||||||
"SQLAlchemy>=1.4,<3",
|
"SQLAlchemy>=1.4,<3",
|
||||||
# for cache
|
# for cache
|
||||||
"msgpack",
|
"msgpack",
|
||||||
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
# for cache
|
||||||
|
# TODO: pympler has not been updated for a long time and needs to
|
||||||
|
# find a new toolkit.
|
||||||
"pympler",
|
"pympler",
|
||||||
"sqlparse==0.4.4",
|
"sqlparse==0.4.4",
|
||||||
"duckdb==0.8.1",
|
"duckdb==0.8.1",
|
||||||
@ -418,7 +433,7 @@ def core_requires():
|
|||||||
|
|
||||||
def knowledge_requires():
|
def knowledge_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[knowledge]"
|
pip install "dbgpt[knowledge]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["knowledge"] = [
|
setup_spec.extras["knowledge"] = [
|
||||||
"spacy==3.5.3",
|
"spacy==3.5.3",
|
||||||
@ -435,7 +450,7 @@ def knowledge_requires():
|
|||||||
|
|
||||||
def llama_cpp_requires():
|
def llama_cpp_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[llama_cpp]"
|
pip install "dbgpt[llama_cpp]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["llama_cpp"] = ["llama-cpp-python"]
|
setup_spec.extras["llama_cpp"] = ["llama-cpp-python"]
|
||||||
llama_cpp_python_cuda_requires()
|
llama_cpp_python_cuda_requires()
|
||||||
@ -523,7 +538,7 @@ def quantization_requires():
|
|||||||
|
|
||||||
def all_vector_store_requires():
|
def all_vector_store_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[vstore]"
|
pip install "dbgpt[vstore]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["vstore"] = [
|
setup_spec.extras["vstore"] = [
|
||||||
"grpcio==1.47.5", # maybe delete it
|
"grpcio==1.47.5", # maybe delete it
|
||||||
@ -534,7 +549,7 @@ def all_vector_store_requires():
|
|||||||
|
|
||||||
def all_datasource_requires():
|
def all_datasource_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[datasource]"
|
pip install "dbgpt[datasource]"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
setup_spec.extras["datasource"] = [
|
setup_spec.extras["datasource"] = [
|
||||||
@ -552,7 +567,7 @@ def all_datasource_requires():
|
|||||||
|
|
||||||
def openai_requires():
|
def openai_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[openai]"
|
pip install "dbgpt[openai]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["openai"] = ["tiktoken"]
|
setup_spec.extras["openai"] = ["tiktoken"]
|
||||||
if BUILD_VERSION_OPENAI:
|
if BUILD_VERSION_OPENAI:
|
||||||
@ -567,28 +582,28 @@ def openai_requires():
|
|||||||
|
|
||||||
def gpt4all_requires():
|
def gpt4all_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[gpt4all]"
|
pip install "dbgpt[gpt4all]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["gpt4all"] = ["gpt4all"]
|
setup_spec.extras["gpt4all"] = ["gpt4all"]
|
||||||
|
|
||||||
|
|
||||||
def vllm_requires():
|
def vllm_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[vllm]"
|
pip install "dbgpt[vllm]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["vllm"] = ["vllm"]
|
setup_spec.extras["vllm"] = ["vllm"]
|
||||||
|
|
||||||
|
|
||||||
def cache_requires():
|
def cache_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[cache]"
|
pip install "dbgpt[cache]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["cache"] = ["rocksdict"]
|
setup_spec.extras["cache"] = ["rocksdict"]
|
||||||
|
|
||||||
|
|
||||||
def default_requires():
|
def default_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[default]"
|
pip install "dbgpt[default]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["default"] = [
|
setup_spec.extras["default"] = [
|
||||||
# "tokenizers==0.13.3",
|
# "tokenizers==0.13.3",
|
||||||
@ -637,14 +652,46 @@ default_requires()
|
|||||||
all_requires()
|
all_requires()
|
||||||
init_install_requires()
|
init_install_requires()
|
||||||
|
|
||||||
|
# Packages to exclude when IS_DEV_MODE is False
|
||||||
|
excluded_packages = ["tests", "*.tests", "*.tests.*", "examples"]
|
||||||
|
|
||||||
|
if IS_DEV_MODE:
|
||||||
|
packages = find_packages(exclude=excluded_packages)
|
||||||
|
else:
|
||||||
|
packages = find_packages(
|
||||||
|
exclude=excluded_packages,
|
||||||
|
include=[
|
||||||
|
"dbgpt",
|
||||||
|
"dbgpt._private",
|
||||||
|
"dbgpt._private.*",
|
||||||
|
"dbgpt.cli",
|
||||||
|
"dbgpt.cli.*",
|
||||||
|
"dbgpt.configs",
|
||||||
|
"dbgpt.configs.*",
|
||||||
|
"dbgpt.core",
|
||||||
|
"dbgpt.core.*",
|
||||||
|
"dbgpt.util",
|
||||||
|
"dbgpt.util.*",
|
||||||
|
"dbgpt.model",
|
||||||
|
"dbgpt.model.proxy",
|
||||||
|
"dbgpt.model.proxy.*",
|
||||||
|
"dbgpt.model.operator",
|
||||||
|
"dbgpt.model.operator.*",
|
||||||
|
"dbgpt.model.utils",
|
||||||
|
"dbgpt.model.utils.*",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="db-gpt",
|
name="dbgpt",
|
||||||
packages=find_packages(exclude=("tests", "*.tests", "*.tests.*", "examples")),
|
packages=packages,
|
||||||
version="0.4.5",
|
version=DB_GPT_VERSION,
|
||||||
author="csunny",
|
author="csunny",
|
||||||
author_email="cfqcsunny@gmail.com",
|
author_email="cfqcsunny@gmail.com",
|
||||||
description="DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment."
|
description="DB-GPT is an experimental open-source project that uses localized GPT "
|
||||||
" With this solution, you can be assured that there is no risk of data leakage, and your data is 100% private and secure.",
|
"large models to interact with your data and environment."
|
||||||
|
" With this solution, you can be assured that there is no risk of data leakage, "
|
||||||
|
"and your data is 100% private and secure.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
install_requires=setup_spec.install_requires,
|
install_requires=setup_spec.install_requires,
|
||||||
|
Loading…
Reference in New Issue
Block a user