mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 01:54:44 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -179,4 +179,8 @@ thirdparty
|
|||||||
|
|
||||||
# typescript
|
# typescript
|
||||||
*.tsbuildinfo
|
*.tsbuildinfo
|
||||||
/web/next-env.d.ts
|
/web/next-env.d.ts
|
||||||
|
|
||||||
|
# Ignore awel DAG visualization files
|
||||||
|
/examples/**/*.gv
|
||||||
|
/examples/**/*.gv.pdf
|
5
Makefile
5
Makefile
@@ -67,6 +67,11 @@ pre-commit: fmt test ## Run formatting and unit tests before committing
|
|||||||
test: $(VENV)/.testenv ## Run unit tests
|
test: $(VENV)/.testenv ## Run unit tests
|
||||||
$(VENV_BIN)/pytest dbgpt
|
$(VENV_BIN)/pytest dbgpt
|
||||||
|
|
||||||
|
.PHONY: test-doc
|
||||||
|
test-doc: $(VENV)/.testenv ## Run doctests
|
||||||
|
# -k "not test_" skips tests that are not doctests.
|
||||||
|
$(VENV_BIN)/pytest --doctest-modules -k "not test_" dbgpt/core
|
||||||
|
|
||||||
.PHONY: coverage
|
.PHONY: coverage
|
||||||
coverage: setup ## Run tests and report coverage
|
coverage: setup ## Run tests and report coverage
|
||||||
$(VENV_BIN)/pytest dbgpt --cov=dbgpt
|
$(VENV_BIN)/pytest dbgpt --cov=dbgpt
|
||||||
|
@@ -102,6 +102,11 @@ class BaseChat(ABC):
|
|||||||
is_stream=True, dag_name="llm_stream_model_dag"
|
is_stream=True, dag_name="llm_stream_model_dag"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get the message version, default is v1 in app
|
||||||
|
# In v1, we will transform the message to compatible format of specific model
|
||||||
|
# In the future, we will upgrade the message version to v2, and the message will be compatible with all models
|
||||||
|
self._message_version = chat_param.get("message_version", "v1")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@@ -185,6 +190,7 @@ class BaseChat(ABC):
|
|||||||
"temperature": float(self.prompt_template.temperature),
|
"temperature": float(self.prompt_template.temperature),
|
||||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||||
"echo": self.llm_echo,
|
"echo": self.llm_echo,
|
||||||
|
"version": self._message_version,
|
||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
@@ -6,7 +6,10 @@ from dbgpt.core.interface.cache import (
|
|||||||
CacheValue,
|
CacheValue,
|
||||||
)
|
)
|
||||||
from dbgpt.core.interface.llm import (
|
from dbgpt.core.interface.llm import (
|
||||||
|
DefaultMessageConverter,
|
||||||
LLMClient,
|
LLMClient,
|
||||||
|
MessageConverter,
|
||||||
|
ModelExtraMedata,
|
||||||
ModelInferenceMetrics,
|
ModelInferenceMetrics,
|
||||||
ModelMetadata,
|
ModelMetadata,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@@ -14,19 +17,28 @@ from dbgpt.core.interface.llm import (
|
|||||||
ModelRequestContext,
|
ModelRequestContext,
|
||||||
)
|
)
|
||||||
from dbgpt.core.interface.message import (
|
from dbgpt.core.interface.message import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
ConversationIdentifier,
|
ConversationIdentifier,
|
||||||
|
HumanMessage,
|
||||||
MessageIdentifier,
|
MessageIdentifier,
|
||||||
MessageStorageItem,
|
MessageStorageItem,
|
||||||
ModelMessage,
|
ModelMessage,
|
||||||
ModelMessageRoleType,
|
ModelMessageRoleType,
|
||||||
OnceConversation,
|
OnceConversation,
|
||||||
StorageConversation,
|
StorageConversation,
|
||||||
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
||||||
from dbgpt.core.interface.prompt import (
|
from dbgpt.core.interface.prompt import (
|
||||||
|
BasePromptTemplate,
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanPromptTemplate,
|
||||||
|
MessagesPlaceholder,
|
||||||
PromptManager,
|
PromptManager,
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
StoragePromptTemplate,
|
StoragePromptTemplate,
|
||||||
|
SystemPromptTemplate,
|
||||||
)
|
)
|
||||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||||
from dbgpt.core.interface.storage import (
|
from dbgpt.core.interface.storage import (
|
||||||
@@ -49,14 +61,26 @@ __ALL__ = [
|
|||||||
"ModelMessage",
|
"ModelMessage",
|
||||||
"LLMClient",
|
"LLMClient",
|
||||||
"ModelMessageRoleType",
|
"ModelMessageRoleType",
|
||||||
|
"ModelExtraMedata",
|
||||||
|
"MessageConverter",
|
||||||
|
"DefaultMessageConverter",
|
||||||
"OnceConversation",
|
"OnceConversation",
|
||||||
"StorageConversation",
|
"StorageConversation",
|
||||||
|
"BaseMessage",
|
||||||
|
"SystemMessage",
|
||||||
|
"AIMessage",
|
||||||
|
"HumanMessage",
|
||||||
"MessageStorageItem",
|
"MessageStorageItem",
|
||||||
"ConversationIdentifier",
|
"ConversationIdentifier",
|
||||||
"MessageIdentifier",
|
"MessageIdentifier",
|
||||||
"PromptTemplate",
|
"PromptTemplate",
|
||||||
"PromptManager",
|
"PromptManager",
|
||||||
"StoragePromptTemplate",
|
"StoragePromptTemplate",
|
||||||
|
"BasePromptTemplate",
|
||||||
|
"ChatPromptTemplate",
|
||||||
|
"MessagesPlaceholder",
|
||||||
|
"SystemPromptTemplate",
|
||||||
|
"HumanPromptTemplate",
|
||||||
"BaseOutputParser",
|
"BaseOutputParser",
|
||||||
"SQLOutputParser",
|
"SQLOutputParser",
|
||||||
"Serializable",
|
"Serializable",
|
||||||
|
@@ -7,6 +7,7 @@ The stability of this API cannot be guaranteed at present.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
@@ -39,6 +40,8 @@ from .task.task_impl import (
|
|||||||
)
|
)
|
||||||
from .trigger.http_trigger import HttpTrigger
|
from .trigger.http_trigger import HttpTrigger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"initialize_awel",
|
"initialize_awel",
|
||||||
"DAGContext",
|
"DAGContext",
|
||||||
@@ -89,14 +92,24 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):
|
|||||||
|
|
||||||
def setup_dev_environment(
|
def setup_dev_environment(
|
||||||
dags: List[DAG],
|
dags: List[DAG],
|
||||||
host: Optional[str] = "0.0.0.0",
|
host: Optional[str] = "127.0.0.1",
|
||||||
port: Optional[int] = 5555,
|
port: Optional[int] = 5555,
|
||||||
logging_level: Optional[str] = None,
|
logging_level: Optional[str] = None,
|
||||||
logger_filename: Optional[str] = None,
|
logger_filename: Optional[str] = None,
|
||||||
|
show_dag_graph: Optional[bool] = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Setup a development environment for AWEL.
|
"""Setup a development environment for AWEL.
|
||||||
|
|
||||||
Just using in development environment, not production environment.
|
Just using in development environment, not production environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dags (List[DAG]): The DAGs.
|
||||||
|
host (Optional[str], optional): The host. Defaults to "127.0.0.1"
|
||||||
|
port (Optional[int], optional): The port. Defaults to 5555.
|
||||||
|
logging_level (Optional[str], optional): The logging level. Defaults to None.
|
||||||
|
logger_filename (Optional[str], optional): The logger filename. Defaults to None.
|
||||||
|
show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True.
|
||||||
|
If True, the DAG graph will be saved to a file and open it automatically.
|
||||||
"""
|
"""
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@@ -118,6 +131,15 @@ def setup_dev_environment(
|
|||||||
system_app.register_instance(trigger_manager)
|
system_app.register_instance(trigger_manager)
|
||||||
|
|
||||||
for dag in dags:
|
for dag in dags:
|
||||||
|
if show_dag_graph:
|
||||||
|
try:
|
||||||
|
dag_graph_file = dag.visualize_dag()
|
||||||
|
if dag_graph_file:
|
||||||
|
logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`"
|
||||||
|
)
|
||||||
for trigger in dag.trigger_nodes:
|
for trigger in dag.trigger_nodes:
|
||||||
trigger_manager.register_trigger(trigger)
|
trigger_manager.register_trigger(trigger)
|
||||||
trigger_manager.after_register()
|
trigger_manager.after_register()
|
||||||
|
@@ -6,8 +6,7 @@ import uuid
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from concurrent.futures import Executor
|
from concurrent.futures import Executor
|
||||||
from functools import cache
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Set, Union
|
|
||||||
|
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
|
|
||||||
@@ -177,7 +176,10 @@ class DAGLifecycle:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
async def after_dag_end(self):
|
async def after_dag_end(self):
|
||||||
"""The callback after DAG end"""
|
"""The callback after DAG end,
|
||||||
|
|
||||||
|
This method may be called multiple times, please make sure it is idempotent.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -299,6 +301,20 @@ class DAGNode(DAGLifecycle, DependencyMixin, ABC):
|
|||||||
self._downstream.append(node)
|
self._downstream.append(node)
|
||||||
node._upstream.append(self)
|
node._upstream.append(self)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
cls_name = self.__class__.__name__
|
||||||
|
if self.node_name and self.node_name:
|
||||||
|
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
|
||||||
|
if self.node_id:
|
||||||
|
return f"{cls_name}(node_id={self.node_id})"
|
||||||
|
if self.node_name:
|
||||||
|
return f"{cls_name}(node_name={self.node_name})"
|
||||||
|
else:
|
||||||
|
return f"{cls_name}"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
|
||||||
def _build_task_key(task_name: str, key: str) -> str:
|
def _build_task_key(task_name: str, key: str) -> str:
|
||||||
return f"{task_name}___$$$$$$___{key}"
|
return f"{task_name}___$$$$$$___{key}"
|
||||||
@@ -496,6 +512,15 @@ class DAG:
|
|||||||
tasks.append(node.after_dag_end())
|
tasks.append(node.after_dag_end())
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
def print_tree(self) -> None:
|
||||||
|
"""Print the DAG tree"""
|
||||||
|
_print_format_dag_tree(self)
|
||||||
|
|
||||||
|
def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
|
||||||
|
"""Create the DAG graph"""
|
||||||
|
self.print_tree()
|
||||||
|
return _visualize_dag(self, view=view, **kwargs)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
DAGVar.enter_dag(self)
|
DAGVar.enter_dag(self)
|
||||||
return self
|
return self
|
||||||
@@ -516,3 +541,109 @@ def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode
|
|||||||
for node in stream_nodes:
|
for node in stream_nodes:
|
||||||
nodes = nodes.union(_get_nodes(node, is_upstream))
|
nodes = nodes.union(_get_nodes(node, is_upstream))
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
def _print_format_dag_tree(dag: DAG) -> None:
|
||||||
|
for node in dag.root_nodes:
|
||||||
|
_print_dag(node)
|
||||||
|
|
||||||
|
|
||||||
|
def _print_dag(
|
||||||
|
node: DAGNode,
|
||||||
|
level: int = 0,
|
||||||
|
prefix: str = "",
|
||||||
|
last: bool = True,
|
||||||
|
level_dict: Dict[str, Any] = None,
|
||||||
|
):
|
||||||
|
if level_dict is None:
|
||||||
|
level_dict = {}
|
||||||
|
|
||||||
|
connector = " -> " if level != 0 else ""
|
||||||
|
new_prefix = prefix
|
||||||
|
if last:
|
||||||
|
if level != 0:
|
||||||
|
new_prefix += " "
|
||||||
|
print(prefix + connector + str(node))
|
||||||
|
else:
|
||||||
|
if level != 0:
|
||||||
|
new_prefix += "| "
|
||||||
|
print(prefix + connector + str(node))
|
||||||
|
|
||||||
|
level_dict[level] = level_dict.get(level, 0) + 1
|
||||||
|
num_children = len(node.downstream)
|
||||||
|
for i, child in enumerate(node.downstream):
|
||||||
|
_print_dag(child, level + 1, new_prefix, i == num_children - 1, level_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _print_dag_tree(root_nodes: List[DAGNode], level_sep: str = " ") -> None:
|
||||||
|
def _print_node(node: DAGNode, level: int) -> None:
|
||||||
|
print(f"{level_sep * level}{node}")
|
||||||
|
|
||||||
|
_apply_root_node(root_nodes, _print_node)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_root_node(
|
||||||
|
root_nodes: List[DAGNode],
|
||||||
|
func: Callable[[DAGNode, int], None],
|
||||||
|
) -> None:
|
||||||
|
for dag_node in root_nodes:
|
||||||
|
_handle_dag_nodes(False, 0, dag_node, func)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_dag_nodes(
|
||||||
|
is_down_to_up: bool,
|
||||||
|
level: int,
|
||||||
|
dag_node: DAGNode,
|
||||||
|
func: Callable[[DAGNode, int], None],
|
||||||
|
):
|
||||||
|
if not dag_node:
|
||||||
|
return
|
||||||
|
func(dag_node, level)
|
||||||
|
stream_nodes = dag_node.upstream if is_down_to_up else dag_node.downstream
|
||||||
|
level += 1
|
||||||
|
for node in stream_nodes:
|
||||||
|
_handle_dag_nodes(is_down_to_up, level, node, func)
|
||||||
|
|
||||||
|
|
||||||
|
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||||
|
"""Visualize the DAG
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dag (DAG): The DAG to visualize
|
||||||
|
view (bool, optional): Whether view the DAG graph. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The filename of the DAG graph
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from graphviz import Digraph
|
||||||
|
except ImportError:
|
||||||
|
logger.warn("Can't import graphviz, skip visualize DAG")
|
||||||
|
return None
|
||||||
|
|
||||||
|
dot = Digraph(name=dag.dag_id)
|
||||||
|
# Record the added edges to avoid adding duplicate edges
|
||||||
|
added_edges = set()
|
||||||
|
|
||||||
|
def add_edges(node: DAGNode):
|
||||||
|
if node.downstream:
|
||||||
|
for downstream_node in node.downstream:
|
||||||
|
# Check if the edge has been added
|
||||||
|
if (str(node), str(downstream_node)) not in added_edges:
|
||||||
|
dot.edge(str(node), str(downstream_node))
|
||||||
|
added_edges.add((str(node), str(downstream_node)))
|
||||||
|
add_edges(downstream_node)
|
||||||
|
|
||||||
|
for root in dag.root_nodes:
|
||||||
|
add_edges(root)
|
||||||
|
filename = f"dag-vis-{dag.dag_id}.gv"
|
||||||
|
if "filename" in kwargs:
|
||||||
|
filename = kwargs["filename"]
|
||||||
|
del kwargs["filename"]
|
||||||
|
|
||||||
|
if not "directory" in kwargs:
|
||||||
|
from dbgpt.configs.model_config import LOGDIR
|
||||||
|
|
||||||
|
kwargs["directory"] = LOGDIR
|
||||||
|
|
||||||
|
return dot.render(filename, view=view, **kwargs)
|
||||||
|
@@ -46,6 +46,7 @@ class WorkflowRunner(ABC, Generic[T]):
|
|||||||
node: "BaseOperator",
|
node: "BaseOperator",
|
||||||
call_data: Optional[CALL_DATA] = None,
|
call_data: Optional[CALL_DATA] = None,
|
||||||
streaming_call: bool = False,
|
streaming_call: bool = False,
|
||||||
|
dag_ctx: Optional[DAGContext] = None,
|
||||||
) -> DAGContext:
|
) -> DAGContext:
|
||||||
"""Execute the workflow starting from a given operator.
|
"""Execute the workflow starting from a given operator.
|
||||||
|
|
||||||
@@ -53,7 +54,7 @@ class WorkflowRunner(ABC, Generic[T]):
|
|||||||
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
node (RunnableDAGNode): The starting node of the workflow to be executed.
|
||||||
call_data (CALL_DATA): The data pass to root operator node.
|
call_data (CALL_DATA): The data pass to root operator node.
|
||||||
streaming_call (bool): Whether the call is a streaming call.
|
streaming_call (bool): Whether the call is a streaming call.
|
||||||
|
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||||
Returns:
|
Returns:
|
||||||
DAGContext: The context after executing the workflow, containing the final state and data.
|
DAGContext: The context after executing the workflow, containing the final state and data.
|
||||||
"""
|
"""
|
||||||
@@ -174,18 +175,22 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
TaskOutput[OUT]: The task output after this node has been run.
|
TaskOutput[OUT]: The task output after this node has been run.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT:
|
async def call(
|
||||||
|
self,
|
||||||
|
call_data: Optional[CALL_DATA] = None,
|
||||||
|
dag_ctx: Optional[DAGContext] = None,
|
||||||
|
) -> OUT:
|
||||||
"""Execute the node and return the output.
|
"""Execute the node and return the output.
|
||||||
|
|
||||||
This method is a high-level wrapper for executing the node.
|
This method is a high-level wrapper for executing the node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
call_data (CALL_DATA): The data pass to root operator node.
|
call_data (CALL_DATA): The data pass to root operator node.
|
||||||
|
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||||
Returns:
|
Returns:
|
||||||
OUT: The output of the node after execution.
|
OUT: The output of the node after execution.
|
||||||
"""
|
"""
|
||||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
out_ctx = await self._runner.execute_workflow(self, call_data, dag_ctx=dag_ctx)
|
||||||
return out_ctx.current_task_context.task_output.output
|
return out_ctx.current_task_context.task_output.output
|
||||||
|
|
||||||
def _blocking_call(
|
def _blocking_call(
|
||||||
@@ -209,7 +214,9 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
return loop.run_until_complete(self.call(call_data))
|
return loop.run_until_complete(self.call(call_data))
|
||||||
|
|
||||||
async def call_stream(
|
async def call_stream(
|
||||||
self, call_data: Optional[CALL_DATA] = None
|
self,
|
||||||
|
call_data: Optional[CALL_DATA] = None,
|
||||||
|
dag_ctx: Optional[DAGContext] = None,
|
||||||
) -> AsyncIterator[OUT]:
|
) -> AsyncIterator[OUT]:
|
||||||
"""Execute the node and return the output as a stream.
|
"""Execute the node and return the output as a stream.
|
||||||
|
|
||||||
@@ -217,12 +224,13 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
call_data (CALL_DATA): The data pass to root operator node.
|
call_data (CALL_DATA): The data pass to root operator node.
|
||||||
|
dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||||
"""
|
"""
|
||||||
out_ctx = await self._runner.execute_workflow(
|
out_ctx = await self._runner.execute_workflow(
|
||||||
self, call_data, streaming_call=True
|
self, call_data, streaming_call=True, dag_ctx=dag_ctx
|
||||||
)
|
)
|
||||||
return out_ctx.current_task_context.task_output.output_stream
|
return out_ctx.current_task_context.task_output.output_stream
|
||||||
|
|
||||||
|
@@ -19,17 +19,21 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
node: BaseOperator,
|
node: BaseOperator,
|
||||||
call_data: Optional[CALL_DATA] = None,
|
call_data: Optional[CALL_DATA] = None,
|
||||||
streaming_call: bool = False,
|
streaming_call: bool = False,
|
||||||
|
dag_ctx: Optional[DAGContext] = None,
|
||||||
) -> DAGContext:
|
) -> DAGContext:
|
||||||
# Save node output
|
# Save node output
|
||||||
# dag = node.dag
|
# dag = node.dag
|
||||||
node_outputs: Dict[str, TaskContext] = {}
|
|
||||||
job_manager = JobManager.build_from_end_node(node, call_data)
|
job_manager = JobManager.build_from_end_node(node, call_data)
|
||||||
# Create DAG context
|
if not dag_ctx:
|
||||||
dag_ctx = DAGContext(
|
# Create DAG context
|
||||||
streaming_call=streaming_call,
|
node_outputs: Dict[str, TaskContext] = {}
|
||||||
node_to_outputs=node_outputs,
|
dag_ctx = DAGContext(
|
||||||
node_name_to_ids=job_manager._node_name_to_ids,
|
streaming_call=streaming_call,
|
||||||
)
|
node_to_outputs=node_outputs,
|
||||||
|
node_name_to_ids=job_manager._node_name_to_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
node_outputs = dag_ctx._node_to_outputs
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}"
|
||||||
)
|
)
|
||||||
|
@@ -1,14 +1,21 @@
|
|||||||
|
import collections
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
|
||||||
|
from dbgpt._private.pydantic import BaseModel
|
||||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||||
from dbgpt.util import BaseParameters
|
from dbgpt.util import BaseParameters
|
||||||
from dbgpt.util.annotations import PublicAPI
|
from dbgpt.util.annotations import PublicAPI
|
||||||
from dbgpt.util.model_utils import GPUInfo
|
from dbgpt.util.model_utils import GPUInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@PublicAPI(stability="beta")
|
@PublicAPI(stability="beta")
|
||||||
@@ -223,6 +230,29 @@ class ModelRequest:
|
|||||||
raise ValueError("The messages is not a single user message")
|
raise ValueError("The messages is not a single user message")
|
||||||
return messages[0]
|
return messages[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_request(
|
||||||
|
model: str,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
context: Union[ModelRequestContext, Dict[str, Any], BaseModel],
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
context_dict = None
|
||||||
|
if isinstance(context, dict):
|
||||||
|
context_dict = context
|
||||||
|
elif isinstance(context, BaseModel):
|
||||||
|
context_dict = context.dict()
|
||||||
|
if context_dict and "stream" not in context_dict:
|
||||||
|
context_dict["stream"] = stream
|
||||||
|
context = ModelRequestContext(**context_dict)
|
||||||
|
return ModelRequest(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
context=context,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build(model: str, prompt: str, **kwargs):
|
def _build(model: str, prompt: str, **kwargs):
|
||||||
return ModelRequest(
|
return ModelRequest(
|
||||||
@@ -271,6 +301,43 @@ class ModelRequest:
|
|||||||
return ModelMessage.to_openai_messages(messages)
|
return ModelMessage.to_openai_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelExtraMedata(BaseParameters):
|
||||||
|
"""A class to represent the extra metadata of a LLM."""
|
||||||
|
|
||||||
|
prompt_roles: Optional[List[str]] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
ModelMessageRoleType.SYSTEM,
|
||||||
|
ModelMessageRoleType.HUMAN,
|
||||||
|
ModelMessageRoleType.AI,
|
||||||
|
],
|
||||||
|
metadata={"help": "The roles of the prompt"},
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_sep: Optional[str] = field(
|
||||||
|
default="\n",
|
||||||
|
metadata={"help": "The separator of the prompt between multiple rounds"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# You can see the chat template in your model repo tokenizer config,
|
||||||
|
# typically in the tokenizer_config.json
|
||||||
|
prompt_chat_template: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def support_system_message(self) -> bool:
|
||||||
|
"""Whether the model supports system message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the model supports system message.
|
||||||
|
"""
|
||||||
|
return ModelMessageRoleType.SYSTEM in self.prompt_roles
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@PublicAPI(stability="beta")
|
@PublicAPI(stability="beta")
|
||||||
class ModelMetadata(BaseParameters):
|
class ModelMetadata(BaseParameters):
|
||||||
@@ -295,18 +362,294 @@ class ModelMetadata(BaseParameters):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
metadata={"help": "Model metadata"},
|
metadata={"help": "Model metadata"},
|
||||||
)
|
)
|
||||||
|
ext_metadata: Optional[ModelExtraMedata] = field(
|
||||||
|
default_factory=ModelExtraMedata,
|
||||||
|
metadata={"help": "Model extra metadata"},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(
|
||||||
|
cls, data: dict, ignore_extra_fields: bool = False
|
||||||
|
) -> "ModelMetadata":
|
||||||
|
if "ext_metadata" in data:
|
||||||
|
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageConverter(ABC):
|
||||||
|
"""An abstract class for message converter.
|
||||||
|
|
||||||
|
Different LLMs may have different message formats, this class is used to convert the messages
|
||||||
|
to the format of the LLM.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> from typing import List
|
||||||
|
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||||
|
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
|
||||||
|
>>> class RemoveSystemMessageConverter(MessageConverter):
|
||||||
|
... def convert(
|
||||||
|
... self,
|
||||||
|
... messages: List[ModelMessage],
|
||||||
|
... model_metadata: Optional[ModelMetadata] = None,
|
||||||
|
... ) -> List[ModelMessage]:
|
||||||
|
... # Convert the messages, merge system messages to the last user message.
|
||||||
|
... system_message = None
|
||||||
|
... other_messages = []
|
||||||
|
... sep = "\\n"
|
||||||
|
... for message in messages:
|
||||||
|
... if message.role == ModelMessageRoleType.SYSTEM:
|
||||||
|
... system_message = message
|
||||||
|
... else:
|
||||||
|
... other_messages.append(message)
|
||||||
|
... if system_message and other_messages:
|
||||||
|
... other_messages[-1].content = (
|
||||||
|
... system_message.content + sep + other_messages[-1].content
|
||||||
|
... )
|
||||||
|
... return other_messages
|
||||||
|
...
|
||||||
|
>>> messages = [
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.SYSTEM,
|
||||||
|
... content="You are a helpful assistant",
|
||||||
|
... ),
|
||||||
|
... ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are you"),
|
||||||
|
... ]
|
||||||
|
>>> converter = RemoveSystemMessageConverter()
|
||||||
|
>>> converted_messages = converter.convert(messages, None)
|
||||||
|
>>> assert converted_messages == [
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.HUMAN,
|
||||||
|
... content="You are a helpful assistant\\nWho are you",
|
||||||
|
... ),
|
||||||
|
... ]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert(
|
||||||
|
self,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
model_metadata: Optional[ModelMetadata] = None,
|
||||||
|
) -> List[ModelMessage]:
|
||||||
|
"""Convert the messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages(List[ModelMessage]): The messages.
|
||||||
|
model_metadata(ModelMetadata): The model metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ModelMessage]: The converted messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultMessageConverter(MessageConverter):
|
||||||
|
"""The default message converter."""
|
||||||
|
|
||||||
|
def __init__(self, prompt_sep: Optional[str] = None):
|
||||||
|
self._prompt_sep = prompt_sep
|
||||||
|
|
||||||
|
def convert(
|
||||||
|
self,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
model_metadata: Optional[ModelMetadata] = None,
|
||||||
|
) -> List[ModelMessage]:
|
||||||
|
"""Convert the messages.
|
||||||
|
|
||||||
|
There are three steps to convert the messages:
|
||||||
|
|
||||||
|
1. Just keep system, human and AI messages
|
||||||
|
|
||||||
|
2. Move the last user's message to the end of the list
|
||||||
|
|
||||||
|
3. Convert the messages to no system message if the model does not support system message
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages(List[ModelMessage]): The messages.
|
||||||
|
model_metadata(ModelMetadata): The model metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ModelMessage]: The converted messages.
|
||||||
|
"""
|
||||||
|
# 1. Just keep system, human and AI messages
|
||||||
|
messages = list(filter(lambda m: m.pass_to_model, messages))
|
||||||
|
# 2. Move the last user's message to the end of the list
|
||||||
|
messages = self.move_last_user_message_to_end(messages)
|
||||||
|
|
||||||
|
if not model_metadata or not model_metadata.ext_metadata:
|
||||||
|
logger.warning("No model metadata, skip message system message conversion")
|
||||||
|
return messages
|
||||||
|
if model_metadata.ext_metadata.support_system_message:
|
||||||
|
# 3. Convert the messages to no system message
|
||||||
|
return self.convert_to_no_system_message(messages, model_metadata)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def convert_to_no_system_message(
|
||||||
|
self,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
model_metadata: Optional[ModelMetadata] = None,
|
||||||
|
) -> List[ModelMessage]:
|
||||||
|
"""Convert the messages to no system message.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Convert the messages to no system message, just merge system messages to the last user message
|
||||||
|
>>> from typing import List
|
||||||
|
>>> from dbgpt.core.interface.message import (
|
||||||
|
... ModelMessage,
|
||||||
|
... ModelMessageRoleType,
|
||||||
|
... )
|
||||||
|
>>> from dbgpt.core.interface.llm import (
|
||||||
|
... DefaultMessageConverter,
|
||||||
|
... ModelMetadata,
|
||||||
|
... )
|
||||||
|
>>> messages = [
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.SYSTEM,
|
||||||
|
... content="You are a helpful assistant",
|
||||||
|
... ),
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.HUMAN, content="Who are you"
|
||||||
|
... ),
|
||||||
|
... ]
|
||||||
|
>>> converter = DefaultMessageConverter()
|
||||||
|
>>> model_metadata = ModelMetadata(model="test")
|
||||||
|
>>> converted_messages = converter.convert_to_no_system_message(
|
||||||
|
... messages, model_metadata
|
||||||
|
... )
|
||||||
|
>>> assert converted_messages == [
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.HUMAN,
|
||||||
|
... content="You are a helpful assistant\\nWho are you",
|
||||||
|
... ),
|
||||||
|
... ]
|
||||||
|
"""
|
||||||
|
if not model_metadata or not model_metadata.ext_metadata:
|
||||||
|
logger.warning("No model metadata, skip message conversion")
|
||||||
|
return messages
|
||||||
|
ext_metadata = model_metadata.ext_metadata
|
||||||
|
system_messages = []
|
||||||
|
result_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message.role == ModelMessageRoleType.SYSTEM:
|
||||||
|
# Not support system message, append system message to the last user message
|
||||||
|
system_messages.append(message)
|
||||||
|
elif message.role in [
|
||||||
|
ModelMessageRoleType.HUMAN,
|
||||||
|
ModelMessageRoleType.AI,
|
||||||
|
]:
|
||||||
|
result_messages.append(message)
|
||||||
|
prompt_sep = self._prompt_sep or ext_metadata.prompt_sep or "\n"
|
||||||
|
system_message_str = None
|
||||||
|
if len(system_messages) > 1:
|
||||||
|
logger.warning("Your system messages have more than one message")
|
||||||
|
system_message_str = prompt_sep.join([m.content for m in system_messages])
|
||||||
|
elif len(system_messages) == 1:
|
||||||
|
system_message_str = system_messages[0].content
|
||||||
|
|
||||||
|
if system_message_str and result_messages:
|
||||||
|
# Not support system messages, merge system messages to the last user message
|
||||||
|
result_messages[-1].content = (
|
||||||
|
system_message_str + prompt_sep + result_messages[-1].content
|
||||||
|
)
|
||||||
|
return result_messages
|
||||||
|
|
||||||
|
def move_last_user_message_to_end(
|
||||||
|
self, messages: List[ModelMessage]
|
||||||
|
) -> List[ModelMessage]:
|
||||||
|
"""Move the last user message to the end of the list.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> from typing import List
|
||||||
|
>>> from dbgpt.core.interface.message import (
|
||||||
|
... ModelMessage,
|
||||||
|
... ModelMessageRoleType,
|
||||||
|
... )
|
||||||
|
>>> from dbgpt.core.interface.llm import DefaultMessageConverter
|
||||||
|
>>> messages = [
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.SYSTEM,
|
||||||
|
... content="You are a helpful assistant",
|
||||||
|
... ),
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.HUMAN, content="Who are you"
|
||||||
|
... ),
|
||||||
|
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.HUMAN, content="What's your name"
|
||||||
|
... ),
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.SYSTEM,
|
||||||
|
... content="You are a helpful assistant",
|
||||||
|
... ),
|
||||||
|
... ]
|
||||||
|
>>> converter = DefaultMessageConverter()
|
||||||
|
>>> converted_messages = converter.move_last_user_message_to_end(messages)
|
||||||
|
>>> assert converted_messages == [
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.SYSTEM,
|
||||||
|
... content="You are a helpful assistant",
|
||||||
|
... ),
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.HUMAN, content="Who are you"
|
||||||
|
... ),
|
||||||
|
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.SYSTEM,
|
||||||
|
... content="You are a helpful assistant",
|
||||||
|
... ),
|
||||||
|
... ModelMessage(
|
||||||
|
... role=ModelMessageRoleType.HUMAN, content="What's your name"
|
||||||
|
... ),
|
||||||
|
... ]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages(List[ModelMessage]): The messages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ModelMessage]: The converted messages.
|
||||||
|
"""
|
||||||
|
last_user_input_index = None
|
||||||
|
for i in range(len(messages) - 1, -1, -1):
|
||||||
|
if messages[i].role == ModelMessageRoleType.HUMAN:
|
||||||
|
last_user_input_index = i
|
||||||
|
break
|
||||||
|
if last_user_input_index is not None:
|
||||||
|
last_user_input = messages.pop(last_user_input_index)
|
||||||
|
messages.append(last_user_input)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI(stability="beta")
|
@PublicAPI(stability="beta")
|
||||||
class LLMClient(ABC):
|
class LLMClient(ABC):
|
||||||
"""An abstract class for LLM client."""
|
"""An abstract class for LLM client."""
|
||||||
|
|
||||||
|
# Cache the model metadata for 60 seconds
|
||||||
|
_MODEL_CACHE_ = TTLCache(maxsize=100, ttl=60)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache(self) -> collections.abc.MutableMapping:
|
||||||
|
"""The cache object to cache the model metadata.
|
||||||
|
|
||||||
|
You can override this property to use your own cache object.
|
||||||
|
Returns:
|
||||||
|
collections.abc.MutableMapping: The cache object.
|
||||||
|
"""
|
||||||
|
return self._MODEL_CACHE_
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def generate(self, request: ModelRequest) -> ModelOutput:
|
async def generate(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
message_converter: Optional[MessageConverter] = None,
|
||||||
|
) -> ModelOutput:
|
||||||
"""Generate a response for a given model request.
|
"""Generate a response for a given model request.
|
||||||
|
|
||||||
|
Sometimes, different LLMs may have different message formats,
|
||||||
|
you can use the message converter to convert the messages to the format of the LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request(ModelRequest): The model request.
|
request(ModelRequest): The model request.
|
||||||
|
message_converter(MessageConverter): The message converter.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelOutput: The model output.
|
ModelOutput: The model output.
|
||||||
@@ -315,12 +658,18 @@ class LLMClient(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def generate_stream(
|
async def generate_stream(
|
||||||
self, request: ModelRequest
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
message_converter: Optional[MessageConverter] = None,
|
||||||
) -> AsyncIterator[ModelOutput]:
|
) -> AsyncIterator[ModelOutput]:
|
||||||
"""Generate a stream of responses for a given model request.
|
"""Generate a stream of responses for a given model request.
|
||||||
|
|
||||||
|
Sometimes, different LLMs may have different message formats,
|
||||||
|
you can use the message converter to convert the messages to the format of the LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request(ModelRequest): The model request.
|
request(ModelRequest): The model request.
|
||||||
|
message_converter(MessageConverter): The message converter.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncIterator[ModelOutput]: The model output stream.
|
AsyncIterator[ModelOutput]: The model output stream.
|
||||||
@@ -345,3 +694,65 @@ class LLMClient(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
int: The number of tokens.
|
int: The number of tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def covert_message(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
message_converter: Optional[MessageConverter] = None,
|
||||||
|
) -> ModelRequest:
|
||||||
|
"""Covert the message.
|
||||||
|
If no message converter is provided, the original request will be returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request(ModelRequest): The model request.
|
||||||
|
message_converter(MessageConverter): The message converter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelRequest: The converted model request.
|
||||||
|
"""
|
||||||
|
if not message_converter:
|
||||||
|
return request
|
||||||
|
new_request = request.copy()
|
||||||
|
model_metadata = await self.get_model_metadata(request.model)
|
||||||
|
new_messages = message_converter.convert(request.messages, model_metadata)
|
||||||
|
new_request.messages = new_messages
|
||||||
|
return new_request
|
||||||
|
|
||||||
|
async def cached_models(self) -> List[ModelMetadata]:
|
||||||
|
"""Get all the models from the cache or the llm server.
|
||||||
|
|
||||||
|
If the model metadata is not in the cache, it will be fetched from the llm server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ModelMetadata]: A list of model metadata.
|
||||||
|
"""
|
||||||
|
key = "____$llm_client_models$____"
|
||||||
|
if key not in self.cache:
|
||||||
|
models = await self.models()
|
||||||
|
self.cache[key] = models
|
||||||
|
for model in models:
|
||||||
|
model_metadata_key = (
|
||||||
|
f"____$llm_client_models_metadata_{model.model}$____"
|
||||||
|
)
|
||||||
|
self.cache[model_metadata_key] = model
|
||||||
|
return self.cache[key]
|
||||||
|
|
||||||
|
async def get_model_metadata(self, model: str) -> ModelMetadata:
|
||||||
|
"""Get the model metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model(str): The model name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelMetadata: The model metadata.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model is not found.
|
||||||
|
"""
|
||||||
|
model_metadata_key = f"____$llm_client_models_metadata_{model}$____"
|
||||||
|
if model_metadata_key not in self.cache:
|
||||||
|
await self.cached_models()
|
||||||
|
model_metadata = self.cache.get(model_metadata_key)
|
||||||
|
if not model_metadata:
|
||||||
|
raise ValueError(f"Model {model} not found")
|
||||||
|
return model_metadata
|
||||||
|
@@ -5,7 +5,6 @@ from datetime import datetime
|
|||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.core.awel import MapOperator
|
|
||||||
from dbgpt.core.interface.storage import (
|
from dbgpt.core.interface.storage import (
|
||||||
InMemoryStorage,
|
InMemoryStorage,
|
||||||
ResourceIdentifier,
|
ResourceIdentifier,
|
||||||
@@ -114,6 +113,50 @@ class ModelMessage(BaseModel):
|
|||||||
content: str
|
content: str
|
||||||
round_index: Optional[int] = 0
|
round_index: Optional[int] = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pass_to_model(self) -> bool:
|
||||||
|
"""Whether the message will be passed to the model
|
||||||
|
|
||||||
|
The view message will not be passed to the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the message will be passed to the model
|
||||||
|
"""
|
||||||
|
return self.role in [
|
||||||
|
ModelMessageRoleType.SYSTEM,
|
||||||
|
ModelMessageRoleType.HUMAN,
|
||||||
|
ModelMessageRoleType.AI,
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]:
|
||||||
|
result = []
|
||||||
|
for message in messages:
|
||||||
|
content, round_index = message.content, message.round_index
|
||||||
|
if isinstance(message, HumanMessage):
|
||||||
|
result.append(
|
||||||
|
ModelMessage(
|
||||||
|
role=ModelMessageRoleType.HUMAN,
|
||||||
|
content=content,
|
||||||
|
round_index=round_index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
result.append(
|
||||||
|
ModelMessage(
|
||||||
|
role=ModelMessageRoleType.AI,
|
||||||
|
content=content,
|
||||||
|
round_index=round_index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
result.append(
|
||||||
|
ModelMessage(
|
||||||
|
role=ModelMessageRoleType.SYSTEM, content=message.content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_openai_messages(
|
def from_openai_messages(
|
||||||
messages: Union[str, List[Dict[str, str]]]
|
messages: Union[str, List[Dict[str, str]]]
|
||||||
@@ -142,9 +185,15 @@ class ModelMessage(BaseModel):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
def to_openai_messages(
|
||||||
|
messages: List["ModelMessage"], convert_to_compatible_format: bool = False
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
"""Convert to OpenAI message format and
|
"""Convert to OpenAI message format and
|
||||||
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List["ModelMessage"]): The model messages
|
||||||
|
convert_to_compatible_format (bool): Whether to convert to compatible format
|
||||||
"""
|
"""
|
||||||
history = []
|
history = []
|
||||||
# Add history conversation
|
# Add history conversation
|
||||||
@@ -157,15 +206,16 @@ class ModelMessage(BaseModel):
|
|||||||
history.append({"role": "assistant", "content": message.content})
|
history.append({"role": "assistant", "content": message.content})
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
# Move the last user's information to the end
|
if convert_to_compatible_format:
|
||||||
last_user_input_index = None
|
# Move the last user's information to the end
|
||||||
for i in range(len(history) - 1, -1, -1):
|
last_user_input_index = None
|
||||||
if history[i]["role"] == "user":
|
for i in range(len(history) - 1, -1, -1):
|
||||||
last_user_input_index = i
|
if history[i]["role"] == "user":
|
||||||
break
|
last_user_input_index = i
|
||||||
if last_user_input_index:
|
break
|
||||||
last_user_input = history.pop(last_user_input_index)
|
if last_user_input_index:
|
||||||
history.append(last_user_input)
|
last_user_input = history.pop(last_user_input_index)
|
||||||
|
history.append(last_user_input)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -189,8 +239,8 @@ class ModelMessage(BaseModel):
|
|||||||
return str_msg
|
return str_msg
|
||||||
|
|
||||||
|
|
||||||
_SingleRoundMessage = List[ModelMessage]
|
_SingleRoundMessage = List[BaseMessage]
|
||||||
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]]
|
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
|
||||||
|
|
||||||
|
|
||||||
def _message_to_dict(message: BaseMessage) -> Dict:
|
def _message_to_dict(message: BaseMessage) -> Dict:
|
||||||
@@ -338,7 +388,8 @@ class OnceConversation:
|
|||||||
"""Start a new round of conversation
|
"""Start a new round of conversation
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> conversation = OnceConversation()
|
|
||||||
|
>>> conversation = OnceConversation("chat_normal")
|
||||||
>>> # The chat order will be 0, then we start a new round of conversation
|
>>> # The chat order will be 0, then we start a new round of conversation
|
||||||
>>> assert conversation.chat_order == 0
|
>>> assert conversation.chat_order == 0
|
||||||
>>> conversation.start_new_round()
|
>>> conversation.start_new_round()
|
||||||
@@ -585,6 +636,28 @@ class OnceConversation:
|
|||||||
)
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def get_history_message(
|
||||||
|
self, include_system_message: bool = False
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Get the history message
|
||||||
|
|
||||||
|
Not include the system messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_system_message (bool): Whether to include the system message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseMessage]: The history messages
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
for message in self.messages:
|
||||||
|
if message.pass_to_model:
|
||||||
|
if include_system_message:
|
||||||
|
messages.append(message)
|
||||||
|
elif message.type != "system":
|
||||||
|
messages.append(message)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
class ConversationIdentifier(ResourceIdentifier):
|
class ConversationIdentifier(ResourceIdentifier):
|
||||||
"""Conversation identifier"""
|
"""Conversation identifier"""
|
||||||
|
114
dbgpt/core/interface/operator/composer_operator.py
Normal file
114
dbgpt/core/interface/operator/composer_operator.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import dataclasses
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from dbgpt.core import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
MessageStorageItem,
|
||||||
|
ModelMessage,
|
||||||
|
ModelRequest,
|
||||||
|
StorageConversation,
|
||||||
|
StorageInterface,
|
||||||
|
)
|
||||||
|
from dbgpt.core.awel import (
|
||||||
|
DAG,
|
||||||
|
BaseOperator,
|
||||||
|
InputOperator,
|
||||||
|
JoinOperator,
|
||||||
|
MapOperator,
|
||||||
|
SimpleCallDataInputSource,
|
||||||
|
)
|
||||||
|
from dbgpt.core.interface.operator.prompt_operator import HistoryPromptBuilderOperator
|
||||||
|
|
||||||
|
from .message_operator import (
|
||||||
|
BufferedConversationMapperOperator,
|
||||||
|
ChatHistoryLoadType,
|
||||||
|
PreChatHistoryLoadOperator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ChatComposerInput:
|
||||||
|
"""The composer input."""
|
||||||
|
|
||||||
|
prompt_dict: Dict[str, Any]
|
||||||
|
model_dict: Dict[str, Any]
|
||||||
|
context: ChatHistoryLoadType
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
|
||||||
|
"""The chat history prompt composer operator.
|
||||||
|
|
||||||
|
For simple use, you can use this operator to compose the chat history prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_template: ChatPromptTemplate,
|
||||||
|
history_key: str = "chat_history",
|
||||||
|
last_k_round: int = 2,
|
||||||
|
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||||
|
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._prompt_template = prompt_template
|
||||||
|
self._history_key = history_key
|
||||||
|
self._last_k_round = last_k_round
|
||||||
|
self._storage = storage
|
||||||
|
self._message_storage = message_storage
|
||||||
|
self._sub_compose_dag = self._build_composer_dag()
|
||||||
|
|
||||||
|
async def map(self, input_value: ChatComposerInput) -> ModelRequest:
|
||||||
|
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
||||||
|
# Sub dag, use the same dag context in the parent dag
|
||||||
|
return await end_node.call(
|
||||||
|
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_composer_dag(self) -> DAG:
|
||||||
|
with DAG("dbgpt_awel_chat_history_prompt_composer") as composer_dag:
|
||||||
|
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||||
|
# Load and store chat history, default use InMemoryStorage.
|
||||||
|
chat_history_load_task = PreChatHistoryLoadOperator(
|
||||||
|
storage=self._storage, message_storage=self._message_storage
|
||||||
|
)
|
||||||
|
# History transform task, here we keep last 5 round messages
|
||||||
|
history_transform_task = BufferedConversationMapperOperator(
|
||||||
|
last_k_round=self._last_k_round
|
||||||
|
)
|
||||||
|
history_prompt_build_task = HistoryPromptBuilderOperator(
|
||||||
|
prompt=self._prompt_template, history_key=self._history_key
|
||||||
|
)
|
||||||
|
model_request_build_task = JoinOperator(self._build_model_request)
|
||||||
|
|
||||||
|
# Build composer dag
|
||||||
|
(
|
||||||
|
input_task
|
||||||
|
>> MapOperator(lambda x: x.context)
|
||||||
|
>> chat_history_load_task
|
||||||
|
>> history_transform_task
|
||||||
|
>> history_prompt_build_task
|
||||||
|
)
|
||||||
|
(
|
||||||
|
input_task
|
||||||
|
>> MapOperator(lambda x: x.prompt_dict)
|
||||||
|
>> history_prompt_build_task
|
||||||
|
)
|
||||||
|
|
||||||
|
history_prompt_build_task >> model_request_build_task
|
||||||
|
(
|
||||||
|
input_task
|
||||||
|
>> MapOperator(lambda x: x.model_dict)
|
||||||
|
>> model_request_build_task
|
||||||
|
)
|
||||||
|
|
||||||
|
return composer_dag
|
||||||
|
|
||||||
|
def _build_model_request(
|
||||||
|
self, messages: List[ModelMessage], model_dict: Dict[str, Any]
|
||||||
|
) -> ModelRequest:
|
||||||
|
return ModelRequest.build_request(messages=messages, **model_dict)
|
||||||
|
|
||||||
|
async def after_dag_end(self):
|
||||||
|
# Should call after_dag_end() of sub dag
|
||||||
|
await self._sub_compose_dag._after_dag_end()
|
@@ -1,11 +1,12 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, AsyncIterator, Dict, Optional, Union
|
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel
|
from dbgpt._private.pydantic import BaseModel
|
||||||
from dbgpt.core.awel import (
|
from dbgpt.core.awel import (
|
||||||
BranchFunc,
|
BranchFunc,
|
||||||
BranchOperator,
|
BranchOperator,
|
||||||
|
DAGContext,
|
||||||
MapOperator,
|
MapOperator,
|
||||||
StreamifyAbsOperator,
|
StreamifyAbsOperator,
|
||||||
)
|
)
|
||||||
@@ -22,20 +23,30 @@ RequestInput = Union[
|
|||||||
str,
|
str,
|
||||||
Dict[str, Any],
|
Dict[str, Any],
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
ModelMessage,
|
||||||
|
List[ModelMessage],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class RequestBuildOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
||||||
|
"""Build the model request from the input value."""
|
||||||
|
|
||||||
def __init__(self, model: Optional[str] = None, **kwargs):
|
def __init__(self, model: Optional[str] = None, **kwargs):
|
||||||
self._model = model
|
self._model = model
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
async def map(self, input_value: RequestInput) -> ModelRequest:
|
async def map(self, input_value: RequestInput) -> ModelRequest:
|
||||||
req_dict = {}
|
req_dict = {}
|
||||||
|
if not input_value:
|
||||||
|
raise ValueError("input_value is not set")
|
||||||
if isinstance(input_value, str):
|
if isinstance(input_value, str):
|
||||||
req_dict = {"messages": [ModelMessage.build_human_message(input_value)]}
|
req_dict = {"messages": [ModelMessage.build_human_message(input_value)]}
|
||||||
elif isinstance(input_value, dict):
|
elif isinstance(input_value, dict):
|
||||||
req_dict = input_value
|
req_dict = input_value
|
||||||
|
elif isinstance(input_value, ModelMessage):
|
||||||
|
req_dict = {"messages": [input_value]}
|
||||||
|
elif isinstance(input_value, list) and isinstance(input_value[0], ModelMessage):
|
||||||
|
req_dict = {"messages": input_value}
|
||||||
elif dataclasses.is_dataclass(input_value):
|
elif dataclasses.is_dataclass(input_value):
|
||||||
req_dict = dataclasses.asdict(input_value)
|
req_dict = dataclasses.asdict(input_value)
|
||||||
elif isinstance(input_value, BaseModel):
|
elif isinstance(input_value, BaseModel):
|
||||||
@@ -76,6 +87,7 @@ class BaseLLM:
|
|||||||
"""The abstract operator for a LLM."""
|
"""The abstract operator for a LLM."""
|
||||||
|
|
||||||
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
|
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
|
||||||
|
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
|
||||||
|
|
||||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||||
self._llm_client = llm_client
|
self._llm_client = llm_client
|
||||||
@@ -87,8 +99,16 @@ class BaseLLM:
|
|||||||
raise ValueError("llm_client is not set")
|
raise ValueError("llm_client is not set")
|
||||||
return self._llm_client
|
return self._llm_client
|
||||||
|
|
||||||
|
async def save_model_output(
|
||||||
|
self, current_dag_context: DAGContext, model_output: ModelOutput
|
||||||
|
) -> None:
|
||||||
|
"""Save the model output to the share data."""
|
||||||
|
await current_dag_context.save_to_share_data(
|
||||||
|
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
|
||||||
|
)
|
||||||
|
|
||||||
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
|
||||||
|
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||||
"""The operator for a LLM.
|
"""The operator for a LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -105,10 +125,12 @@ class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
|||||||
await self.current_dag_context.save_to_share_data(
|
await self.current_dag_context.save_to_share_data(
|
||||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||||
)
|
)
|
||||||
return await self.llm_client.generate(request)
|
model_output = await self.llm_client.generate(request)
|
||||||
|
await self.save_model_output(self.current_dag_context, model_output)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
|
||||||
class StreamingLLMOperator(
|
class BaseStreamingLLMOperator(
|
||||||
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
||||||
):
|
):
|
||||||
"""The streaming operator for a LLM.
|
"""The streaming operator for a LLM.
|
||||||
@@ -127,8 +149,12 @@ class StreamingLLMOperator(
|
|||||||
await self.current_dag_context.save_to_share_data(
|
await self.current_dag_context.save_to_share_data(
|
||||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||||
)
|
)
|
||||||
|
model_output = None
|
||||||
async for output in self.llm_client.generate_stream(request):
|
async for output in self.llm_client.generate_stream(request):
|
||||||
|
model_output = output
|
||||||
yield output
|
yield output
|
||||||
|
if model_output:
|
||||||
|
await self.save_model_output(self.current_dag_context, model_output)
|
||||||
|
|
||||||
|
|
||||||
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||||
|
@@ -1,19 +1,17 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC
|
||||||
from typing import Any, AsyncIterator, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt.core import (
|
from dbgpt.core import (
|
||||||
MessageStorageItem,
|
MessageStorageItem,
|
||||||
ModelMessage,
|
ModelMessage,
|
||||||
ModelMessageRoleType,
|
ModelMessageRoleType,
|
||||||
ModelOutput,
|
|
||||||
ModelRequest,
|
|
||||||
ModelRequestContext,
|
ModelRequestContext,
|
||||||
StorageConversation,
|
StorageConversation,
|
||||||
StorageInterface,
|
StorageInterface,
|
||||||
)
|
)
|
||||||
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
|
from dbgpt.core.awel import BaseOperator, MapOperator
|
||||||
from dbgpt.core.interface.message import _MultiRoundMessageMapper
|
from dbgpt.core.interface.message import BaseMessage, _MultiRoundMessageMapper
|
||||||
|
|
||||||
|
|
||||||
class BaseConversationOperator(BaseOperator, ABC):
|
class BaseConversationOperator(BaseOperator, ABC):
|
||||||
@@ -21,32 +19,41 @@ class BaseConversationOperator(BaseOperator, ABC):
|
|||||||
|
|
||||||
SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation"
|
SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation"
|
||||||
SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request"
|
SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request"
|
||||||
|
SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT = "share_data_key_model_request_context"
|
||||||
|
|
||||||
|
_check_storage: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||||
|
check_storage: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
self._check_storage = check_storage
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._storage = storage
|
self._storage = storage
|
||||||
self._message_storage = message_storage
|
self._message_storage = message_storage
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def storage(self) -> StorageInterface[StorageConversation, Any]:
|
def storage(self) -> Optional[StorageInterface[StorageConversation, Any]]:
|
||||||
"""Return the LLM client."""
|
"""Return the LLM client."""
|
||||||
if not self._storage:
|
if not self._storage:
|
||||||
raise ValueError("Storage is not set")
|
if self._check_storage:
|
||||||
|
raise ValueError("Storage is not set")
|
||||||
|
return None
|
||||||
return self._storage
|
return self._storage
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def message_storage(self) -> StorageInterface[MessageStorageItem, Any]:
|
def message_storage(self) -> Optional[StorageInterface[MessageStorageItem, Any]]:
|
||||||
"""Return the LLM client."""
|
"""Return the LLM client."""
|
||||||
if not self._message_storage:
|
if not self._message_storage:
|
||||||
raise ValueError("Message storage is not set")
|
if self._check_storage:
|
||||||
|
raise ValueError("Message storage is not set")
|
||||||
|
return None
|
||||||
return self._message_storage
|
return self._message_storage
|
||||||
|
|
||||||
async def get_storage_conversation(self) -> StorageConversation:
|
async def get_storage_conversation(self) -> Optional[StorageConversation]:
|
||||||
"""Get the storage conversation from share data.
|
"""Get the storage conversation from share data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -58,104 +65,11 @@ class BaseConversationOperator(BaseOperator, ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not storage_conv:
|
if not storage_conv:
|
||||||
raise ValueError("Storage conversation is not set")
|
if self._check_storage:
|
||||||
|
raise ValueError("Storage conversation is not set")
|
||||||
|
return None
|
||||||
return storage_conv
|
return storage_conv
|
||||||
|
|
||||||
async def get_model_request(self) -> ModelRequest:
|
|
||||||
"""Get the model request from share data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelRequest: The model request.
|
|
||||||
"""
|
|
||||||
model_request: ModelRequest = (
|
|
||||||
await self.current_dag_context.get_from_share_data(
|
|
||||||
self.SHARE_DATA_KEY_MODEL_REQUEST
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not model_request:
|
|
||||||
raise ValueError("Model request is not set")
|
|
||||||
return model_request
|
|
||||||
|
|
||||||
|
|
||||||
class PreConversationOperator(
|
|
||||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
|
||||||
):
|
|
||||||
"""The operator to prepare the storage conversation.
|
|
||||||
|
|
||||||
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
|
|
||||||
and they can store in different storage(for high performance).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
|
||||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(storage=storage, message_storage=message_storage)
|
|
||||||
MapOperator.__init__(self, **kwargs)
|
|
||||||
|
|
||||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
|
||||||
"""Map the input value to a ModelRequest.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_value (ModelRequest): The input value.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelRequest: The mapped ModelRequest.
|
|
||||||
"""
|
|
||||||
if input_value.context is None:
|
|
||||||
input_value.context = ModelRequestContext()
|
|
||||||
if not input_value.context.conv_uid:
|
|
||||||
input_value.context.conv_uid = str(uuid.uuid4())
|
|
||||||
if not input_value.context.extra:
|
|
||||||
input_value.context.extra = {}
|
|
||||||
|
|
||||||
chat_mode = input_value.context.chat_mode
|
|
||||||
|
|
||||||
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
|
|
||||||
storage_conv: StorageConversation = await self.blocking_func_to_async(
|
|
||||||
StorageConversation,
|
|
||||||
conv_uid=input_value.context.conv_uid,
|
|
||||||
chat_mode=chat_mode,
|
|
||||||
user_name=input_value.context.user_name,
|
|
||||||
sys_code=input_value.context.sys_code,
|
|
||||||
conv_storage=self.storage,
|
|
||||||
message_storage=self.message_storage,
|
|
||||||
)
|
|
||||||
input_messages = input_value.get_messages()
|
|
||||||
await self.save_to_storage(storage_conv, input_messages)
|
|
||||||
# Get all messages from current storage conversation, and overwrite the input value
|
|
||||||
messages: List[ModelMessage] = storage_conv.get_model_messages()
|
|
||||||
input_value.messages = messages
|
|
||||||
|
|
||||||
# Save the storage conversation to share data, for the child operators
|
|
||||||
await self.current_dag_context.save_to_share_data(
|
|
||||||
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
|
|
||||||
)
|
|
||||||
await self.current_dag_context.save_to_share_data(
|
|
||||||
self.SHARE_DATA_KEY_MODEL_REQUEST, input_value
|
|
||||||
)
|
|
||||||
return input_value
|
|
||||||
|
|
||||||
async def save_to_storage(
|
|
||||||
self, storage_conv: StorageConversation, input_messages: List[ModelMessage]
|
|
||||||
) -> None:
|
|
||||||
"""Save the messages to storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_conv (StorageConversation): The storage conversation.
|
|
||||||
input_messages (List[ModelMessage]): The input messages.
|
|
||||||
"""
|
|
||||||
# check first
|
|
||||||
self.check_messages(input_messages)
|
|
||||||
storage_conv.start_new_round()
|
|
||||||
for message in input_messages:
|
|
||||||
if message.role == ModelMessageRoleType.HUMAN:
|
|
||||||
storage_conv.add_user_message(message.content)
|
|
||||||
else:
|
|
||||||
storage_conv.add_system_message(message.content)
|
|
||||||
|
|
||||||
def check_messages(self, messages: List[ModelMessage]) -> None:
|
def check_messages(self, messages: List[ModelMessage]) -> None:
|
||||||
"""Check the messages.
|
"""Check the messages.
|
||||||
|
|
||||||
@@ -174,164 +88,147 @@ class PreConversationOperator(
|
|||||||
]:
|
]:
|
||||||
raise ValueError(f"Message role {message.role} is not supported")
|
raise ValueError(f"Message role {message.role} is not supported")
|
||||||
|
|
||||||
async def after_dag_end(self):
|
|
||||||
"""The callback after DAG end"""
|
ChatHistoryLoadType = Union[ModelRequestContext, Dict[str, Any]]
|
||||||
# Save the storage conversation to storage after the whole DAG finished
|
|
||||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
|
||||||
# TODO dont save if the conversation has some internal error
|
|
||||||
storage_conv.end_current_round()
|
|
||||||
|
|
||||||
|
|
||||||
class PostConversationOperator(
|
class PreChatHistoryLoadOperator(
|
||||||
BaseConversationOperator, MapOperator[ModelOutput, ModelOutput]
|
BaseConversationOperator, MapOperator[ChatHistoryLoadType, List[BaseMessage]]
|
||||||
):
|
):
|
||||||
def __init__(self, **kwargs):
|
"""The operator to prepare the storage conversation.
|
||||||
|
|
||||||
|
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
|
||||||
|
and they can store in different storage(for high performance).
|
||||||
|
|
||||||
|
This operator just load the conversation and messages from storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||||
|
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||||
|
include_system_message: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(storage=storage, message_storage=message_storage)
|
||||||
MapOperator.__init__(self, **kwargs)
|
MapOperator.__init__(self, **kwargs)
|
||||||
|
self._include_system_message = include_system_message
|
||||||
|
|
||||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
async def map(self, input_value: ChatHistoryLoadType) -> List[BaseMessage]:
|
||||||
"""Map the input value to a ModelOutput.
|
"""Map the input value to a ModelRequest.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_value (ModelOutput): The input value.
|
input_value (ChatHistoryLoadType): The input value.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelOutput: The mapped ModelOutput.
|
List[BaseMessage]: The messages stored in the storage.
|
||||||
"""
|
"""
|
||||||
# Get the storage conversation from share data
|
if not input_value:
|
||||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
raise ValueError("Model request context can't be None")
|
||||||
storage_conv.add_ai_message(input_value.text)
|
if isinstance(input_value, dict):
|
||||||
return input_value
|
input_value = ModelRequestContext(**input_value)
|
||||||
|
if not input_value.conv_uid:
|
||||||
|
input_value.conv_uid = str(uuid.uuid4())
|
||||||
|
if not input_value.extra:
|
||||||
|
input_value.extra = {}
|
||||||
|
|
||||||
|
chat_mode = input_value.chat_mode
|
||||||
|
|
||||||
class PostStreamingConversationOperator(
|
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
|
||||||
BaseConversationOperator, TransformStreamAbsOperator[ModelOutput, ModelOutput]
|
storage_conv: StorageConversation = await self.blocking_func_to_async(
|
||||||
):
|
StorageConversation,
|
||||||
def __init__(self, **kwargs):
|
conv_uid=input_value.conv_uid,
|
||||||
TransformStreamAbsOperator.__init__(self, **kwargs)
|
chat_mode=chat_mode,
|
||||||
|
user_name=input_value.user_name,
|
||||||
|
sys_code=input_value.sys_code,
|
||||||
|
conv_storage=self.storage,
|
||||||
|
message_storage=self.message_storage,
|
||||||
|
)
|
||||||
|
|
||||||
async def transform_stream(
|
# Save the storage conversation to share data, for the child operators
|
||||||
self, input_value: AsyncIterator[ModelOutput]
|
await self.current_dag_context.save_to_share_data(
|
||||||
) -> ModelOutput:
|
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
|
||||||
"""Transform the input value to a ModelOutput.
|
)
|
||||||
|
await self.current_dag_context.save_to_share_data(
|
||||||
Args:
|
self.SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT, input_value
|
||||||
input_value (ModelOutput): The input value.
|
)
|
||||||
|
# Get history messages from storage
|
||||||
Returns:
|
history_messages: List[BaseMessage] = storage_conv.get_history_message(
|
||||||
ModelOutput: The transformed ModelOutput.
|
include_system_message=self._include_system_message
|
||||||
"""
|
)
|
||||||
full_text = ""
|
return history_messages
|
||||||
async for model_output in input_value:
|
|
||||||
# Now model_output.text if full text, if it is a delta text, we should merge all delta text to a full text
|
|
||||||
full_text = model_output.text
|
|
||||||
yield model_output
|
|
||||||
# Get the storage conversation from share data
|
|
||||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
|
||||||
storage_conv.add_ai_message(full_text)
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationMapperOperator(
|
class ConversationMapperOperator(
|
||||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
BaseConversationOperator, MapOperator[List[BaseMessage], List[BaseMessage]]
|
||||||
):
|
):
|
||||||
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
|
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
|
||||||
MapOperator.__init__(self, **kwargs)
|
MapOperator.__init__(self, **kwargs)
|
||||||
self._message_mapper = message_mapper
|
self._message_mapper = message_mapper
|
||||||
|
|
||||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]:
|
||||||
"""Map the input value to a ModelRequest.
|
return self.map_messages(input_value)
|
||||||
|
|
||||||
Args:
|
def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||||
input_value (ModelRequest): The input value.
|
messages_by_round: List[List[BaseMessage]] = self._split_messages_by_round(
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelRequest: The mapped ModelRequest.
|
|
||||||
"""
|
|
||||||
input_value = input_value.copy()
|
|
||||||
messages: List[ModelMessage] = self.map_messages(input_value.messages)
|
|
||||||
# Overwrite the input value
|
|
||||||
input_value.messages = messages
|
|
||||||
return input_value
|
|
||||||
|
|
||||||
def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
|
||||||
"""Map the input messages to a list of ModelMessage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (List[ModelMessage]): The input messages.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[ModelMessage]: The mapped ModelMessage.
|
|
||||||
"""
|
|
||||||
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
|
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
message_mapper = self._message_mapper or self.map_multi_round_messages
|
message_mapper = self._message_mapper or self.map_multi_round_messages
|
||||||
return message_mapper(messages_by_round)
|
return message_mapper(messages_by_round)
|
||||||
|
|
||||||
def map_multi_round_messages(
|
def map_multi_round_messages(
|
||||||
self, messages_by_round: List[List[ModelMessage]]
|
self, messages_by_round: List[List[BaseMessage]]
|
||||||
) -> List[ModelMessage]:
|
) -> List[BaseMessage]:
|
||||||
"""Map multi round messages to a list of ModelMessage
|
"""Map multi round messages to a list of BaseMessage.
|
||||||
|
|
||||||
By default, just merge all multi round messages to a list of ModelMessage according origin order.
|
By default, just merge all multi round messages to a list of BaseMessage according origin order.
|
||||||
And you can overwrite this method to implement your own logic.
|
And you can overwrite this method to implement your own logic.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
Merge multi round messages to a list of ModelMessage according origin order.
|
Merge multi round messages to a list of BaseMessage according origin order.
|
||||||
|
|
||||||
.. code-block:: python
|
>>> from dbgpt.core.interface.message import (
|
||||||
|
... AIMessage,
|
||||||
|
... HumanMessage,
|
||||||
|
... SystemMessage,
|
||||||
|
... )
|
||||||
|
>>> messages_by_round = [
|
||||||
|
... [
|
||||||
|
... HumanMessage(content="Hi", round_index=1),
|
||||||
|
... AIMessage(content="Hello!", round_index=1),
|
||||||
|
... ],
|
||||||
|
... [
|
||||||
|
... HumanMessage(content="What's the error?", round_index=2),
|
||||||
|
... AIMessage(content="Just a joke.", round_index=2),
|
||||||
|
... ],
|
||||||
|
... ]
|
||||||
|
>>> operator = ConversationMapperOperator()
|
||||||
|
>>> messages = operator.map_multi_round_messages(messages_by_round)
|
||||||
|
>>> assert messages == [
|
||||||
|
... HumanMessage(content="Hi", round_index=1),
|
||||||
|
... AIMessage(content="Hello!", round_index=1),
|
||||||
|
... HumanMessage(content="What's the error?", round_index=2),
|
||||||
|
... AIMessage(content="Just a joke.", round_index=2),
|
||||||
|
... ]
|
||||||
|
|
||||||
import asyncio
|
Map multi round messages to a list of BaseMessage just keep the last one round.
|
||||||
from dbgpt.core.operator import ConversationMapperOperator
|
|
||||||
|
|
||||||
messages_by_round = [
|
>>> class MyMapper(ConversationMapperOperator):
|
||||||
[
|
... def __init__(self, **kwargs):
|
||||||
ModelMessage(role="human", content="Hi", round_index=1),
|
... super().__init__(**kwargs)
|
||||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
...
|
||||||
],
|
... def map_multi_round_messages(
|
||||||
[
|
... self, messages_by_round: List[List[BaseMessage]]
|
||||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
... ) -> List[BaseMessage]:
|
||||||
ModelMessage(
|
... return messages_by_round[-1]
|
||||||
role="human", content="What's the error?", round_index=2
|
...
|
||||||
),
|
>>> operator = MyMapper()
|
||||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
>>> messages = operator.map_multi_round_messages(messages_by_round)
|
||||||
],
|
>>> assert messages == [
|
||||||
[
|
... HumanMessage(content="What's the error?", round_index=2),
|
||||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
... AIMessage(content="Just a joke.", round_index=2),
|
||||||
],
|
... ]
|
||||||
]
|
|
||||||
operator = ConversationMapperOperator()
|
|
||||||
messages = operator.map_multi_round_messages(messages_by_round)
|
|
||||||
assert messages == [
|
|
||||||
ModelMessage(role="human", content="Hi", round_index=1),
|
|
||||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
|
||||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
|
||||||
ModelMessage(
|
|
||||||
role="human", content="What's the error?", round_index=2
|
|
||||||
),
|
|
||||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
|
||||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
|
||||||
]
|
|
||||||
|
|
||||||
Map multi round messages to a list of ModelMessage just keep the last one round.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class MyMapper(ConversationMapperOperator):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def map_multi_round_messages(
|
|
||||||
self, messages_by_round: List[List[ModelMessage]]
|
|
||||||
) -> List[ModelMessage]:
|
|
||||||
return messages_by_round[-1]
|
|
||||||
|
|
||||||
|
|
||||||
operator = MyMapper()
|
|
||||||
messages = operator.map_multi_round_messages(messages_by_round)
|
|
||||||
assert messages == [
|
|
||||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
|
||||||
]
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
"""
|
"""
|
||||||
@@ -340,17 +237,17 @@ class ConversationMapperOperator(
|
|||||||
return sum(messages_by_round, [])
|
return sum(messages_by_round, [])
|
||||||
|
|
||||||
def _split_messages_by_round(
|
def _split_messages_by_round(
|
||||||
self, messages: List[ModelMessage]
|
self, messages: List[BaseMessage]
|
||||||
) -> List[List[ModelMessage]]:
|
) -> List[List[BaseMessage]]:
|
||||||
"""Split the messages by round index.
|
"""Split the messages by round index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[ModelMessage]): The input messages.
|
messages (List[BaseMessage]): The messages.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[ModelMessage]]: The split messages.
|
List[List[BaseMessage]]: The messages split by round.
|
||||||
"""
|
"""
|
||||||
messages_by_round: List[List[ModelMessage]] = []
|
messages_by_round: List[List[BaseMessage]] = []
|
||||||
last_round_index = 0
|
last_round_index = 0
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if not message.round_index:
|
if not message.round_index:
|
||||||
@@ -366,7 +263,7 @@ class ConversationMapperOperator(
|
|||||||
class BufferedConversationMapperOperator(ConversationMapperOperator):
|
class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||||
"""The buffered conversation mapper operator.
|
"""The buffered conversation mapper operator.
|
||||||
|
|
||||||
This Operator must be used after the PreConversationOperator,
|
This Operator must be used after the PreChatHistoryLoadOperator,
|
||||||
and it will map the messages in the storage conversation.
|
and it will map the messages in the storage conversation.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@@ -419,8 +316,8 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
|||||||
if message_mapper:
|
if message_mapper:
|
||||||
|
|
||||||
def new_message_mapper(
|
def new_message_mapper(
|
||||||
messages_by_round: List[List[ModelMessage]],
|
messages_by_round: List[List[BaseMessage]],
|
||||||
) -> List[ModelMessage]:
|
) -> List[BaseMessage]:
|
||||||
# Apply keep k round messages first, then apply the custom message mapper
|
# Apply keep k round messages first, then apply the custom message mapper
|
||||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||||
return message_mapper(messages_by_round)
|
return message_mapper(messages_by_round)
|
||||||
@@ -428,23 +325,23 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
def new_message_mapper(
|
def new_message_mapper(
|
||||||
messages_by_round: List[List[ModelMessage]],
|
messages_by_round: List[List[BaseMessage]],
|
||||||
) -> List[ModelMessage]:
|
) -> List[BaseMessage]:
|
||||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||||
return sum(messages_by_round, [])
|
return sum(messages_by_round, [])
|
||||||
|
|
||||||
super().__init__(new_message_mapper, **kwargs)
|
super().__init__(new_message_mapper, **kwargs)
|
||||||
|
|
||||||
def _keep_last_round_messages(
|
def _keep_last_round_messages(
|
||||||
self, messages_by_round: List[List[ModelMessage]]
|
self, messages_by_round: List[List[BaseMessage]]
|
||||||
) -> List[List[ModelMessage]]:
|
) -> List[List[BaseMessage]]:
|
||||||
"""Keep the last k round messages.
|
"""Keep the last k round messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages_by_round (List[List[ModelMessage]]): The messages by round.
|
messages_by_round (List[List[BaseMessage]]): The messages by round.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[ModelMessage]]: The latest round messages.
|
List[List[BaseMessage]]: The latest round messages.
|
||||||
"""
|
"""
|
||||||
index = self._last_k_round + 1
|
index = self._last_k_round + 1
|
||||||
return messages_by_round[-index:]
|
return messages_by_round[-index:]
|
||||||
|
255
dbgpt/core/interface/operator/prompt_operator.py
Normal file
255
dbgpt/core/interface/operator/prompt_operator.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from dbgpt.core import (
|
||||||
|
BasePromptTemplate,
|
||||||
|
ChatPromptTemplate,
|
||||||
|
ModelMessage,
|
||||||
|
ModelMessageRoleType,
|
||||||
|
ModelOutput,
|
||||||
|
StorageConversation,
|
||||||
|
)
|
||||||
|
from dbgpt.core.awel import JoinOperator, MapOperator
|
||||||
|
from dbgpt.core.interface.message import BaseMessage
|
||||||
|
from dbgpt.core.interface.operator.llm_operator import BaseLLM
|
||||||
|
from dbgpt.core.interface.operator.message_operator import BaseConversationOperator
|
||||||
|
from dbgpt.core.interface.prompt import HumanPromptTemplate, MessageType
|
||||||
|
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||||
|
|
||||||
|
|
||||||
|
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||||
|
"""The base prompt builder operator."""
|
||||||
|
|
||||||
|
def __init__(self, check_storage: bool, **kwargs):
|
||||||
|
super().__init__(check_storage=check_storage, **kwargs)
|
||||||
|
|
||||||
|
async def format_prompt(
|
||||||
|
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
|
||||||
|
) -> List[ModelMessage]:
|
||||||
|
"""Format the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (ChatPromptTemplate): The prompt.
|
||||||
|
prompt_dict (Dict[str, Any]): The prompt dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ModelMessage]: The formatted prompt.
|
||||||
|
"""
|
||||||
|
kwargs = {}
|
||||||
|
kwargs.update(prompt_dict)
|
||||||
|
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
|
||||||
|
messages = prompt.format_messages(**pass_kwargs)
|
||||||
|
messages = ModelMessage.from_base_messages(messages)
|
||||||
|
# Start new round conversation, and save user message to storage
|
||||||
|
await self.start_new_round_conv(messages)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
|
||||||
|
"""Start a new round conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[ModelMessage]): The messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lass_user_message = None
|
||||||
|
for message in messages[::-1]:
|
||||||
|
if message.role == ModelMessageRoleType.HUMAN:
|
||||||
|
lass_user_message = message.content
|
||||||
|
break
|
||||||
|
if not lass_user_message:
|
||||||
|
raise ValueError("No user message")
|
||||||
|
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||||
|
if not storage_conv:
|
||||||
|
return
|
||||||
|
# Start new round
|
||||||
|
storage_conv.start_new_round()
|
||||||
|
storage_conv.add_user_message(lass_user_message)
|
||||||
|
|
||||||
|
async def after_dag_end(self):
|
||||||
|
"""The callback after DAG end"""
|
||||||
|
# TODO remove this to start_new_round()
|
||||||
|
# Save the storage conversation to storage after the whole DAG finished
|
||||||
|
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||||
|
if not storage_conv:
|
||||||
|
return
|
||||||
|
model_output: ModelOutput = await self.current_dag_context.get_from_share_data(
|
||||||
|
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT
|
||||||
|
)
|
||||||
|
if model_output:
|
||||||
|
# Save model output message to storage
|
||||||
|
storage_conv.add_ai_message(model_output.text)
|
||||||
|
# End current conversation round and flush to storage
|
||||||
|
storage_conv.end_current_round()
|
||||||
|
|
||||||
|
|
||||||
|
PromptTemplateType = Union[ChatPromptTemplate, BasePromptTemplate, MessageType, str]
|
||||||
|
|
||||||
|
|
||||||
|
class PromptBuilderOperator(
|
||||||
|
BasePromptBuilderOperator, MapOperator[Dict[str, Any], List[ModelMessage]]
|
||||||
|
):
|
||||||
|
"""The operator to build the prompt with static prompt.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dbgpt.core.awel import DAG
|
||||||
|
from dbgpt.core import (
|
||||||
|
ModelMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
HumanPromptTemplate,
|
||||||
|
SystemPromptTemplate,
|
||||||
|
ChatPromptTemplate,
|
||||||
|
)
|
||||||
|
from dbgpt.core.operator import PromptBuilderOperator
|
||||||
|
|
||||||
|
with DAG("prompt_test") as dag:
|
||||||
|
str_prompt = PromptBuilderOperator(
|
||||||
|
"Please write a {dialect} SQL count the length of a field"
|
||||||
|
)
|
||||||
|
tp_prompt = PromptBuilderOperator(
|
||||||
|
HumanPromptTemplate.from_template(
|
||||||
|
"Please write a {dialect} SQL count the length of a field"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
chat_prompt = PromptBuilderOperator(
|
||||||
|
ChatPromptTemplate(
|
||||||
|
messages=[
|
||||||
|
HumanPromptTemplate.from_template(
|
||||||
|
"Please write a {dialect} SQL count the length of a field"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with_sys_prompt = PromptBuilderOperator(
|
||||||
|
ChatPromptTemplate(
|
||||||
|
messages=[
|
||||||
|
SystemPromptTemplate.from_template(
|
||||||
|
"You are a {dialect} SQL expert"
|
||||||
|
),
|
||||||
|
HumanPromptTemplate.from_template(
|
||||||
|
"Please write a {dialect} SQL count the length of a field"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
single_input = {"data": {"dialect": "mysql"}}
|
||||||
|
single_expected_messages = [
|
||||||
|
ModelMessage(
|
||||||
|
content="Please write a mysql SQL count the length of a field",
|
||||||
|
role="human",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
with_sys_expected_messages = [
|
||||||
|
ModelMessage(content="You are a mysql SQL expert", role="system"),
|
||||||
|
ModelMessage(
|
||||||
|
content="Please write a mysql SQL count the length of a field",
|
||||||
|
role="human",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
asyncio.run(str_prompt.call(call_data=single_input))
|
||||||
|
== single_expected_messages
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
asyncio.run(tp_prompt.call(call_data=single_input))
|
||||||
|
== single_expected_messages
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
asyncio.run(chat_prompt.call(call_data=single_input))
|
||||||
|
== single_expected_messages
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
asyncio.run(with_sys_prompt.call(call_data=single_input))
|
||||||
|
== with_sys_expected_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prompt: PromptTemplateType, **kwargs):
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = ChatPromptTemplate(
|
||||||
|
messages=[HumanPromptTemplate.from_template(prompt)]
|
||||||
|
)
|
||||||
|
elif isinstance(prompt, BasePromptTemplate) and not isinstance(
|
||||||
|
prompt, ChatPromptTemplate
|
||||||
|
):
|
||||||
|
prompt = ChatPromptTemplate(
|
||||||
|
messages=[HumanPromptTemplate.from_template(prompt.template)]
|
||||||
|
)
|
||||||
|
elif isinstance(prompt, MessageType):
|
||||||
|
prompt = ChatPromptTemplate(messages=[prompt])
|
||||||
|
self._prompt = prompt
|
||||||
|
|
||||||
|
super().__init__(check_storage=False, **kwargs)
|
||||||
|
MapOperator.__init__(self, map_function=self.merge_prompt, **kwargs)
|
||||||
|
|
||||||
|
@rearrange_args_by_type
|
||||||
|
async def merge_prompt(self, prompt_dict: Dict[str, Any]) -> List[ModelMessage]:
|
||||||
|
return await self.format_prompt(self._prompt, prompt_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicPromptBuilderOperator(
|
||||||
|
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
|
||||||
|
):
|
||||||
|
"""The operator to build the prompt with dynamic prompt.
|
||||||
|
|
||||||
|
The prompt template is dynamic, and it created by parent operator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(check_storage=False, **kwargs)
|
||||||
|
JoinOperator.__init__(self, combine_function=self.merge_prompt, **kwargs)
|
||||||
|
|
||||||
|
@rearrange_args_by_type
|
||||||
|
async def merge_prompt(
|
||||||
|
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
|
||||||
|
) -> List[ModelMessage]:
|
||||||
|
return await self.format_prompt(prompt, prompt_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryPromptBuilderOperator(
|
||||||
|
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self, prompt: ChatPromptTemplate, history_key: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
self._prompt = prompt
|
||||||
|
self._history_key = history_key
|
||||||
|
|
||||||
|
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||||
|
|
||||||
|
@rearrange_args_by_type
|
||||||
|
async def merge_history(
|
||||||
|
self, history: List[BaseMessage], prompt_dict: Dict[str, Any]
|
||||||
|
) -> List[ModelMessage]:
|
||||||
|
prompt_dict[self._history_key] = history
|
||||||
|
return await self.format_prompt(self._prompt, prompt_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryDynamicPromptBuilderOperator(
|
||||||
|
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
|
||||||
|
):
|
||||||
|
"""The operator to build the prompt with dynamic prompt.
|
||||||
|
|
||||||
|
The prompt template is dynamic, and it created by parent operator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, history_key: Optional[str] = None, **kwargs):
|
||||||
|
self._history_key = history_key
|
||||||
|
|
||||||
|
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||||
|
|
||||||
|
@rearrange_args_by_type
|
||||||
|
async def merge_history(
|
||||||
|
self,
|
||||||
|
prompt: ChatPromptTemplate,
|
||||||
|
history: List[BaseMessage],
|
||||||
|
prompt_dict: Dict[str, Any],
|
||||||
|
) -> List[ModelMessage]:
|
||||||
|
prompt_dict[self._history_key] = history
|
||||||
|
return await self.format_prompt(prompt, prompt_dict)
|
@@ -1,11 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from string import Formatter
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel
|
from dbgpt._private.pydantic import BaseModel, root_validator
|
||||||
from dbgpt.core._private.example_base import ExampleSelector
|
from dbgpt.core._private.example_base import ExampleSelector
|
||||||
from dbgpt.core.awel import MapOperator
|
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||||
from dbgpt.core.interface.storage import (
|
from dbgpt.core.interface.storage import (
|
||||||
InMemoryStorage,
|
InMemoryStorage,
|
||||||
@@ -38,15 +41,40 @@ _DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplate(BaseModel, ABC):
|
class BasePromptTemplate(BaseModel):
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
"""A list of the names of the variables the prompt template expects."""
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
|
|
||||||
|
template: Optional[str]
|
||||||
|
"""The prompt template."""
|
||||||
|
|
||||||
|
template_format: Optional[str] = "f-string"
|
||||||
|
|
||||||
|
def format(self, **kwargs: Any) -> str:
|
||||||
|
"""Format the prompt with the inputs."""
|
||||||
|
if self.template:
|
||||||
|
return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)(
|
||||||
|
self.template, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_template(
|
||||||
|
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
|
||||||
|
) -> BasePromptTemplate:
|
||||||
|
"""Create a prompt template from a template string."""
|
||||||
|
input_variables = get_template_vars(template, template_format)
|
||||||
|
return cls(
|
||||||
|
template=template,
|
||||||
|
input_variables=input_variables,
|
||||||
|
template_format=template_format,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplate(BasePromptTemplate):
|
||||||
template_scene: Optional[str]
|
template_scene: Optional[str]
|
||||||
template_define: Optional[str]
|
template_define: Optional[str]
|
||||||
"""this template define"""
|
"""this template define"""
|
||||||
template: Optional[str]
|
|
||||||
"""The prompt template."""
|
|
||||||
template_format: str = "f-string"
|
|
||||||
"""strict template will check template args"""
|
"""strict template will check template args"""
|
||||||
template_is_strict: bool = True
|
template_is_strict: bool = True
|
||||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||||
@@ -86,12 +114,114 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
self.template_is_strict
|
self.template_is_strict
|
||||||
)(self.template, **kwargs)
|
)(self.template, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_template(template: str) -> "PromptTemplateOperator":
|
class BaseChatPromptTemplate(BaseModel, ABC):
|
||||||
|
prompt: BasePromptTemplate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_variables(self) -> List[str]:
|
||||||
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
|
return self.prompt.input_variables
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
"""Format the prompt with the inputs."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_template(
|
||||||
|
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
|
||||||
|
) -> BaseChatPromptTemplate:
|
||||||
"""Create a prompt template from a template string."""
|
"""Create a prompt template from a template string."""
|
||||||
return PromptTemplateOperator(
|
prompt = BasePromptTemplate.from_template(template, template_format)
|
||||||
PromptTemplate(template=template, input_variables=[])
|
return cls(prompt=prompt, **kwargs)
|
||||||
)
|
|
||||||
|
|
||||||
|
class SystemPromptTemplate(BaseChatPromptTemplate):
|
||||||
|
"""The system prompt template."""
|
||||||
|
|
||||||
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
content = self.prompt.format(**kwargs)
|
||||||
|
return [SystemMessage(content=content)]
|
||||||
|
|
||||||
|
|
||||||
|
class HumanPromptTemplate(BaseChatPromptTemplate):
|
||||||
|
"""The human prompt template."""
|
||||||
|
|
||||||
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
content = self.prompt.format(**kwargs)
|
||||||
|
return [HumanMessage(content=content)]
|
||||||
|
|
||||||
|
|
||||||
|
class MessagesPlaceholder(BaseChatPromptTemplate):
|
||||||
|
"""The messages placeholder template.
|
||||||
|
|
||||||
|
Mostly used for the chat history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
variable_name: str
|
||||||
|
prompt: BasePromptTemplate = None
|
||||||
|
|
||||||
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
messages = kwargs.get(self.variable_name, [])
|
||||||
|
if not isinstance(messages, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported messages type: {type(messages)}, should be list."
|
||||||
|
)
|
||||||
|
for message in messages:
|
||||||
|
if not isinstance(message, BaseMessage):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported message type: {type(message)}, should be BaseMessage."
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_variables(self) -> List[str]:
|
||||||
|
"""A list of the names of the variables the prompt template expects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: The input variables.
|
||||||
|
"""
|
||||||
|
return [self.variable_name]
|
||||||
|
|
||||||
|
|
||||||
|
MessageType = Union[BaseChatPromptTemplate, BaseMessage]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatPromptTemplate(BasePromptTemplate):
|
||||||
|
messages: List[MessageType]
|
||||||
|
|
||||||
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
"""Format the prompt with the inputs."""
|
||||||
|
result_messages = []
|
||||||
|
for message in self.messages:
|
||||||
|
if isinstance(message, BaseMessage):
|
||||||
|
result_messages.append(message)
|
||||||
|
elif isinstance(message, BaseChatPromptTemplate):
|
||||||
|
pass_kwargs = {
|
||||||
|
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||||
|
}
|
||||||
|
result_messages.extend(message.format_messages(**pass_kwargs))
|
||||||
|
elif isinstance(message, MessagesPlaceholder):
|
||||||
|
pass_kwargs = {
|
||||||
|
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||||
|
}
|
||||||
|
result_messages.extend(message.format_messages(**pass_kwargs))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||||
|
return result_messages
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Pre-fill the messages."""
|
||||||
|
input_variables = values.get("input_variables", {})
|
||||||
|
messages = values.get("messages", [])
|
||||||
|
if not input_variables:
|
||||||
|
input_variables = set()
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, BaseChatPromptTemplate):
|
||||||
|
input_variables.update(message.input_variables)
|
||||||
|
values["input_variables"] = sorted(input_variables)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -547,10 +677,36 @@ class PromptManager:
|
|||||||
self.storage.delete(identifier)
|
self.storage.delete(identifier)
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateOperator(MapOperator[Dict, str]):
|
def _get_string_template_vars(template_str: str) -> Set[str]:
|
||||||
def __init__(self, prompt_template: PromptTemplate, **kwargs: Any):
|
"""Get template variables from a template string."""
|
||||||
super().__init__(**kwargs)
|
variables = set()
|
||||||
self._prompt_template = prompt_template
|
formatter = Formatter()
|
||||||
|
|
||||||
async def map(self, input_value: Dict) -> str:
|
for _, variable_name, _, _ in formatter.parse(template_str):
|
||||||
return self._prompt_template.format(**input_value)
|
if variable_name:
|
||||||
|
variables.add(variable_name)
|
||||||
|
|
||||||
|
return variables
|
||||||
|
|
||||||
|
|
||||||
|
def _get_jinja2_template_vars(template_str: str) -> Set[str]:
|
||||||
|
"""Get template variables from a template string."""
|
||||||
|
from jinja2 import Environment, meta
|
||||||
|
|
||||||
|
env = Environment()
|
||||||
|
ast = env.parse(template_str)
|
||||||
|
variables = meta.find_undeclared_variables(ast)
|
||||||
|
return variables
|
||||||
|
|
||||||
|
|
||||||
|
def get_template_vars(
|
||||||
|
template_str: str, template_format: str = "f-string"
|
||||||
|
) -> List[str]:
|
||||||
|
"""Get template variables from a template string."""
|
||||||
|
if template_format == "f-string":
|
||||||
|
result = _get_string_template_vars(template_str)
|
||||||
|
elif template_format == "jinja2":
|
||||||
|
result = _get_jinja2_template_vars(template_str)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported template format: {template_format}")
|
||||||
|
return sorted(result)
|
||||||
|
@@ -413,13 +413,18 @@ def test_to_openai_messages(
|
|||||||
{"role": "user", "content": human_model_message.content},
|
{"role": "user", "content": human_model_message.content},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_openai_messages_convert_to_compatible_format(
|
||||||
|
human_model_message, ai_model_message, system_model_message
|
||||||
|
):
|
||||||
shuffle_messages = ModelMessage.to_openai_messages(
|
shuffle_messages = ModelMessage.to_openai_messages(
|
||||||
[
|
[
|
||||||
system_model_message,
|
system_model_message,
|
||||||
human_model_message,
|
human_model_message,
|
||||||
human_model_message,
|
human_model_message,
|
||||||
ai_model_message,
|
ai_model_message,
|
||||||
]
|
],
|
||||||
|
convert_to_compatible_format=True,
|
||||||
)
|
)
|
||||||
assert shuffle_messages == [
|
assert shuffle_messages == [
|
||||||
{"role": "system", "content": system_model_message.content},
|
{"role": "system", "content": system_model_message.content},
|
||||||
|
@@ -99,12 +99,6 @@ class TestPromptTemplate:
|
|||||||
formatted_output = prompt.format(response="hello")
|
formatted_output = prompt.format(response="hello")
|
||||||
assert "Response: " in formatted_output
|
assert "Response: " in formatted_output
|
||||||
|
|
||||||
def test_from_template(self):
|
|
||||||
template_str = "Hello {name}"
|
|
||||||
prompt = PromptTemplate.from_template(template_str)
|
|
||||||
assert prompt._prompt_template.template == template_str
|
|
||||||
assert prompt._prompt_template.input_variables == []
|
|
||||||
|
|
||||||
def test_format_missing_variable(self):
|
def test_format_missing_variable(self):
|
||||||
template_str = "Hello {name}"
|
template_str = "Hello {name}"
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
|
@@ -1,31 +1,41 @@
|
|||||||
|
from dbgpt.core.interface.operator.composer_operator import (
|
||||||
|
ChatComposerInput,
|
||||||
|
ChatHistoryPromptComposerOperator,
|
||||||
|
)
|
||||||
from dbgpt.core.interface.operator.llm_operator import (
|
from dbgpt.core.interface.operator.llm_operator import (
|
||||||
BaseLLM,
|
BaseLLM,
|
||||||
|
BaseLLMOperator,
|
||||||
|
BaseStreamingLLMOperator,
|
||||||
LLMBranchOperator,
|
LLMBranchOperator,
|
||||||
LLMOperator,
|
RequestBuilderOperator,
|
||||||
RequestBuildOperator,
|
|
||||||
StreamingLLMOperator,
|
|
||||||
)
|
)
|
||||||
from dbgpt.core.interface.operator.message_operator import (
|
from dbgpt.core.interface.operator.message_operator import (
|
||||||
BaseConversationOperator,
|
BaseConversationOperator,
|
||||||
BufferedConversationMapperOperator,
|
BufferedConversationMapperOperator,
|
||||||
ConversationMapperOperator,
|
ConversationMapperOperator,
|
||||||
PostConversationOperator,
|
PreChatHistoryLoadOperator,
|
||||||
PostStreamingConversationOperator,
|
)
|
||||||
PreConversationOperator,
|
from dbgpt.core.interface.operator.prompt_operator import (
|
||||||
|
DynamicPromptBuilderOperator,
|
||||||
|
HistoryDynamicPromptBuilderOperator,
|
||||||
|
HistoryPromptBuilderOperator,
|
||||||
|
PromptBuilderOperator,
|
||||||
)
|
)
|
||||||
from dbgpt.core.interface.prompt import PromptTemplateOperator
|
|
||||||
|
|
||||||
__ALL__ = [
|
__ALL__ = [
|
||||||
"BaseLLM",
|
"BaseLLM",
|
||||||
"LLMBranchOperator",
|
"LLMBranchOperator",
|
||||||
"LLMOperator",
|
"BaseLLMOperator",
|
||||||
"RequestBuildOperator",
|
"RequestBuilderOperator",
|
||||||
"StreamingLLMOperator",
|
"BaseStreamingLLMOperator",
|
||||||
"BaseConversationOperator",
|
"BaseConversationOperator",
|
||||||
"BufferedConversationMapperOperator",
|
"BufferedConversationMapperOperator",
|
||||||
"ConversationMapperOperator",
|
"ConversationMapperOperator",
|
||||||
"PostConversationOperator",
|
"PreChatHistoryLoadOperator",
|
||||||
"PostStreamingConversationOperator",
|
"PromptBuilderOperator",
|
||||||
"PreConversationOperator",
|
"DynamicPromptBuilderOperator",
|
||||||
"PromptTemplateOperator",
|
"HistoryPromptBuilderOperator",
|
||||||
|
"HistoryDynamicPromptBuilderOperator",
|
||||||
|
"ChatComposerInput",
|
||||||
|
"ChatHistoryPromptComposerOperator",
|
||||||
]
|
]
|
||||||
|
@@ -1,13 +1,7 @@
|
|||||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||||
from dbgpt.model.utils.chatgpt_utils import (
|
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
|
||||||
OpenAILLMClient,
|
|
||||||
OpenAIStreamingOperator,
|
|
||||||
MixinLLMOperator,
|
|
||||||
)
|
|
||||||
|
|
||||||
__ALL__ = [
|
__ALL__ = [
|
||||||
"DefaultLLMClient",
|
"DefaultLLMClient",
|
||||||
"OpenAILLMClient",
|
"OpenAILLMClient",
|
||||||
"OpenAIStreamingOperator",
|
|
||||||
"MixinLLMOperator",
|
|
||||||
]
|
]
|
||||||
|
@@ -152,7 +152,7 @@ class LLMModelAdapter(ABC):
|
|||||||
return "\n"
|
return "\n"
|
||||||
|
|
||||||
def transform_model_messages(
|
def transform_model_messages(
|
||||||
self, messages: List[ModelMessage]
|
self, messages: List[ModelMessage], convert_to_compatible_format: bool = False
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""Transform the model messages
|
"""Transform the model messages
|
||||||
|
|
||||||
@@ -174,15 +174,19 @@ class LLMModelAdapter(ABC):
|
|||||||
]
|
]
|
||||||
Args:
|
Args:
|
||||||
messages (List[ModelMessage]): The model messages
|
messages (List[ModelMessage]): The model messages
|
||||||
|
convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, str]]: The transformed model messages
|
List[Dict[str, str]]: The transformed model messages
|
||||||
"""
|
"""
|
||||||
logger.info(f"support_system_message: {self.support_system_message}")
|
logger.info(f"support_system_message: {self.support_system_message}")
|
||||||
if not self.support_system_message:
|
if not self.support_system_message and convert_to_compatible_format:
|
||||||
|
# We will not do any transform in the future
|
||||||
return self._transform_to_no_system_messages(messages)
|
return self._transform_to_no_system_messages(messages)
|
||||||
else:
|
else:
|
||||||
return ModelMessage.to_openai_messages(messages)
|
return ModelMessage.to_openai_messages(
|
||||||
|
messages, convert_to_compatible_format=convert_to_compatible_format
|
||||||
|
)
|
||||||
|
|
||||||
def _transform_to_no_system_messages(
|
def _transform_to_no_system_messages(
|
||||||
self, messages: List[ModelMessage]
|
self, messages: List[ModelMessage]
|
||||||
@@ -237,6 +241,7 @@ class LLMModelAdapter(ABC):
|
|||||||
messages: List[ModelMessage],
|
messages: List[ModelMessage],
|
||||||
tokenizer: Any,
|
tokenizer: Any,
|
||||||
prompt_template: str = None,
|
prompt_template: str = None,
|
||||||
|
convert_to_compatible_format: bool = False,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Get the string prompt from the given parameters and messages
|
"""Get the string prompt from the given parameters and messages
|
||||||
|
|
||||||
@@ -247,6 +252,7 @@ class LLMModelAdapter(ABC):
|
|||||||
messages (List[ModelMessage]): The model messages
|
messages (List[ModelMessage]): The model messages
|
||||||
tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer
|
tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer
|
||||||
prompt_template (str, optional): The prompt template. Defaults to None.
|
prompt_template (str, optional): The prompt template. Defaults to None.
|
||||||
|
convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: The string prompt
|
Optional[str]: The string prompt
|
||||||
@@ -262,6 +268,7 @@ class LLMModelAdapter(ABC):
|
|||||||
model_context: Dict,
|
model_context: Dict,
|
||||||
prompt_template: str = None,
|
prompt_template: str = None,
|
||||||
):
|
):
|
||||||
|
convert_to_compatible_format = params.get("convert_to_compatible_format")
|
||||||
conv: ConversationAdapter = self.get_default_conv_template(
|
conv: ConversationAdapter = self.get_default_conv_template(
|
||||||
model_name, model_path
|
model_name, model_path
|
||||||
)
|
)
|
||||||
@@ -277,6 +284,72 @@ class LLMModelAdapter(ABC):
|
|||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
conv = conv.copy()
|
conv = conv.copy()
|
||||||
|
if convert_to_compatible_format:
|
||||||
|
# In old version, we will convert the messages to compatible format
|
||||||
|
conv = self._set_conv_converted_messages(conv, messages)
|
||||||
|
else:
|
||||||
|
# In new version, we will use the messages directly
|
||||||
|
conv = self._set_conv_messages(conv, messages)
|
||||||
|
|
||||||
|
# Add a blank message for the assistant.
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
new_prompt = conv.get_prompt()
|
||||||
|
return new_prompt, conv.stop_str, conv.stop_token_ids
|
||||||
|
|
||||||
|
def _set_conv_messages(
|
||||||
|
self, conv: ConversationAdapter, messages: List[ModelMessage]
|
||||||
|
) -> ConversationAdapter:
|
||||||
|
"""Set the messages to the conversation template
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conv (ConversationAdapter): The conversation template
|
||||||
|
messages (List[ModelMessage]): The model messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConversationAdapter: The conversation template with messages
|
||||||
|
"""
|
||||||
|
system_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, ModelMessage):
|
||||||
|
role = message.role
|
||||||
|
content = message.content
|
||||||
|
elif isinstance(message, dict):
|
||||||
|
role = message["role"]
|
||||||
|
content = message["content"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid message type: {message}")
|
||||||
|
|
||||||
|
if role == ModelMessageRoleType.SYSTEM:
|
||||||
|
system_messages.append(content)
|
||||||
|
elif role == ModelMessageRoleType.HUMAN:
|
||||||
|
conv.append_message(conv.roles[0], content)
|
||||||
|
elif role == ModelMessageRoleType.AI:
|
||||||
|
conv.append_message(conv.roles[1], content)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown role: {role}")
|
||||||
|
if len(system_messages) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Your system messages have more than one message: {system_messages}"
|
||||||
|
)
|
||||||
|
if system_messages:
|
||||||
|
conv.set_system_message(system_messages[0])
|
||||||
|
return conv
|
||||||
|
|
||||||
|
def _set_conv_converted_messages(
|
||||||
|
self, conv: ConversationAdapter, messages: List[ModelMessage]
|
||||||
|
) -> ConversationAdapter:
|
||||||
|
"""Set the messages to the conversation template
|
||||||
|
|
||||||
|
In the old version, we will convert the messages to compatible format.
|
||||||
|
This method will be deprecated in the future.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conv (ConversationAdapter): The conversation template
|
||||||
|
messages (List[ModelMessage]): The model messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConversationAdapter: The conversation template with messages
|
||||||
|
"""
|
||||||
system_messages = []
|
system_messages = []
|
||||||
user_messages = []
|
user_messages = []
|
||||||
ai_messages = []
|
ai_messages = []
|
||||||
@@ -295,10 +368,8 @@ class LLMModelAdapter(ABC):
|
|||||||
# Support for multiple system messages
|
# Support for multiple system messages
|
||||||
system_messages.append(content)
|
system_messages.append(content)
|
||||||
elif role == ModelMessageRoleType.HUMAN:
|
elif role == ModelMessageRoleType.HUMAN:
|
||||||
# conv.append_message(conv.roles[0], content)
|
|
||||||
user_messages.append(content)
|
user_messages.append(content)
|
||||||
elif role == ModelMessageRoleType.AI:
|
elif role == ModelMessageRoleType.AI:
|
||||||
# conv.append_message(conv.roles[1], content)
|
|
||||||
ai_messages.append(content)
|
ai_messages.append(content)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown role: {role}")
|
raise ValueError(f"Unknown role: {role}")
|
||||||
@@ -320,10 +391,7 @@ class LLMModelAdapter(ABC):
|
|||||||
|
|
||||||
# TODO join all system messages may not be a good idea
|
# TODO join all system messages may not be a good idea
|
||||||
conv.set_system_message("".join(can_use_systems))
|
conv.set_system_message("".join(can_use_systems))
|
||||||
# Add a blank message for the assistant.
|
return conv
|
||||||
conv.append_message(conv.roles[1], None)
|
|
||||||
new_prompt = conv.get_prompt()
|
|
||||||
return new_prompt, conv.stop_str, conv.stop_token_ids
|
|
||||||
|
|
||||||
def model_adaptation(
|
def model_adaptation(
|
||||||
self,
|
self,
|
||||||
@@ -335,6 +403,15 @@ class LLMModelAdapter(ABC):
|
|||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
"""Params adaptation"""
|
"""Params adaptation"""
|
||||||
messages = params.get("messages")
|
messages = params.get("messages")
|
||||||
|
convert_to_compatible_format = params.get("convert_to_compatible_format")
|
||||||
|
message_version = params.get("version", "v2").lower()
|
||||||
|
logger.info(f"Message version is {message_version}")
|
||||||
|
if convert_to_compatible_format is None:
|
||||||
|
# Support convert messages to compatible format when message version is v1
|
||||||
|
convert_to_compatible_format = message_version == "v1"
|
||||||
|
# Save to params
|
||||||
|
params["convert_to_compatible_format"] = convert_to_compatible_format
|
||||||
|
|
||||||
# Some model context to dbgpt server
|
# Some model context to dbgpt server
|
||||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||||
if messages:
|
if messages:
|
||||||
@@ -345,7 +422,9 @@ class LLMModelAdapter(ABC):
|
|||||||
]
|
]
|
||||||
params["messages"] = messages
|
params["messages"] = messages
|
||||||
|
|
||||||
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
|
new_prompt = self.get_str_prompt(
|
||||||
|
params, messages, tokenizer, prompt_template, convert_to_compatible_format
|
||||||
|
)
|
||||||
conv_stop_str, conv_stop_token_ids = None, None
|
conv_stop_str, conv_stop_token_ids = None, None
|
||||||
if not new_prompt:
|
if not new_prompt:
|
||||||
(
|
(
|
||||||
|
@@ -87,6 +87,7 @@ class NewHFChatModelAdapter(LLMModelAdapter, ABC):
|
|||||||
messages: List[ModelMessage],
|
messages: List[ModelMessage],
|
||||||
tokenizer: Any,
|
tokenizer: Any,
|
||||||
prompt_template: str = None,
|
prompt_template: str = None,
|
||||||
|
convert_to_compatible_format: bool = False,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@@ -94,7 +95,7 @@ class NewHFChatModelAdapter(LLMModelAdapter, ABC):
|
|||||||
raise ValueError("tokenizer is is None")
|
raise ValueError("tokenizer is is None")
|
||||||
tokenizer: AutoTokenizer = tokenizer
|
tokenizer: AutoTokenizer = tokenizer
|
||||||
|
|
||||||
messages = self.transform_model_messages(messages)
|
messages = self.transform_model_messages(messages, convert_to_compatible_format)
|
||||||
logger.debug(f"The messages after transform: \n{messages}")
|
logger.debug(f"The messages after transform: \n{messages}")
|
||||||
str_prompt = tokenizer.apply_chat_template(
|
str_prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
@@ -22,6 +22,8 @@ class PromptRequest(BaseModel):
|
|||||||
span_id: str = None
|
span_id: str = None
|
||||||
metrics: bool = False
|
metrics: bool = False
|
||||||
"""Whether to return metrics of inference"""
|
"""Whether to return metrics of inference"""
|
||||||
|
version: str = "v2"
|
||||||
|
"""Message version, default to v2"""
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequest(BaseModel):
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
@@ -1,20 +1,35 @@
|
|||||||
from typing import AsyncIterator, List
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dbgpt.core.interface.llm import LLMClient, ModelRequest, ModelOutput, ModelMetadata
|
from typing import AsyncIterator, List, Optional
|
||||||
from dbgpt.model.parameter import WorkerType
|
|
||||||
|
from dbgpt.core.interface.llm import (
|
||||||
|
LLMClient,
|
||||||
|
MessageConverter,
|
||||||
|
ModelMetadata,
|
||||||
|
ModelOutput,
|
||||||
|
ModelRequest,
|
||||||
|
)
|
||||||
from dbgpt.model.cluster.manager_base import WorkerManager
|
from dbgpt.model.cluster.manager_base import WorkerManager
|
||||||
|
from dbgpt.model.parameter import WorkerType
|
||||||
|
|
||||||
|
|
||||||
class DefaultLLMClient(LLMClient):
|
class DefaultLLMClient(LLMClient):
|
||||||
def __init__(self, worker_manager: WorkerManager):
|
def __init__(self, worker_manager: WorkerManager):
|
||||||
self._worker_manager = worker_manager
|
self._worker_manager = worker_manager
|
||||||
|
|
||||||
async def generate(self, request: ModelRequest) -> ModelOutput:
|
async def generate(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
message_converter: Optional[MessageConverter] = None,
|
||||||
|
) -> ModelOutput:
|
||||||
|
request = await self.covert_message(request, message_converter)
|
||||||
return await self._worker_manager.generate(request.to_dict())
|
return await self._worker_manager.generate(request.to_dict())
|
||||||
|
|
||||||
async def generate_stream(
|
async def generate_stream(
|
||||||
self, request: ModelRequest
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
message_converter: Optional[MessageConverter] = None,
|
||||||
) -> AsyncIterator[ModelOutput]:
|
) -> AsyncIterator[ModelOutput]:
|
||||||
|
request = await self.covert_message(request, message_converter)
|
||||||
async for output in self._worker_manager.generate_stream(request.to_dict()):
|
async for output in self._worker_manager.generate_stream(request.to_dict()):
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
@@ -8,7 +8,12 @@ import traceback
|
|||||||
from dbgpt.configs.model_config import get_device
|
from dbgpt.configs.model_config import get_device
|
||||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||||
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
|
||||||
from dbgpt.core import ModelOutput, ModelInferenceMetrics, ModelMetadata
|
from dbgpt.core import (
|
||||||
|
ModelOutput,
|
||||||
|
ModelInferenceMetrics,
|
||||||
|
ModelMetadata,
|
||||||
|
ModelExtraMedata,
|
||||||
|
)
|
||||||
from dbgpt.model.loader import ModelLoader, _get_model_real_path
|
from dbgpt.model.loader import ModelLoader, _get_model_real_path
|
||||||
from dbgpt.model.parameter import ModelParameters
|
from dbgpt.model.parameter import ModelParameters
|
||||||
from dbgpt.model.cluster.worker_base import ModelWorker
|
from dbgpt.model.cluster.worker_base import ModelWorker
|
||||||
@@ -196,9 +201,13 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||||
|
ext_metadata = ModelExtraMedata(
|
||||||
|
prompt_sep=self.llm_adapter.get_default_message_separator()
|
||||||
|
)
|
||||||
return ModelMetadata(
|
return ModelMetadata(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
context_length=self.context_len,
|
context_length=self.context_len,
|
||||||
|
ext_metadata=ext_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
async def async_get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||||
|
@@ -122,7 +122,7 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
json=params,
|
json=params,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
return ModelMetadata(**response.json())
|
return ModelMetadata.from_dict(response.json())
|
||||||
|
|
||||||
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
def get_model_metadata(self, params: Dict) -> ModelMetadata:
|
||||||
"""Get model metadata"""
|
"""Get model metadata"""
|
||||||
|
@@ -0,0 +1,13 @@
|
|||||||
|
from dbgpt.model.operator.llm_operator import (
|
||||||
|
LLMOperator,
|
||||||
|
MixinLLMOperator,
|
||||||
|
StreamingLLMOperator,
|
||||||
|
)
|
||||||
|
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
|
||||||
|
|
||||||
|
__ALL__ = [
|
||||||
|
"MixinLLMOperator",
|
||||||
|
"LLMOperator",
|
||||||
|
"StreamingLLMOperator",
|
||||||
|
"OpenAIStreamingOutputOperator",
|
||||||
|
]
|
||||||
|
75
dbgpt/model/operator/llm_operator.py
Normal file
75
dbgpt/model/operator/llm_operator.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import logging
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dbgpt.component import ComponentType
|
||||||
|
from dbgpt.core import LLMClient
|
||||||
|
from dbgpt.core.awel import BaseOperator
|
||||||
|
from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator
|
||||||
|
from dbgpt.model.cluster import WorkerManagerFactory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
||||||
|
"""Mixin class for LLM operator.
|
||||||
|
|
||||||
|
This class extends BaseOperator by adding LLM capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
||||||
|
super().__init__(default_client)
|
||||||
|
self._default_llm_client = default_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_client(self) -> LLMClient:
|
||||||
|
if not self._llm_client:
|
||||||
|
worker_manager_factory: WorkerManagerFactory = (
|
||||||
|
self.system_app.get_component(
|
||||||
|
ComponentType.WORKER_MANAGER_FACTORY,
|
||||||
|
WorkerManagerFactory,
|
||||||
|
default_component=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if worker_manager_factory:
|
||||||
|
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||||
|
|
||||||
|
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
||||||
|
else:
|
||||||
|
if self._default_llm_client is None:
|
||||||
|
from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
|
||||||
|
|
||||||
|
self._default_llm_client = OpenAILLMClient()
|
||||||
|
logger.info(
|
||||||
|
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
|
||||||
|
)
|
||||||
|
self._llm_client = self._default_llm_client
|
||||||
|
return self._llm_client
|
||||||
|
|
||||||
|
|
||||||
|
class LLMOperator(MixinLLMOperator, BaseLLMOperator):
|
||||||
|
"""Default LLM operator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_client (Optional[LLMClient], optional): The LLM client. Defaults to None.
|
||||||
|
If llm_client is None, we will try to connect to the model serving cluster deploy by DB-GPT,
|
||||||
|
and if we can't connect to the model serving cluster, we will use the :class:`OpenAILLMClient` as the llm_client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||||
|
super().__init__(llm_client)
|
||||||
|
BaseLLMOperator.__init__(self, llm_client, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingLLMOperator(MixinLLMOperator, BaseStreamingLLMOperator):
|
||||||
|
"""Default streaming LLM operator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_client (Optional[LLMClient], optional): The LLM client. Defaults to None.
|
||||||
|
If llm_client is None, we will try to connect to the model serving cluster deploy by DB-GPT,
|
||||||
|
and if we can't connect to the model serving cluster, we will use the :class:`OpenAILLMClient` as the llm_client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||||
|
super().__init__(llm_client)
|
||||||
|
BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
@@ -1,17 +1,18 @@
|
|||||||
from typing import AsyncIterator, Dict, List, Union
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import AsyncIterator, Dict, List, Union
|
||||||
|
|
||||||
|
from dbgpt.component import ComponentType
|
||||||
|
from dbgpt.core import ModelOutput
|
||||||
from dbgpt.core.awel import (
|
from dbgpt.core.awel import (
|
||||||
BranchFunc,
|
BranchFunc,
|
||||||
StreamifyAbsOperator,
|
|
||||||
BranchOperator,
|
BranchOperator,
|
||||||
MapOperator,
|
MapOperator,
|
||||||
|
StreamifyAbsOperator,
|
||||||
TransformStreamAbsOperator,
|
TransformStreamAbsOperator,
|
||||||
)
|
)
|
||||||
from dbgpt.component import ComponentType
|
|
||||||
from dbgpt.core.awel.operator.base import BaseOperator
|
from dbgpt.core.awel.operator.base import BaseOperator
|
||||||
from dbgpt.core import ModelOutput
|
|
||||||
from dbgpt.model.cluster import WorkerManager, WorkerManagerFactory
|
from dbgpt.model.cluster import WorkerManager, WorkerManagerFactory
|
||||||
from dbgpt.storage.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue
|
from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@@ -13,6 +13,8 @@ def bard_generate_stream(
|
|||||||
proxy_api_key = model_params.proxy_api_key
|
proxy_api_key = model_params.proxy_api_key
|
||||||
proxy_server_url = model_params.proxy_server_url
|
proxy_server_url = model_params.proxy_server_url
|
||||||
|
|
||||||
|
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -25,14 +27,15 @@ def bard_generate_stream(
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
last_user_input_index = None
|
if convert_to_compatible_format:
|
||||||
for i in range(len(history) - 1, -1, -1):
|
last_user_input_index = None
|
||||||
if history[i]["role"] == "user":
|
for i in range(len(history) - 1, -1, -1):
|
||||||
last_user_input_index = i
|
if history[i]["role"] == "user":
|
||||||
break
|
last_user_input_index = i
|
||||||
if last_user_input_index:
|
break
|
||||||
last_user_input = history.pop(last_user_input_index)
|
if last_user_input_index:
|
||||||
history.append(last_user_input)
|
last_user_input = history.pop(last_user_input_index)
|
||||||
|
history.append(last_user_input)
|
||||||
|
|
||||||
msgs = []
|
msgs = []
|
||||||
for msg in history:
|
for msg in history:
|
||||||
|
@@ -128,7 +128,10 @@ def _build_request(model: ProxyModel, params):
|
|||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
|
||||||
# history = __convert_2_gpt_messages(messages)
|
# history = __convert_2_gpt_messages(messages)
|
||||||
history = ModelMessage.to_openai_messages(messages)
|
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||||
|
history = ModelMessage.to_openai_messages(
|
||||||
|
messages, convert_to_compatible_format=convert_to_compatible_format
|
||||||
|
)
|
||||||
payloads = {
|
payloads = {
|
||||||
"temperature": params.get("temperature"),
|
"temperature": params.get("temperature"),
|
||||||
"max_tokens": params.get("max_new_tokens"),
|
"max_tokens": params.get("max_new_tokens"),
|
||||||
|
@@ -12,7 +12,6 @@ def gemini_generate_stream(
|
|||||||
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
|
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
|
||||||
model_params = model.get_params()
|
model_params = model.get_params()
|
||||||
print(f"Model: {model}, model_params: {model_params}")
|
print(f"Model: {model}, model_params: {model_params}")
|
||||||
global history
|
|
||||||
|
|
||||||
# TODO proxy model use unified config?
|
# TODO proxy model use unified config?
|
||||||
proxy_api_key = model_params.proxy_api_key
|
proxy_api_key = model_params.proxy_api_key
|
||||||
|
@@ -56,6 +56,9 @@ def spark_generate_stream(
|
|||||||
del messages[index]
|
del messages[index]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# TODO: Support convert_to_compatible_format config
|
||||||
|
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
# Add history conversation
|
# Add history conversation
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
@@ -53,8 +53,12 @@ def tongyi_generate_stream(
|
|||||||
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
|
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
|
||||||
|
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||||
|
|
||||||
history = __convert_2_tongyi_messages(messages)
|
if convert_to_compatible_format:
|
||||||
|
history = __convert_2_tongyi_messages(messages)
|
||||||
|
else:
|
||||||
|
history = ModelMessage.to_openai_messages(messages)
|
||||||
gen = Generation()
|
gen = Generation()
|
||||||
res = gen.call(
|
res = gen.call(
|
||||||
proxyllm_backend,
|
proxyllm_backend,
|
||||||
|
@@ -25,8 +25,29 @@ def _build_access_token(api_key: str, secret_key: str) -> str:
|
|||||||
return res.json().get("access_token")
|
return res.json().get("access_token")
|
||||||
|
|
||||||
|
|
||||||
|
def _to_wenxin_messages(messages: List[ModelMessage]):
|
||||||
|
"""Convert messages to wenxin compatible format
|
||||||
|
|
||||||
|
See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
|
||||||
|
"""
|
||||||
|
wenxin_messages = []
|
||||||
|
system_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message.role == ModelMessageRoleType.HUMAN:
|
||||||
|
wenxin_messages.append({"role": "user", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||||
|
system_messages.append(message.content)
|
||||||
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
|
wenxin_messages.append({"role": "assistant", "content": message.content})
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
if len(system_messages) > 1:
|
||||||
|
raise ValueError("Wenxin only support one system message")
|
||||||
|
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
|
||||||
|
return wenxin_messages, str_system_message
|
||||||
|
|
||||||
|
|
||||||
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
||||||
chat_round = 0
|
|
||||||
wenxin_messages = []
|
wenxin_messages = []
|
||||||
|
|
||||||
last_usr_message = ""
|
last_usr_message = ""
|
||||||
@@ -57,7 +78,8 @@ def __convert_2_wenxin_messages(messages: List[ModelMessage]):
|
|||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
end_message = last_message.content
|
end_message = last_message.content
|
||||||
wenxin_messages.append({"role": "user", "content": end_message})
|
wenxin_messages.append({"role": "user", "content": end_message})
|
||||||
return wenxin_messages, system_messages
|
str_system_message = system_messages[0] if len(system_messages) > 0 else ""
|
||||||
|
return wenxin_messages, str_system_message
|
||||||
|
|
||||||
|
|
||||||
def wenxin_generate_stream(
|
def wenxin_generate_stream(
|
||||||
@@ -87,13 +109,14 @@ def wenxin_generate_stream(
|
|||||||
|
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
|
||||||
history, systems = __convert_2_wenxin_messages(messages)
|
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||||
system = ""
|
if convert_to_compatible_format:
|
||||||
if systems and len(systems) > 0:
|
history, system_message = __convert_2_wenxin_messages(messages)
|
||||||
system = systems[0]
|
else:
|
||||||
|
history, system_message = _to_wenxin_messages(messages)
|
||||||
payload = {
|
payload = {
|
||||||
"messages": history,
|
"messages": history,
|
||||||
"system": system,
|
"system": system_message,
|
||||||
"temperature": params.get("temperature"),
|
"temperature": params.get("temperature"),
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
@@ -57,6 +57,10 @@ def zhipu_generate_stream(
|
|||||||
zhipuai.api_key = proxy_api_key
|
zhipuai.api_key = proxy_api_key
|
||||||
|
|
||||||
messages: List[ModelMessage] = params["messages"]
|
messages: List[ModelMessage] = params["messages"]
|
||||||
|
|
||||||
|
# TODO: Support convert_to_compatible_format config, zhipu not support system message
|
||||||
|
convert_to_compatible_format = params.get("convert_to_compatible_format", False)
|
||||||
|
|
||||||
history, systems = __convert_2_zhipu_messages(messages)
|
history, systems = __convert_2_zhipu_messages(messages)
|
||||||
res = zhipuai.model_api.sse_invoke(
|
res = zhipuai.model_api.sse_invoke(
|
||||||
model=proxyllm_backend,
|
model=proxyllm_backend,
|
||||||
|
@@ -20,8 +20,13 @@ from typing import (
|
|||||||
from dbgpt.component import ComponentType
|
from dbgpt.component import ComponentType
|
||||||
from dbgpt.core.operator import BaseLLM
|
from dbgpt.core.operator import BaseLLM
|
||||||
from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator
|
from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator
|
||||||
from dbgpt.core.interface.llm import ModelMetadata, LLMClient
|
from dbgpt.core.interface.llm import (
|
||||||
from dbgpt.core.interface.llm import ModelOutput, ModelRequest
|
ModelOutput,
|
||||||
|
ModelRequest,
|
||||||
|
ModelMetadata,
|
||||||
|
LLMClient,
|
||||||
|
MessageConverter,
|
||||||
|
)
|
||||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||||
from dbgpt.model.cluster import WorkerManagerFactory
|
from dbgpt.model.cluster import WorkerManagerFactory
|
||||||
from dbgpt._private.pydantic import model_to_json
|
from dbgpt._private.pydantic import model_to_json
|
||||||
@@ -175,7 +180,13 @@ class OpenAILLMClient(LLMClient):
|
|||||||
payload["max_tokens"] = request.max_new_tokens
|
payload["max_tokens"] = request.max_new_tokens
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def generate(self, request: ModelRequest) -> ModelOutput:
|
async def generate(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
message_converter: Optional[MessageConverter] = None,
|
||||||
|
) -> ModelOutput:
|
||||||
|
request = await self.covert_message(request, message_converter)
|
||||||
|
|
||||||
messages = request.to_openai_messages()
|
messages = request.to_openai_messages()
|
||||||
payload = self._build_request(request)
|
payload = self._build_request(request)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -195,8 +206,11 @@ class OpenAILLMClient(LLMClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def generate_stream(
|
async def generate_stream(
|
||||||
self, request: ModelRequest
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
message_converter: Optional[MessageConverter] = None,
|
||||||
) -> AsyncIterator[ModelOutput]:
|
) -> AsyncIterator[ModelOutput]:
|
||||||
|
request = await self.covert_message(request, message_converter)
|
||||||
messages = request.to_openai_messages()
|
messages = request.to_openai_messages()
|
||||||
payload = self._build_request(request, True)
|
payload = self._build_request(request, True)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -247,7 +261,7 @@ class OpenAILLMClient(LLMClient):
|
|||||||
return self._tokenizer.count_token(prompt, model)
|
return self._tokenizer.count_token(prompt, model)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||||
"""Transform ModelOutput to openai stream format."""
|
"""Transform ModelOutput to openai stream format."""
|
||||||
|
|
||||||
async def transform_stream(
|
async def transform_stream(
|
||||||
@@ -266,40 +280,6 @@ class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
|||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
|
||||||
class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
|
||||||
"""Mixin class for LLM operator.
|
|
||||||
|
|
||||||
This class extends BaseOperator by adding LLM capabilities.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
|
||||||
super().__init__(default_client)
|
|
||||||
self._default_llm_client = default_client
|
|
||||||
|
|
||||||
@property
|
|
||||||
def llm_client(self) -> LLMClient:
|
|
||||||
if not self._llm_client:
|
|
||||||
worker_manager_factory: WorkerManagerFactory = (
|
|
||||||
self.system_app.get_component(
|
|
||||||
ComponentType.WORKER_MANAGER_FACTORY,
|
|
||||||
WorkerManagerFactory,
|
|
||||||
default_component=None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if worker_manager_factory:
|
|
||||||
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
|
|
||||||
else:
|
|
||||||
if self._default_llm_client is None:
|
|
||||||
from dbgpt.model import OpenAILLMClient
|
|
||||||
|
|
||||||
self._default_llm_client = OpenAILLMClient()
|
|
||||||
logger.info(
|
|
||||||
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
|
|
||||||
)
|
|
||||||
self._llm_client = self._default_llm_client
|
|
||||||
return self._llm_client
|
|
||||||
|
|
||||||
|
|
||||||
async def _to_openai_stream(
|
async def _to_openai_stream(
|
||||||
output_iter: AsyncIterator[ModelOutput],
|
output_iter: AsyncIterator[ModelOutput],
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
|
71
dbgpt/serve/conversation/operator.py
Normal file
71
dbgpt/serve/conversation/operator.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from dbgpt.core import (
|
||||||
|
InMemoryStorage,
|
||||||
|
MessageStorageItem,
|
||||||
|
StorageConversation,
|
||||||
|
StorageInterface,
|
||||||
|
)
|
||||||
|
from dbgpt.core.operator import PreChatHistoryLoadOperator
|
||||||
|
|
||||||
|
from .serve import Serve
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ServePreChatHistoryLoadOperator(PreChatHistoryLoadOperator):
|
||||||
|
"""Pre-chat history load operator for DB-GPT serve component
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage (Optional[StorageInterface[StorageConversation, Any]], optional):
|
||||||
|
The conversation storage, store the conversation items. Defaults to None.
|
||||||
|
message_storage (Optional[StorageInterface[MessageStorageItem, Any]], optional):
|
||||||
|
The message storage, store the messages of one conversation. Defaults to None.
|
||||||
|
|
||||||
|
If the storage or message_storage is not None, the storage or message_storage will be used first.
|
||||||
|
Otherwise, we will try get current serve component from system app,
|
||||||
|
and use the storage or message_storage of the serve component.
|
||||||
|
If we can't get the storage, we will use the InMemoryStorage as the storage or message_storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||||
|
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(storage, message_storage, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def storage(self):
|
||||||
|
if self._storage:
|
||||||
|
return self._storage
|
||||||
|
storage = Serve.call_on_current_serve(
|
||||||
|
self.system_app, lambda serve: serve.conv_storage
|
||||||
|
)
|
||||||
|
if not storage:
|
||||||
|
logger.warning(
|
||||||
|
"Can't get the conversation storage from current serve component, "
|
||||||
|
"use the InMemoryStorage as the conversation storage."
|
||||||
|
)
|
||||||
|
self._storage = InMemoryStorage()
|
||||||
|
return self._storage
|
||||||
|
return storage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message_storage(self):
|
||||||
|
if self._message_storage:
|
||||||
|
return self._message_storage
|
||||||
|
storage = Serve.call_on_current_serve(
|
||||||
|
self.system_app,
|
||||||
|
lambda serve: serve.message_storage,
|
||||||
|
)
|
||||||
|
if not storage:
|
||||||
|
logger.warning(
|
||||||
|
"Can't get the message storage from current serve component, "
|
||||||
|
"use the InMemoryStorage as the message storage."
|
||||||
|
)
|
||||||
|
self._message_storage = InMemoryStorage()
|
||||||
|
return self._message_storage
|
||||||
|
return storage
|
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
from sqlalchemy import URL
|
from sqlalchemy import URL
|
||||||
|
|
||||||
@@ -60,3 +60,44 @@ class BaseServe(BaseComponent, ABC):
|
|||||||
finally:
|
finally:
|
||||||
self._not_create_table = False
|
self._not_create_table = False
|
||||||
return init_db
|
return init_db
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_current_serve(cls, system_app: SystemApp) -> Optional["BaseServe"]:
|
||||||
|
"""Get the current serve component.
|
||||||
|
|
||||||
|
None if the serve component is not exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_app (SystemApp): The system app
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[BaseServe]: The current serve component.
|
||||||
|
"""
|
||||||
|
return system_app.get_component(cls.name, cls, default_component=None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def call_on_current_serve(
|
||||||
|
cls,
|
||||||
|
system_app: SystemApp,
|
||||||
|
func: Callable[["BaseServe"], Optional[Any]],
|
||||||
|
default_value: Optional[Any] = None,
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""Call the function on the current serve component.
|
||||||
|
|
||||||
|
Return default_value if the serve component is not exist or the function return None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_app (SystemApp): The system app
|
||||||
|
func (Callable[[BaseServe], Any]): The function to call
|
||||||
|
default_value (Optional[Any], optional): The default value. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Any]: The result of the function
|
||||||
|
"""
|
||||||
|
serve = cls.get_current_serve(system_app)
|
||||||
|
if not serve:
|
||||||
|
return default_value
|
||||||
|
result = func(serve)
|
||||||
|
if not result:
|
||||||
|
result = default_value
|
||||||
|
return result
|
||||||
|
87
dbgpt/util/function_utils.py
Normal file
87
dbgpt/util/function_utils.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
from typing import Any, get_type_hints, get_origin, get_args
|
||||||
|
from functools import wraps
|
||||||
|
import inspect
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
def _is_instance_of_generic_type(obj, generic_type):
|
||||||
|
"""Check if an object is an instance of a generic type."""
|
||||||
|
if generic_type is Any:
|
||||||
|
return True # Any type is compatible with any object
|
||||||
|
|
||||||
|
origin = get_origin(generic_type)
|
||||||
|
if origin is None:
|
||||||
|
return isinstance(obj, generic_type) # Handle non-generic types
|
||||||
|
|
||||||
|
args = get_args(generic_type)
|
||||||
|
if not args:
|
||||||
|
return isinstance(obj, origin)
|
||||||
|
|
||||||
|
# Check if object matches the generic origin (like list, dict)
|
||||||
|
if not isinstance(obj, origin):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# For each item in the object, check if it matches the corresponding type argument
|
||||||
|
for sub_obj, arg in zip(obj, args):
|
||||||
|
# Skip check if the type argument is Any
|
||||||
|
if arg is not Any and not isinstance(sub_obj, arg):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _sort_args(func, args, kwargs):
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
type_hints = get_type_hints(func)
|
||||||
|
|
||||||
|
arg_types = [
|
||||||
|
type_hints[param_name]
|
||||||
|
for param_name in sig.parameters
|
||||||
|
if param_name != "return" and param_name != "self"
|
||||||
|
]
|
||||||
|
|
||||||
|
if "self" in sig.parameters:
|
||||||
|
self_arg = [args[0]]
|
||||||
|
other_args = args[1:]
|
||||||
|
else:
|
||||||
|
self_arg = []
|
||||||
|
other_args = args
|
||||||
|
|
||||||
|
sorted_args = sorted(
|
||||||
|
other_args,
|
||||||
|
key=lambda x: next(
|
||||||
|
i for i, t in enumerate(arg_types) if _is_instance_of_generic_type(x, t)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return (*self_arg, *sorted_args), kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def rearrange_args_by_type(func):
|
||||||
|
"""Decorator to rearrange the arguments of a function by type.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||||
|
|
||||||
|
@rearrange_args_by_type
|
||||||
|
def sync_regular_function(a: int, b: str, c: float):
|
||||||
|
return a, b, c
|
||||||
|
|
||||||
|
assert instance.sync_class_method(1, "b", 3.0) == (1, "b", 3.0)
|
||||||
|
assert instance.sync_class_method("b", 3.0, 1) == (1, "b", 3.0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
sorted_args, sorted_kwargs = _sort_args(func, args, kwargs)
|
||||||
|
return func(*sorted_args, **sorted_kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
sorted_args, sorted_kwargs = _sort_args(func, args, kwargs)
|
||||||
|
return await func(*sorted_args, **sorted_kwargs)
|
||||||
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
@@ -10,11 +10,12 @@ needed), or truncating them so that they fit in a single LLM call.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Callable, List, Optional, Sequence
|
from typing import Callable, List, Optional, Sequence, Set
|
||||||
|
|
||||||
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
|
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
|
||||||
|
|
||||||
from dbgpt.util.global_helper import globals_helper
|
from dbgpt.util.global_helper import globals_helper
|
||||||
|
from dbgpt.core.interface.prompt import get_template_vars
|
||||||
from dbgpt._private.llm_metadata import LLMMetadata
|
from dbgpt._private.llm_metadata import LLMMetadata
|
||||||
from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter
|
from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter
|
||||||
|
|
||||||
@@ -230,15 +231,3 @@ def get_empty_prompt_txt(template: str) -> str:
|
|||||||
all_kwargs = {**partial_kargs, **empty_kwargs}
|
all_kwargs = {**partial_kargs, **empty_kwargs}
|
||||||
prompt = template.format(**all_kwargs)
|
prompt = template.format(**all_kwargs)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def get_template_vars(template_str: str) -> List[str]:
|
|
||||||
"""Get template variables from a template string."""
|
|
||||||
variables = []
|
|
||||||
formatter = Formatter()
|
|
||||||
|
|
||||||
for _, variable_name, _, _ in formatter.parse(template_str):
|
|
||||||
if variable_name:
|
|
||||||
variables.append(variable_name)
|
|
||||||
|
|
||||||
return variables
|
|
||||||
|
120
dbgpt/util/tests/test_function_utils.py
Normal file
120
dbgpt/util/tests/test_function_utils.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||||
|
|
||||||
|
|
||||||
|
class ChatPromptTemplate:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessage:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMessage:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClass:
|
||||||
|
@rearrange_args_by_type
|
||||||
|
async def class_method(self, a: int, b: str, c: float):
|
||||||
|
return a, b, c
|
||||||
|
|
||||||
|
@rearrange_args_by_type
|
||||||
|
async def merge_history(
|
||||||
|
self,
|
||||||
|
prompt: ChatPromptTemplate,
|
||||||
|
history: List[BaseMessage],
|
||||||
|
prompt_dict: Dict[str, Any],
|
||||||
|
) -> List[ModelMessage]:
|
||||||
|
return [type(prompt), type(history), type(prompt_dict)]
|
||||||
|
|
||||||
|
@rearrange_args_by_type
|
||||||
|
def sync_class_method(self, a: int, b: str, c: float):
|
||||||
|
return a, b, c
|
||||||
|
|
||||||
|
|
||||||
|
@rearrange_args_by_type
|
||||||
|
def sync_regular_function(a: int, b: str, c: float):
|
||||||
|
return a, b, c
|
||||||
|
|
||||||
|
|
||||||
|
@rearrange_args_by_type
|
||||||
|
async def regular_function(a: int, b: str, c: float):
|
||||||
|
return a, b, c
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_class_method_correct_order():
|
||||||
|
instance = DummyClass()
|
||||||
|
result = await instance.class_method(1, "b", 3.0)
|
||||||
|
assert result == (1, "b", 3.0), "Class method failed with correct order"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_class_method_incorrect_order():
|
||||||
|
instance = DummyClass()
|
||||||
|
result = await instance.class_method("b", 3.0, 1)
|
||||||
|
assert result == (1, "b", 3.0), "Class method failed with incorrect order"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_regular_function_correct_order():
|
||||||
|
result = await regular_function(1, "b", 3.0)
|
||||||
|
assert result == (1, "b", 3.0), "Regular function failed with correct order"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_regular_function_incorrect_order():
|
||||||
|
result = await regular_function("b", 3.0, 1)
|
||||||
|
assert result == (1, "b", 3.0), "Regular function failed with incorrect order"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_merge_history_correct_order():
|
||||||
|
instance = DummyClass()
|
||||||
|
result = await instance.merge_history(
|
||||||
|
ChatPromptTemplate(), [BaseMessage()], {"key": "value"}
|
||||||
|
)
|
||||||
|
assert result == [ChatPromptTemplate, list, dict], "Failed with correct order"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_merge_history_incorrect_order_1():
|
||||||
|
instance = DummyClass()
|
||||||
|
result = await instance.merge_history(
|
||||||
|
[BaseMessage()], ChatPromptTemplate(), {"key": "value"}
|
||||||
|
)
|
||||||
|
assert result == [ChatPromptTemplate, list, dict], "Failed with incorrect order 1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_merge_history_incorrect_order_2():
|
||||||
|
instance = DummyClass()
|
||||||
|
result = await instance.merge_history(
|
||||||
|
{"key": "value"}, [BaseMessage()], ChatPromptTemplate()
|
||||||
|
)
|
||||||
|
assert result == [ChatPromptTemplate, list, dict], "Failed with incorrect order 2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_class_method_correct_order():
|
||||||
|
instance = DummyClass()
|
||||||
|
result = instance.sync_class_method(1, "b", 3.0)
|
||||||
|
assert result == (1, "b", 3.0), "Sync class method failed with correct order"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_class_method_incorrect_order():
|
||||||
|
instance = DummyClass()
|
||||||
|
result = instance.sync_class_method("b", 3.0, 1)
|
||||||
|
assert result == (1, "b", 3.0), "Sync class method failed with incorrect order"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_regular_function_correct_order():
|
||||||
|
result = sync_regular_function(1, "b", 3.0)
|
||||||
|
assert result == (1, "b", 3.0), "Sync regular function failed with correct order"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_regular_function_incorrect_order():
|
||||||
|
result = sync_regular_function("b", 3.0, 1)
|
||||||
|
assert result == (1, "b", 3.0), "Sync regular function failed with incorrect order"
|
@@ -26,7 +26,7 @@
|
|||||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \
|
||||||
-H "Content-Type: application/json" -d '{
|
-H "Content-Type: application/json" -d '{
|
||||||
"command": "dbgpt_awel_data_analyst_code_fix",
|
"command": "dbgpt_awel_data_analyst_code_fix",
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "'"$MODEL"'",
|
||||||
"stream": false,
|
"stream": false,
|
||||||
"context": {
|
"context": {
|
||||||
"conv_uid": "uuid_conv_copilot_1234",
|
"conv_uid": "uuid_conv_copilot_1234",
|
||||||
@@ -37,43 +37,55 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.core import (
|
from dbgpt.core import (
|
||||||
InMemoryStorage,
|
ChatPromptTemplate,
|
||||||
LLMClient,
|
HumanPromptTemplate,
|
||||||
MessageStorageItem,
|
MessagesPlaceholder,
|
||||||
ModelMessage,
|
ModelMessage,
|
||||||
ModelMessageRoleType,
|
ModelRequest,
|
||||||
|
ModelRequestContext,
|
||||||
PromptManager,
|
PromptManager,
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
StorageConversation,
|
SystemPromptTemplate,
|
||||||
StorageInterface,
|
|
||||||
)
|
)
|
||||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||||
from dbgpt.core.operator import (
|
from dbgpt.core.operator import (
|
||||||
BufferedConversationMapperOperator,
|
BufferedConversationMapperOperator,
|
||||||
|
HistoryDynamicPromptBuilderOperator,
|
||||||
LLMBranchOperator,
|
LLMBranchOperator,
|
||||||
|
RequestBuilderOperator,
|
||||||
|
)
|
||||||
|
from dbgpt.model.operator import (
|
||||||
LLMOperator,
|
LLMOperator,
|
||||||
PostConversationOperator,
|
OpenAIStreamingOutputOperator,
|
||||||
PostStreamingConversationOperator,
|
|
||||||
PreConversationOperator,
|
|
||||||
RequestBuildOperator,
|
|
||||||
StreamingLLMOperator,
|
StreamingLLMOperator,
|
||||||
)
|
)
|
||||||
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
|
from dbgpt.serve.conversation.operator import ServePreChatHistoryLoadOperator
|
||||||
from dbgpt.util.utils import colored
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PROMPT_LANG_ZH = "zh"
|
||||||
|
PROMPT_LANG_EN = "en"
|
||||||
|
|
||||||
|
CODE_DEFAULT = "dbgpt_awel_data_analyst_code_default"
|
||||||
CODE_FIX = "dbgpt_awel_data_analyst_code_fix"
|
CODE_FIX = "dbgpt_awel_data_analyst_code_fix"
|
||||||
CODE_PERF = "dbgpt_awel_data_analyst_code_perf"
|
CODE_PERF = "dbgpt_awel_data_analyst_code_perf"
|
||||||
CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain"
|
CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain"
|
||||||
CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment"
|
CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment"
|
||||||
CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate"
|
CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate"
|
||||||
|
|
||||||
|
CODE_DEFAULT_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师。
|
||||||
|
你可以根据最佳实践来优化代码, 也可以对代码进行修复, 解释, 添加注释, 以及将代码翻译成其他语言。"""
|
||||||
|
CODE_DEFAULT_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst.
|
||||||
|
You can optimize the code according to best practices, or fix, explain, add comments to the code,
|
||||||
|
and you can also translate the code into other languages.
|
||||||
|
"""
|
||||||
|
|
||||||
CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,
|
CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,
|
||||||
这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。"""
|
这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。"""
|
||||||
CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
|
CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
|
||||||
@@ -126,7 +138,9 @@ class ReqContext(BaseModel):
|
|||||||
|
|
||||||
class TriggerReqBody(BaseModel):
|
class TriggerReqBody(BaseModel):
|
||||||
messages: str = Field(..., description="User input messages")
|
messages: str = Field(..., description="User input messages")
|
||||||
command: Optional[str] = Field(default="fix", description="Command name")
|
command: Optional[str] = Field(
|
||||||
|
default=None, description="Command name, None if common chat"
|
||||||
|
)
|
||||||
model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name")
|
model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name")
|
||||||
stream: Optional[bool] = Field(default=False, description="Whether return stream")
|
stream: Optional[bool] = Field(default=False, description="Whether return stream")
|
||||||
language: Optional[str] = Field(default="hive", description="Language")
|
language: Optional[str] = Field(default="hive", description="Language")
|
||||||
@@ -140,109 +154,89 @@ class TriggerReqBody(BaseModel):
|
|||||||
|
|
||||||
@cache
|
@cache
|
||||||
def load_or_save_prompt_template(pm: PromptManager):
|
def load_or_save_prompt_template(pm: PromptManager):
|
||||||
ext_params = {
|
zh_ext_params = {
|
||||||
"chat_scene": "chat_with_code",
|
"chat_scene": "chat_with_code",
|
||||||
"sub_chat_scene": "data_analyst",
|
"sub_chat_scene": "data_analyst",
|
||||||
"prompt_type": "common",
|
"prompt_type": "common",
|
||||||
|
"prompt_language": PROMPT_LANG_ZH,
|
||||||
}
|
}
|
||||||
|
en_ext_params = {
|
||||||
|
"chat_scene": "chat_with_code",
|
||||||
|
"sub_chat_scene": "data_analyst",
|
||||||
|
"prompt_type": "common",
|
||||||
|
"prompt_language": PROMPT_LANG_EN,
|
||||||
|
}
|
||||||
|
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_ZH),
|
||||||
input_variables=["language"],
|
prompt_name=CODE_DEFAULT,
|
||||||
template=CODE_FIX_TEMPLATE_ZH,
|
**zh_ext_params,
|
||||||
),
|
)
|
||||||
|
pm.query_or_save(
|
||||||
|
PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_EN),
|
||||||
|
prompt_name=CODE_DEFAULT,
|
||||||
|
**en_ext_params,
|
||||||
|
)
|
||||||
|
pm.query_or_save(
|
||||||
|
PromptTemplate.from_template(CODE_FIX_TEMPLATE_ZH),
|
||||||
prompt_name=CODE_FIX,
|
prompt_name=CODE_FIX,
|
||||||
prompt_language="zh",
|
**zh_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_FIX_TEMPLATE_EN),
|
||||||
input_variables=["language"],
|
|
||||||
template=CODE_FIX_TEMPLATE_EN,
|
|
||||||
),
|
|
||||||
prompt_name=CODE_FIX,
|
prompt_name=CODE_FIX,
|
||||||
prompt_language="en",
|
**en_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_PERF_TEMPLATE_ZH),
|
||||||
input_variables=["language"],
|
|
||||||
template=CODE_PERF_TEMPLATE_ZH,
|
|
||||||
),
|
|
||||||
prompt_name=CODE_PERF,
|
prompt_name=CODE_PERF,
|
||||||
prompt_language="zh",
|
**zh_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_PERF_TEMPLATE_EN),
|
||||||
input_variables=["language"],
|
|
||||||
template=CODE_PERF_TEMPLATE_EN,
|
|
||||||
),
|
|
||||||
prompt_name=CODE_PERF,
|
prompt_name=CODE_PERF,
|
||||||
prompt_language="en",
|
**en_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_ZH),
|
||||||
input_variables=["language"],
|
|
||||||
template=CODE_EXPLAIN_TEMPLATE_ZH,
|
|
||||||
),
|
|
||||||
prompt_name=CODE_EXPLAIN,
|
prompt_name=CODE_EXPLAIN,
|
||||||
prompt_language="zh",
|
**zh_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_EN),
|
||||||
input_variables=["language"],
|
|
||||||
template=CODE_EXPLAIN_TEMPLATE_EN,
|
|
||||||
),
|
|
||||||
prompt_name=CODE_EXPLAIN,
|
prompt_name=CODE_EXPLAIN,
|
||||||
prompt_language="en",
|
**en_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_ZH),
|
||||||
input_variables=["language"],
|
|
||||||
template=CODE_COMMENT_TEMPLATE_ZH,
|
|
||||||
),
|
|
||||||
prompt_name=CODE_COMMENT,
|
prompt_name=CODE_COMMENT,
|
||||||
prompt_language="zh",
|
**zh_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_EN),
|
||||||
input_variables=["language"],
|
|
||||||
template=CODE_COMMENT_TEMPLATE_EN,
|
|
||||||
),
|
|
||||||
prompt_name=CODE_COMMENT,
|
prompt_name=CODE_COMMENT,
|
||||||
prompt_language="en",
|
**en_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_ZH),
|
||||||
input_variables=["source_language", "target_language"],
|
|
||||||
template=CODE_TRANSLATE_TEMPLATE_ZH,
|
|
||||||
),
|
|
||||||
prompt_name=CODE_TRANSLATE,
|
prompt_name=CODE_TRANSLATE,
|
||||||
prompt_language="zh",
|
**zh_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
pm.query_or_save(
|
pm.query_or_save(
|
||||||
PromptTemplate(
|
PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_EN),
|
||||||
input_variables=["source_language", "target_language"],
|
|
||||||
template=CODE_TRANSLATE_TEMPLATE_EN,
|
|
||||||
),
|
|
||||||
prompt_name=CODE_TRANSLATE,
|
prompt_name=CODE_TRANSLATE,
|
||||||
prompt_language="en",
|
**en_ext_params,
|
||||||
**ext_params,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
|
class PromptTemplateBuilderOperator(MapOperator[TriggerReqBody, ChatPromptTemplate]):
|
||||||
|
"""Build prompt template for chat with code."""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._default_prompt_manager = PromptManager()
|
self._default_prompt_manager = PromptManager()
|
||||||
|
|
||||||
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]:
|
async def map(self, input_value: TriggerReqBody) -> ChatPromptTemplate:
|
||||||
from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME
|
from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME
|
||||||
from dbgpt.serve.prompt.serve import Serve as PromptServe
|
from dbgpt.serve.prompt.serve import Serve as PromptServe
|
||||||
|
|
||||||
@@ -256,7 +250,24 @@ class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
|
|||||||
load_or_save_prompt_template(pm)
|
load_or_save_prompt_template(pm)
|
||||||
|
|
||||||
user_language = self.system_app.config.get_current_lang(default="en")
|
user_language = self.system_app.config.get_current_lang(default="en")
|
||||||
|
if not input_value.command:
|
||||||
|
# No command, just chat, not include system prompt.
|
||||||
|
default_prompt_list = pm.prefer_query(
|
||||||
|
CODE_DEFAULT, prefer_prompt_language=user_language
|
||||||
|
)
|
||||||
|
default_prompt_template = (
|
||||||
|
default_prompt_list[0].to_prompt_template().template
|
||||||
|
)
|
||||||
|
prompt = ChatPromptTemplate(
|
||||||
|
messages=[
|
||||||
|
SystemPromptTemplate.from_template(default_prompt_template),
|
||||||
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
|
HumanPromptTemplate.from_template("{user_input}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
# Query prompt template from prompt manager by command name
|
||||||
prompt_list = pm.prefer_query(
|
prompt_list = pm.prefer_query(
|
||||||
input_value.command, prefer_prompt_language=user_language
|
input_value.command, prefer_prompt_language=user_language
|
||||||
)
|
)
|
||||||
@@ -264,109 +275,38 @@ class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
|
|||||||
error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}"
|
error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
prompt = prompt_list[0].to_prompt_template()
|
prompt_template = prompt_list[0].to_prompt_template()
|
||||||
if input_value.command == CODE_TRANSLATE:
|
|
||||||
format_params = {
|
|
||||||
"source_language": input_value.language,
|
|
||||||
"target_language": input_value.target_language,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
format_params = {"language": input_value.language}
|
|
||||||
|
|
||||||
system_message = prompt.format(**format_params)
|
return ChatPromptTemplate(
|
||||||
messages = [
|
messages=[
|
||||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_message),
|
SystemPromptTemplate.from_template(prompt_template.template),
|
||||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content=input_value.messages),
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
]
|
HumanPromptTemplate.from_template("{user_input}"),
|
||||||
context = input_value.context.dict() if input_value.context else {}
|
]
|
||||||
return {
|
|
||||||
"messages": messages,
|
|
||||||
"stream": input_value.stream,
|
|
||||||
"model": input_value.model,
|
|
||||||
"context": context,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MyConversationOperator(PreConversationOperator):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
|
||||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(storage, message_storage, **kwargs)
|
|
||||||
|
|
||||||
def _get_conversion_serve(self):
|
|
||||||
from dbgpt.serve.conversation.serve import (
|
|
||||||
SERVE_APP_NAME as CONVERSATION_SERVE_APP_NAME,
|
|
||||||
)
|
)
|
||||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
|
||||||
|
|
||||||
conversation_serve: ConversationServe = self.system_app.get_component(
|
|
||||||
CONVERSATION_SERVE_APP_NAME, ConversationServe, default_component=None
|
|
||||||
)
|
|
||||||
return conversation_serve
|
|
||||||
|
|
||||||
@property
|
|
||||||
def storage(self):
|
|
||||||
if self._storage:
|
|
||||||
return self._storage
|
|
||||||
conversation_serve = self._get_conversion_serve()
|
|
||||||
if conversation_serve:
|
|
||||||
return conversation_serve.conv_storage
|
|
||||||
else:
|
|
||||||
logger.info("Conversation storage not found, use InMemoryStorage default")
|
|
||||||
self._storage = InMemoryStorage()
|
|
||||||
return self._storage
|
|
||||||
|
|
||||||
@property
|
|
||||||
def message_storage(self):
|
|
||||||
if self._message_storage:
|
|
||||||
return self._message_storage
|
|
||||||
conversation_serve = self._get_conversion_serve()
|
|
||||||
if conversation_serve:
|
|
||||||
return conversation_serve.message_storage
|
|
||||||
else:
|
|
||||||
logger.info("Message storage not found, use InMemoryStorage default")
|
|
||||||
self._message_storage = InMemoryStorage()
|
|
||||||
return self._message_storage
|
|
||||||
|
|
||||||
|
|
||||||
class MyLLMOperator(MixinLLMOperator, LLMOperator):
|
def parse_prompt_args(req: TriggerReqBody) -> Dict[str, Any]:
|
||||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
prompt_args = {"user_input": req.messages}
|
||||||
super().__init__(llm_client)
|
if not req.command:
|
||||||
LLMOperator.__init__(self, llm_client, **kwargs)
|
return prompt_args
|
||||||
|
if req.command == CODE_TRANSLATE:
|
||||||
|
prompt_args["source_language"] = req.language
|
||||||
|
prompt_args["target_language"] = req.target_language
|
||||||
|
else:
|
||||||
|
prompt_args["language"] = req.language
|
||||||
|
return prompt_args
|
||||||
|
|
||||||
|
|
||||||
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
|
async def build_model_request(
|
||||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
messages: List[ModelMessage], req_body: TriggerReqBody
|
||||||
super().__init__(llm_client)
|
) -> ModelRequest:
|
||||||
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
return ModelRequest.build_request(
|
||||||
|
model=req_body.model,
|
||||||
|
messages=messages,
|
||||||
def history_message_mapper(
|
context=req_body.context,
|
||||||
messages_by_round: List[List[ModelMessage]],
|
stream=req_body.stream,
|
||||||
) -> List[ModelMessage]:
|
)
|
||||||
"""Mapper for history conversation.
|
|
||||||
|
|
||||||
If there are multi system messages, just keep the first system message.
|
|
||||||
"""
|
|
||||||
has_system_message = False
|
|
||||||
mapper_messages = []
|
|
||||||
for messages in messages_by_round:
|
|
||||||
for message in messages:
|
|
||||||
if message.role == ModelMessageRoleType.SYSTEM:
|
|
||||||
if has_system_message:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
mapper_messages.append(message)
|
|
||||||
has_system_message = True
|
|
||||||
else:
|
|
||||||
mapper_messages.append(message)
|
|
||||||
print("history_message_mapper start:" + "=" * 70)
|
|
||||||
print(colored(ModelMessage.get_printable_message(mapper_messages), "green"))
|
|
||||||
print("history_message_mapper end:" + "=" * 72)
|
|
||||||
return mapper_messages
|
|
||||||
|
|
||||||
|
|
||||||
with DAG("dbgpt_awel_data_analyst_assistant") as dag:
|
with DAG("dbgpt_awel_data_analyst_assistant") as dag:
|
||||||
@@ -377,57 +317,59 @@ with DAG("dbgpt_awel_data_analyst_assistant") as dag:
|
|||||||
streaming_predict_func=lambda x: x.stream,
|
streaming_predict_func=lambda x: x.stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
copilot_task = CopilotOperator()
|
prompt_template_load_task = PromptTemplateBuilderOperator()
|
||||||
request_handle_task = RequestBuildOperator()
|
request_handle_task = RequestBuilderOperator()
|
||||||
|
|
||||||
# Pre-process conversation
|
# Load and store chat history
|
||||||
pre_conversation_task = MyConversationOperator()
|
chat_history_load_task = ServePreChatHistoryLoadOperator()
|
||||||
# Keep last k round conversation.
|
last_k_round = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_LAST_K_ROUND", 5))
|
||||||
history_conversation_task = BufferedConversationMapperOperator(
|
# History transform task, here we keep last k round messages
|
||||||
last_k_round=5, message_mapper=history_message_mapper
|
history_transform_task = BufferedConversationMapperOperator(
|
||||||
|
last_k_round=last_k_round
|
||||||
|
)
|
||||||
|
history_prompt_build_task = HistoryDynamicPromptBuilderOperator(
|
||||||
|
history_key="chat_history"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save conversation to storage.
|
model_request_build_task = JoinOperator(build_model_request)
|
||||||
post_conversation_task = PostConversationOperator()
|
|
||||||
# Save streaming conversation to storage.
|
|
||||||
post_streaming_conversation_task = PostStreamingConversationOperator()
|
|
||||||
|
|
||||||
# Use LLMOperator to generate response.
|
# Use BaseLLMOperator to generate response.
|
||||||
llm_task = MyLLMOperator(task_name="llm_task")
|
llm_task = LLMOperator(task_name="llm_task")
|
||||||
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
|
streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
|
||||||
branch_task = LLMBranchOperator(
|
branch_task = LLMBranchOperator(
|
||||||
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
||||||
)
|
)
|
||||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||||
openai_format_stream_task = OpenAIStreamingOperator()
|
openai_format_stream_task = OpenAIStreamingOutputOperator()
|
||||||
result_join_task = JoinOperator(
|
result_join_task = JoinOperator(
|
||||||
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
||||||
)
|
)
|
||||||
|
trigger >> prompt_template_load_task >> history_prompt_build_task
|
||||||
|
|
||||||
(
|
(
|
||||||
trigger
|
trigger
|
||||||
>> copilot_task
|
>> MapOperator(
|
||||||
>> request_handle_task
|
lambda req: ModelRequestContext(
|
||||||
>> pre_conversation_task
|
conv_uid=req.context.conv_uid,
|
||||||
>> history_conversation_task
|
stream=req.stream,
|
||||||
>> branch_task
|
chat_mode=req.context.chat_mode,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
>> chat_history_load_task
|
||||||
|
>> history_transform_task
|
||||||
|
>> history_prompt_build_task
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trigger >> MapOperator(parse_prompt_args) >> history_prompt_build_task
|
||||||
|
|
||||||
|
history_prompt_build_task >> model_request_build_task
|
||||||
|
trigger >> model_request_build_task
|
||||||
|
|
||||||
|
model_request_build_task >> branch_task
|
||||||
# The branch of no streaming response.
|
# The branch of no streaming response.
|
||||||
(
|
(branch_task >> llm_task >> model_parse_task >> result_join_task)
|
||||||
branch_task
|
|
||||||
>> llm_task
|
|
||||||
>> post_conversation_task
|
|
||||||
>> model_parse_task
|
|
||||||
>> result_join_task
|
|
||||||
)
|
|
||||||
# The branch of streaming response.
|
# The branch of streaming response.
|
||||||
(
|
(branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task)
|
||||||
branch_task
|
|
||||||
>> streaming_llm_task
|
|
||||||
>> post_streaming_conversation_task
|
|
||||||
>> openai_format_stream_task
|
|
||||||
>> result_join_task
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if dag.leaf_nodes[0].dev_mode:
|
if dag.leaf_nodes[0].dev_mode:
|
||||||
|
@@ -12,7 +12,7 @@
|
|||||||
# Fist round
|
# Fist round
|
||||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||||
-H "Content-Type: application/json" -d '{
|
-H "Content-Type: application/json" -d '{
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "'"$MODEL"'",
|
||||||
"context": {
|
"context": {
|
||||||
"conv_uid": "uuid_conv_1234"
|
"conv_uid": "uuid_conv_1234"
|
||||||
},
|
},
|
||||||
@@ -22,7 +22,7 @@
|
|||||||
# Second round
|
# Second round
|
||||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||||
-H "Content-Type: application/json" -d '{
|
-H "Content-Type: application/json" -d '{
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "'"$MODEL"'",
|
||||||
"context": {
|
"context": {
|
||||||
"conv_uid": "uuid_conv_1234"
|
"conv_uid": "uuid_conv_1234"
|
||||||
},
|
},
|
||||||
@@ -34,7 +34,7 @@
|
|||||||
|
|
||||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||||
-H "Content-Type: application/json" -d '{
|
-H "Content-Type: application/json" -d '{
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "'"$MODEL"'",
|
||||||
"context": {
|
"context": {
|
||||||
"conv_uid": "uuid_conv_stream_1234"
|
"conv_uid": "uuid_conv_stream_1234"
|
||||||
},
|
},
|
||||||
@@ -45,7 +45,7 @@
|
|||||||
# Second round
|
# Second round
|
||||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||||
-H "Content-Type: application/json" -d '{
|
-H "Content-Type: application/json" -d '{
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "'"$MODEL"'",
|
||||||
"context": {
|
"context": {
|
||||||
"conv_uid": "uuid_conv_stream_1234"
|
"conv_uid": "uuid_conv_stream_1234"
|
||||||
},
|
},
|
||||||
@@ -59,19 +59,27 @@ import logging
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.core import InMemoryStorage, LLMClient
|
from dbgpt.core import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanPromptTemplate,
|
||||||
|
InMemoryStorage,
|
||||||
|
MessagesPlaceholder,
|
||||||
|
ModelMessage,
|
||||||
|
ModelRequest,
|
||||||
|
ModelRequestContext,
|
||||||
|
SystemPromptTemplate,
|
||||||
|
)
|
||||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||||
from dbgpt.core.operator import (
|
from dbgpt.core.operator import (
|
||||||
BufferedConversationMapperOperator,
|
ChatComposerInput,
|
||||||
|
ChatHistoryPromptComposerOperator,
|
||||||
LLMBranchOperator,
|
LLMBranchOperator,
|
||||||
|
)
|
||||||
|
from dbgpt.model.operator import (
|
||||||
LLMOperator,
|
LLMOperator,
|
||||||
PostConversationOperator,
|
OpenAIStreamingOutputOperator,
|
||||||
PostStreamingConversationOperator,
|
|
||||||
PreConversationOperator,
|
|
||||||
RequestBuildOperator,
|
|
||||||
StreamingLLMOperator,
|
StreamingLLMOperator,
|
||||||
)
|
)
|
||||||
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -100,16 +108,15 @@ class TriggerReqBody(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MyLLMOperator(MixinLLMOperator, LLMOperator):
|
async def build_model_request(
|
||||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
messages: List[ModelMessage], req_body: TriggerReqBody
|
||||||
super().__init__(llm_client)
|
) -> ModelRequest:
|
||||||
LLMOperator.__init__(self, llm_client, **kwargs)
|
return ModelRequest.build_request(
|
||||||
|
model=req_body.model,
|
||||||
|
messages=messages,
|
||||||
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
|
context=req_body.context,
|
||||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
stream=req_body.stream,
|
||||||
super().__init__(llm_client)
|
)
|
||||||
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
|
with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
|
||||||
@@ -120,56 +127,53 @@ with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
|
|||||||
request_body=TriggerReqBody,
|
request_body=TriggerReqBody,
|
||||||
streaming_predict_func=lambda req: req.stream,
|
streaming_predict_func=lambda req: req.stream,
|
||||||
)
|
)
|
||||||
# Transform request body to model request.
|
prompt = ChatPromptTemplate(
|
||||||
request_handle_task = RequestBuildOperator()
|
messages=[
|
||||||
# Pre-process conversation, use InMemoryStorage to store conversation.
|
SystemPromptTemplate.from_template("You are a helpful chatbot."),
|
||||||
pre_conversation_task = PreConversationOperator(
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
storage=InMemoryStorage(), message_storage=InMemoryStorage()
|
HumanPromptTemplate.from_template("{user_input}"),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
# Keep last k round conversation.
|
|
||||||
history_conversation_task = BufferedConversationMapperOperator(last_k_round=5)
|
|
||||||
|
|
||||||
# Save conversation to storage.
|
composer_operator = ChatHistoryPromptComposerOperator(
|
||||||
post_conversation_task = PostConversationOperator()
|
prompt_template=prompt,
|
||||||
# Save streaming conversation to storage.
|
last_k_round=5,
|
||||||
post_streaming_conversation_task = PostStreamingConversationOperator()
|
storage=InMemoryStorage(),
|
||||||
|
message_storage=InMemoryStorage(),
|
||||||
|
)
|
||||||
|
|
||||||
# Use LLMOperator to generate response.
|
# Use BaseLLMOperator to generate response.
|
||||||
llm_task = MyLLMOperator(task_name="llm_task")
|
llm_task = LLMOperator(task_name="llm_task")
|
||||||
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
|
streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
|
||||||
branch_task = LLMBranchOperator(
|
branch_task = LLMBranchOperator(
|
||||||
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
||||||
)
|
)
|
||||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||||
openai_format_stream_task = OpenAIStreamingOperator()
|
openai_format_stream_task = OpenAIStreamingOutputOperator()
|
||||||
result_join_task = JoinOperator(
|
result_join_task = JoinOperator(
|
||||||
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
req_handle_task = MapOperator(
|
||||||
trigger
|
lambda req: ChatComposerInput(
|
||||||
>> request_handle_task
|
context=ModelRequestContext(
|
||||||
>> pre_conversation_task
|
conv_uid=req.context.conv_uid, stream=req.stream
|
||||||
>> history_conversation_task
|
),
|
||||||
>> branch_task
|
prompt_dict={"user_input": req.messages},
|
||||||
|
model_dict={
|
||||||
|
"model": req.model,
|
||||||
|
"context": req.context,
|
||||||
|
"stream": req.stream,
|
||||||
|
},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trigger >> req_handle_task >> composer_operator >> branch_task
|
||||||
|
|
||||||
# The branch of no streaming response.
|
# The branch of no streaming response.
|
||||||
(
|
branch_task >> llm_task >> model_parse_task >> result_join_task
|
||||||
branch_task
|
|
||||||
>> llm_task
|
|
||||||
>> post_conversation_task
|
|
||||||
>> model_parse_task
|
|
||||||
>> result_join_task
|
|
||||||
)
|
|
||||||
# The branch of streaming response.
|
# The branch of streaming response.
|
||||||
(
|
branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task
|
||||||
branch_task
|
|
||||||
>> streaming_llm_task
|
|
||||||
>> post_streaming_conversation_task
|
|
||||||
>> openai_format_stream_task
|
|
||||||
>> result_join_task
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if multi_round_dag.leaf_nodes[0].dev_mode:
|
if multi_round_dag.leaf_nodes[0].dev_mode:
|
||||||
|
@@ -31,3 +31,11 @@ with DAG("simple_dag_example") as dag:
|
|||||||
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
|
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
|
||||||
map_node = RequestHandleOperator()
|
map_node = RequestHandleOperator()
|
||||||
trigger >> map_node
|
trigger >> map_node
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if dag.leaf_nodes[0].dev_mode:
|
||||||
|
from dbgpt.core.awel import setup_dev_environment
|
||||||
|
|
||||||
|
setup_dev_environment([dag])
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
@@ -8,9 +8,10 @@
|
|||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
DBGPT_SERVER="http://127.0.0.1:5555"
|
DBGPT_SERVER="http://127.0.0.1:5555"
|
||||||
|
MODEL="gpt-3.5-turbo"
|
||||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
||||||
-H "Content-Type: application/json" -d '{
|
-H "Content-Type: application/json" -d '{
|
||||||
"model": "proxyllm",
|
"model": "'"$MODEL"'",
|
||||||
"messages": "hello"
|
"messages": "hello"
|
||||||
}'
|
}'
|
||||||
|
|
||||||
@@ -19,7 +20,7 @@
|
|||||||
|
|
||||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
||||||
-H "Content-Type: application/json" -d '{
|
-H "Content-Type: application/json" -d '{
|
||||||
"model": "proxyllm",
|
"model": "'"$MODEL"'",
|
||||||
"messages": "hello",
|
"messages": "hello",
|
||||||
"stream": true
|
"stream": true
|
||||||
}'
|
}'
|
||||||
@@ -29,7 +30,7 @@
|
|||||||
|
|
||||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \
|
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \
|
||||||
-H "Content-Type: application/json" -d '{
|
-H "Content-Type: application/json" -d '{
|
||||||
"model": "proxyllm",
|
"model": "'"$MODEL"'",
|
||||||
"messages": "hello"
|
"messages": "hello"
|
||||||
}'
|
}'
|
||||||
|
|
||||||
@@ -40,13 +41,13 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.core import LLMClient
|
from dbgpt.core import LLMClient
|
||||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||||
from dbgpt.core.operator import (
|
from dbgpt.core.operator import LLMBranchOperator, RequestBuilderOperator
|
||||||
LLMBranchOperator,
|
from dbgpt.model.operator import (
|
||||||
LLMOperator,
|
LLMOperator,
|
||||||
RequestBuildOperator,
|
MixinLLMOperator,
|
||||||
|
OpenAIStreamingOutputOperator,
|
||||||
StreamingLLMOperator,
|
StreamingLLMOperator,
|
||||||
)
|
)
|
||||||
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -59,18 +60,6 @@ class TriggerReqBody(BaseModel):
|
|||||||
stream: Optional[bool] = Field(default=False, description="Whether return stream")
|
stream: Optional[bool] = Field(default=False, description="Whether return stream")
|
||||||
|
|
||||||
|
|
||||||
class MyLLMOperator(MixinLLMOperator, LLMOperator):
|
|
||||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
|
||||||
super().__init__(llm_client)
|
|
||||||
LLMOperator.__init__(self, llm_client, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
|
|
||||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
|
||||||
super().__init__(llm_client)
|
|
||||||
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class MyModelToolOperator(
|
class MyModelToolOperator(
|
||||||
MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]]
|
MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]]
|
||||||
):
|
):
|
||||||
@@ -97,14 +86,14 @@ with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag:
|
|||||||
request_body=TriggerReqBody,
|
request_body=TriggerReqBody,
|
||||||
streaming_predict_func=lambda req: req.stream,
|
streaming_predict_func=lambda req: req.stream,
|
||||||
)
|
)
|
||||||
request_handle_task = RequestBuildOperator()
|
request_handle_task = RequestBuilderOperator()
|
||||||
llm_task = MyLLMOperator(task_name="llm_task")
|
llm_task = LLMOperator(task_name="llm_task")
|
||||||
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
|
streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
|
||||||
branch_task = LLMBranchOperator(
|
branch_task = LLMBranchOperator(
|
||||||
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
||||||
)
|
)
|
||||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||||
openai_format_stream_task = OpenAIStreamingOperator()
|
openai_format_stream_task = OpenAIStreamingOutputOperator()
|
||||||
result_join_task = JoinOperator(
|
result_join_task = JoinOperator(
|
||||||
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
||||||
)
|
)
|
||||||
|
@@ -1,16 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from dbgpt.core import BaseOutputParser, PromptTemplate
|
from dbgpt.core import BaseOutputParser
|
||||||
from dbgpt.core.awel import DAG
|
from dbgpt.core.awel import DAG
|
||||||
from dbgpt.core.operator import LLMOperator, RequestBuildOperator
|
from dbgpt.core.operator import (
|
||||||
|
BaseLLMOperator,
|
||||||
|
PromptBuilderOperator,
|
||||||
|
RequestBuilderOperator,
|
||||||
|
)
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model import OpenAILLMClient
|
||||||
|
|
||||||
with DAG("simple_sdk_llm_example_dag") as dag:
|
with DAG("simple_sdk_llm_example_dag") as dag:
|
||||||
prompt_task = PromptTemplate.from_template(
|
prompt_task = PromptBuilderOperator(
|
||||||
"Write a SQL of {dialect} to query all data of {table_name}."
|
"Write a SQL of {dialect} to query all data of {table_name}."
|
||||||
)
|
)
|
||||||
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
|
model_pre_handle_task = RequestBuilderOperator(model="gpt-3.5-turbo")
|
||||||
llm_task = LLMOperator(OpenAILLMClient())
|
llm_task = BaseLLMOperator(OpenAILLMClient())
|
||||||
out_parse_task = BaseOutputParser()
|
out_parse_task = BaseOutputParser()
|
||||||
prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task
|
prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task
|
||||||
|
|
||||||
|
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from dbgpt.core import PromptTemplate, SQLOutputParser
|
from dbgpt.core import SQLOutputParser
|
||||||
from dbgpt.core.awel import (
|
from dbgpt.core.awel import (
|
||||||
DAG,
|
DAG,
|
||||||
InputOperator,
|
InputOperator,
|
||||||
@@ -10,7 +10,11 @@ from dbgpt.core.awel import (
|
|||||||
MapOperator,
|
MapOperator,
|
||||||
SimpleCallDataInputSource,
|
SimpleCallDataInputSource,
|
||||||
)
|
)
|
||||||
from dbgpt.core.operator import LLMOperator, RequestBuildOperator
|
from dbgpt.core.operator import (
|
||||||
|
BaseLLMOperator,
|
||||||
|
PromptBuilderOperator,
|
||||||
|
RequestBuilderOperator,
|
||||||
|
)
|
||||||
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
|
from dbgpt.datasource.operator.datasource_operator import DatasourceOperator
|
||||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||||
from dbgpt.model import OpenAILLMClient
|
from dbgpt.model import OpenAILLMClient
|
||||||
@@ -116,9 +120,9 @@ with DAG("simple_sdk_llm_sql_example") as dag:
|
|||||||
retriever_task = DatasourceRetrieverOperator(connection=db_connection)
|
retriever_task = DatasourceRetrieverOperator(connection=db_connection)
|
||||||
# Merge the input data and the table structure information.
|
# Merge the input data and the table structure information.
|
||||||
prompt_input_task = JoinOperator(combine_function=_join_func)
|
prompt_input_task = JoinOperator(combine_function=_join_func)
|
||||||
prompt_task = PromptTemplate.from_template(_sql_prompt())
|
prompt_task = PromptBuilderOperator(_sql_prompt())
|
||||||
model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo")
|
model_pre_handle_task = RequestBuilderOperator(model="gpt-3.5-turbo")
|
||||||
llm_task = LLMOperator(OpenAILLMClient())
|
llm_task = BaseLLMOperator(OpenAILLMClient())
|
||||||
out_parse_task = SQLOutputParser()
|
out_parse_task = SQLOutputParser()
|
||||||
sql_parse_task = MapOperator(map_function=lambda x: x["sql"])
|
sql_parse_task = MapOperator(map_function=lambda x: x["sql"])
|
||||||
db_query_task = DatasourceOperator(connection=db_connection)
|
db_query_task = DatasourceOperator(connection=db_connection)
|
||||||
|
Reference in New Issue
Block a user