mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 01:54:44 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
87
dbgpt/util/function_utils.py
Normal file
87
dbgpt/util/function_utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import Any, get_type_hints, get_origin, get_args
|
||||
from functools import wraps
|
||||
import inspect
|
||||
import asyncio
|
||||
|
||||
|
||||
def _is_instance_of_generic_type(obj, generic_type):
|
||||
"""Check if an object is an instance of a generic type."""
|
||||
if generic_type is Any:
|
||||
return True # Any type is compatible with any object
|
||||
|
||||
origin = get_origin(generic_type)
|
||||
if origin is None:
|
||||
return isinstance(obj, generic_type) # Handle non-generic types
|
||||
|
||||
args = get_args(generic_type)
|
||||
if not args:
|
||||
return isinstance(obj, origin)
|
||||
|
||||
# Check if object matches the generic origin (like list, dict)
|
||||
if not isinstance(obj, origin):
|
||||
return False
|
||||
|
||||
# For each item in the object, check if it matches the corresponding type argument
|
||||
for sub_obj, arg in zip(obj, args):
|
||||
# Skip check if the type argument is Any
|
||||
if arg is not Any and not isinstance(sub_obj, arg):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _sort_args(func, args, kwargs):
|
||||
sig = inspect.signature(func)
|
||||
type_hints = get_type_hints(func)
|
||||
|
||||
arg_types = [
|
||||
type_hints[param_name]
|
||||
for param_name in sig.parameters
|
||||
if param_name != "return" and param_name != "self"
|
||||
]
|
||||
|
||||
if "self" in sig.parameters:
|
||||
self_arg = [args[0]]
|
||||
other_args = args[1:]
|
||||
else:
|
||||
self_arg = []
|
||||
other_args = args
|
||||
|
||||
sorted_args = sorted(
|
||||
other_args,
|
||||
key=lambda x: next(
|
||||
i for i, t in enumerate(arg_types) if _is_instance_of_generic_type(x, t)
|
||||
),
|
||||
)
|
||||
return (*self_arg, *sorted_args), kwargs
|
||||
|
||||
|
||||
def rearrange_args_by_type(func):
|
||||
"""Decorator to rearrange the arguments of a function by type.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||
|
||||
@rearrange_args_by_type
|
||||
def sync_regular_function(a: int, b: str, c: float):
|
||||
return a, b, c
|
||||
|
||||
assert instance.sync_class_method(1, "b", 3.0) == (1, "b", 3.0)
|
||||
assert instance.sync_class_method("b", 3.0, 1) == (1, "b", 3.0)
|
||||
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
sorted_args, sorted_kwargs = _sort_args(func, args, kwargs)
|
||||
return func(*sorted_args, **sorted_kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
sorted_args, sorted_kwargs = _sort_args(func, args, kwargs)
|
||||
return await func(*sorted_args, **sorted_kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
@@ -10,11 +10,12 @@ needed), or truncating them so that they fit in a single LLM call.
|
||||
|
||||
import logging
|
||||
from string import Formatter
|
||||
from typing import Callable, List, Optional, Sequence
|
||||
from typing import Callable, List, Optional, Sequence, Set
|
||||
|
||||
from dbgpt._private.pydantic import Field, PrivateAttr, BaseModel
|
||||
|
||||
from dbgpt.util.global_helper import globals_helper
|
||||
from dbgpt.core.interface.prompt import get_template_vars
|
||||
from dbgpt._private.llm_metadata import LLMMetadata
|
||||
from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter
|
||||
|
||||
@@ -230,15 +231,3 @@ def get_empty_prompt_txt(template: str) -> str:
|
||||
all_kwargs = {**partial_kargs, **empty_kwargs}
|
||||
prompt = template.format(**all_kwargs)
|
||||
return prompt
|
||||
|
||||
|
||||
def get_template_vars(template_str: str) -> List[str]:
|
||||
"""Get template variables from a template string."""
|
||||
variables = []
|
||||
formatter = Formatter()
|
||||
|
||||
for _, variable_name, _, _ in formatter.parse(template_str):
|
||||
if variable_name:
|
||||
variables.append(variable_name)
|
||||
|
||||
return variables
|
||||
|
120
dbgpt/util/tests/test_function_utils.py
Normal file
120
dbgpt/util/tests/test_function_utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import pytest
|
||||
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||
|
||||
|
||||
class ChatPromptTemplate:
|
||||
pass
|
||||
|
||||
|
||||
class BaseMessage:
|
||||
pass
|
||||
|
||||
|
||||
class ModelMessage:
|
||||
pass
|
||||
|
||||
|
||||
class DummyClass:
|
||||
@rearrange_args_by_type
|
||||
async def class_method(self, a: int, b: str, c: float):
|
||||
return a, b, c
|
||||
|
||||
@rearrange_args_by_type
|
||||
async def merge_history(
|
||||
self,
|
||||
prompt: ChatPromptTemplate,
|
||||
history: List[BaseMessage],
|
||||
prompt_dict: Dict[str, Any],
|
||||
) -> List[ModelMessage]:
|
||||
return [type(prompt), type(history), type(prompt_dict)]
|
||||
|
||||
@rearrange_args_by_type
|
||||
def sync_class_method(self, a: int, b: str, c: float):
|
||||
return a, b, c
|
||||
|
||||
|
||||
@rearrange_args_by_type
|
||||
def sync_regular_function(a: int, b: str, c: float):
|
||||
return a, b, c
|
||||
|
||||
|
||||
@rearrange_args_by_type
|
||||
async def regular_function(a: int, b: str, c: float):
|
||||
return a, b, c
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_class_method_correct_order():
|
||||
instance = DummyClass()
|
||||
result = await instance.class_method(1, "b", 3.0)
|
||||
assert result == (1, "b", 3.0), "Class method failed with correct order"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_class_method_incorrect_order():
|
||||
instance = DummyClass()
|
||||
result = await instance.class_method("b", 3.0, 1)
|
||||
assert result == (1, "b", 3.0), "Class method failed with incorrect order"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_function_correct_order():
|
||||
result = await regular_function(1, "b", 3.0)
|
||||
assert result == (1, "b", 3.0), "Regular function failed with correct order"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_function_incorrect_order():
|
||||
result = await regular_function("b", 3.0, 1)
|
||||
assert result == (1, "b", 3.0), "Regular function failed with incorrect order"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_history_correct_order():
|
||||
instance = DummyClass()
|
||||
result = await instance.merge_history(
|
||||
ChatPromptTemplate(), [BaseMessage()], {"key": "value"}
|
||||
)
|
||||
assert result == [ChatPromptTemplate, list, dict], "Failed with correct order"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_history_incorrect_order_1():
|
||||
instance = DummyClass()
|
||||
result = await instance.merge_history(
|
||||
[BaseMessage()], ChatPromptTemplate(), {"key": "value"}
|
||||
)
|
||||
assert result == [ChatPromptTemplate, list, dict], "Failed with incorrect order 1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_history_incorrect_order_2():
|
||||
instance = DummyClass()
|
||||
result = await instance.merge_history(
|
||||
{"key": "value"}, [BaseMessage()], ChatPromptTemplate()
|
||||
)
|
||||
assert result == [ChatPromptTemplate, list, dict], "Failed with incorrect order 2"
|
||||
|
||||
|
||||
def test_sync_class_method_correct_order():
|
||||
instance = DummyClass()
|
||||
result = instance.sync_class_method(1, "b", 3.0)
|
||||
assert result == (1, "b", 3.0), "Sync class method failed with correct order"
|
||||
|
||||
|
||||
def test_sync_class_method_incorrect_order():
|
||||
instance = DummyClass()
|
||||
result = instance.sync_class_method("b", 3.0, 1)
|
||||
assert result == (1, "b", 3.0), "Sync class method failed with incorrect order"
|
||||
|
||||
|
||||
def test_sync_regular_function_correct_order():
|
||||
result = sync_regular_function(1, "b", 3.0)
|
||||
assert result == (1, "b", 3.0), "Sync regular function failed with correct order"
|
||||
|
||||
|
||||
def test_sync_regular_function_incorrect_order():
|
||||
result = sync_regular_function("b", 3.0, 1)
|
||||
assert result == (1, "b", 3.0), "Sync regular function failed with incorrect order"
|
Reference in New Issue
Block a user