refactor: Refactor for core SDK (#1092)

This commit is contained in:
Fangyin Cheng
2024-01-21 09:57:57 +08:00
committed by GitHub
parent ba7248adbb
commit 2d905191f8
45 changed files with 236 additions and 133 deletions

View File

@@ -9,7 +9,7 @@ from functools import cache
from typing import Dict, List, Tuple
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.model.conversation import Conversation, get_conv_template
from dbgpt.model.llm.conversation import Conversation, get_conv_template
class BaseChatAdpter:
@@ -21,7 +21,7 @@ class BaseChatAdpter:
def get_generate_stream_func(self, model_path: str):
"""Return the generate stream handler func"""
from dbgpt.model.inference import generate_stream
from dbgpt.model.llm.inference import generate_stream
return generate_stream

View File

@@ -171,13 +171,13 @@ class BaseChat(ABC):
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
return await llm_task.call(call_data={"data": request})
return await llm_task.call(call_data=request)
async def call_streaming_operator(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
async for out in await llm_task.call_stream(call_data={"data": request}):
async for out in await llm_task.call_stream(call_data=request):
yield out
def do_action(self, prompt_response):
@@ -251,11 +251,9 @@ class BaseChat(ABC):
str_history=self.prompt_template.str_history,
request_context=req_ctx,
)
node_input = {
"data": ChatComposerInput(
messages=self.history_messages, prompt_dict=input_values
)
}
node_input = ChatComposerInput(
messages=self.history_messages, prompt_dict=input_values
)
# llm_messages = self.generate_llm_messages()
model_request: ModelRequest = await node.call(call_data=node_input)
model_request.context.cache_enable = self.model_cache_enable

View File

@@ -87,7 +87,7 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
# Sub dag, use the same dag context in the parent dag
messages = await end_node.call(
call_data={"data": input_value}, dag_ctx=self.current_dag_context
call_data=input_value, dag_ctx=self.current_dag_context
)
span_id = self._request_context.span_id
model_request = ModelRequest.build_request(