mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 22:19:28 +00:00
refactor: Refactor for core SDK (#1092)
This commit is contained in:
@@ -9,7 +9,7 @@ from functools import cache
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.model.conversation import Conversation, get_conv_template
|
||||
from dbgpt.model.llm.conversation import Conversation, get_conv_template
|
||||
|
||||
|
||||
class BaseChatAdpter:
|
||||
@@ -21,7 +21,7 @@ class BaseChatAdpter:
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
"""Return the generate stream handler func"""
|
||||
from dbgpt.model.inference import generate_stream
|
||||
from dbgpt.model.llm.inference import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
|
@@ -171,13 +171,13 @@ class BaseChat(ABC):
|
||||
|
||||
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
|
||||
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
|
||||
return await llm_task.call(call_data={"data": request})
|
||||
return await llm_task.call(call_data=request)
|
||||
|
||||
async def call_streaming_operator(
|
||||
self, request: ModelRequest
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
|
||||
async for out in await llm_task.call_stream(call_data={"data": request}):
|
||||
async for out in await llm_task.call_stream(call_data=request):
|
||||
yield out
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
@@ -251,11 +251,9 @@ class BaseChat(ABC):
|
||||
str_history=self.prompt_template.str_history,
|
||||
request_context=req_ctx,
|
||||
)
|
||||
node_input = {
|
||||
"data": ChatComposerInput(
|
||||
messages=self.history_messages, prompt_dict=input_values
|
||||
)
|
||||
}
|
||||
node_input = ChatComposerInput(
|
||||
messages=self.history_messages, prompt_dict=input_values
|
||||
)
|
||||
# llm_messages = self.generate_llm_messages()
|
||||
model_request: ModelRequest = await node.call(call_data=node_input)
|
||||
model_request.context.cache_enable = self.model_cache_enable
|
||||
|
@@ -87,7 +87,7 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
|
||||
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
messages = await end_node.call(
|
||||
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
||||
call_data=input_value, dag_ctx=self.current_dag_context
|
||||
)
|
||||
span_id = self._request_context.span_id
|
||||
model_request = ModelRequest.build_request(
|
||||
|
Reference in New Issue
Block a user