mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 10:34:30 +00:00
feat(rag): Support rag retriever evaluation (#1291)
This commit is contained in:
@@ -55,6 +55,7 @@ from .trigger.http_trigger import (
|
||||
CommonLLMHttpResponseBody,
|
||||
HttpTrigger,
|
||||
)
|
||||
from .trigger.iterator_trigger import IteratorTrigger
|
||||
|
||||
_request_http_trigger_available = False
|
||||
try:
|
||||
@@ -100,6 +101,7 @@ __all__ = [
|
||||
"TransformStreamAbsOperator",
|
||||
"Trigger",
|
||||
"HttpTrigger",
|
||||
"IteratorTrigger",
|
||||
"CommonLLMHTTPRequestContext",
|
||||
"CommonLLMHttpResponseBody",
|
||||
"CommonLLMHttpRequestBody",
|
||||
|
@@ -277,7 +277,7 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
return task_output
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator, Generic[OUT]):
|
||||
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
|
||||
"""Operator node that triggers the DAG to run."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
|
@@ -60,8 +60,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
streaming_call=streaming_call,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
if node.dag:
|
||||
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
|
||||
# if node.dag:
|
||||
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
|
||||
)
|
||||
@@ -76,8 +76,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
if not streaming_call and node.dag:
|
||||
# streaming call not work for dag end
|
||||
await node.dag._after_dag_end()
|
||||
if node.dag:
|
||||
del self._running_dag_ctx[node.dag.dag_id]
|
||||
# if node.dag:
|
||||
# del self._running_dag_ctx[node.dag.dag_id]
|
||||
return dag_ctx
|
||||
|
||||
async def _execute_node(
|
||||
|
@@ -3,11 +3,13 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
@@ -421,3 +423,40 @@ class InputSource(ABC, Generic[T]):
|
||||
Returns:
|
||||
TaskOutput[T]: The output object read from current source
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: T) -> "InputSource[T]":
|
||||
"""Create an InputSource from data.
|
||||
|
||||
Args:
|
||||
data (T): The data to create the InputSource from.
|
||||
|
||||
Returns:
|
||||
InputSource[T]: The InputSource created from the data.
|
||||
"""
|
||||
from .task_impl import SimpleInputSource
|
||||
|
||||
return SimpleInputSource(data, streaming=False)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
|
||||
) -> "InputSource[T]":
|
||||
"""Create an InputSource from an iterable.
|
||||
|
||||
Args:
|
||||
iterable (List[T]): The iterable to create the InputSource from.
|
||||
|
||||
Returns:
|
||||
InputSource[T]: The InputSource created from the iterable.
|
||||
"""
|
||||
from .task_impl import SimpleInputSource
|
||||
|
||||
return SimpleInputSource(iterable, streaming=True)
|
||||
|
||||
@classmethod
|
||||
def from_callable(cls) -> "InputSource[T]":
|
||||
"""Create an InputSource from a callable."""
|
||||
from .task_impl import SimpleCallDataInputSource
|
||||
|
||||
return SimpleCallDataInputSource()
|
||||
|
@@ -261,13 +261,42 @@ def _is_async_iterator(obj):
|
||||
)
|
||||
|
||||
|
||||
def _is_async_iterable(obj):
|
||||
return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None))
|
||||
|
||||
|
||||
def _is_iterator(obj):
|
||||
return (
|
||||
hasattr(obj, "__iter__")
|
||||
and callable(getattr(obj, "__iter__", None))
|
||||
and hasattr(obj, "__next__")
|
||||
and callable(getattr(obj, "__next__", None))
|
||||
)
|
||||
|
||||
|
||||
def _is_iterable(obj):
|
||||
return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None))
|
||||
|
||||
|
||||
async def _to_async_iterator(obj) -> AsyncIterator:
|
||||
if _is_async_iterable(obj):
|
||||
async for item in obj:
|
||||
yield item
|
||||
elif _is_iterable(obj):
|
||||
for item in obj:
|
||||
yield item
|
||||
else:
|
||||
raise ValueError(f"Can not convert {obj} to AsyncIterator")
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
"""The base class of InputSource."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, streaming: Optional[bool] = None) -> None:
|
||||
"""Create a BaseInputSource."""
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
self._streaming_data = streaming
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
@@ -286,10 +315,15 @@ class BaseInputSource(InputSource, ABC):
|
||||
ValueError: If the input source is a stream and has been read.
|
||||
"""
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._streaming_data is None:
|
||||
streaming_data = _is_async_iterator(data) or _is_iterator(data)
|
||||
else:
|
||||
streaming_data = self._streaming_data
|
||||
if streaming_data:
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output: TaskOutput = SimpleStreamTaskOutput(data)
|
||||
it_data = _to_async_iterator(data)
|
||||
output: TaskOutput = SimpleStreamTaskOutput(it_data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
@@ -299,13 +333,13 @@ class BaseInputSource(InputSource, ABC):
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
"""The default implementation of InputSource."""
|
||||
|
||||
def __init__(self, data: Any) -> None:
|
||||
def __init__(self, data: Any, streaming: Optional[bool] = None) -> None:
|
||||
"""Create a SimpleInputSource.
|
||||
|
||||
Args:
|
||||
data (Any): The input data.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(streaming=streaming)
|
||||
self._data = data
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
|
0
dbgpt/core/awel/tests/trigger/__init__.py
Normal file
0
dbgpt/core/awel/tests/trigger/__init__.py
Normal file
118
dbgpt/core/awel/tests/trigger/test_iterator_trigger.py
Normal file
118
dbgpt/core/awel/tests/trigger/test_iterator_trigger.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import AsyncIterator
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
InputSource,
|
||||
MapOperator,
|
||||
StreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from dbgpt.core.awel.trigger.iterator_trigger import IteratorTrigger
|
||||
|
||||
|
||||
class NumberProducerOperator(StreamifyAbsOperator[int, int]):
|
||||
"""Create a stream of numbers from 0 to `n-1`"""
|
||||
|
||||
async def streamify(self, n: int) -> AsyncIterator[int]:
|
||||
for i in range(n):
|
||||
yield i
|
||||
|
||||
|
||||
class MyStreamingOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def transform_stream(self, data: AsyncIterator[int]) -> AsyncIterator[int]:
|
||||
async for i in data:
|
||||
yield i * i
|
||||
|
||||
|
||||
async def _check_stream_results(stream_results, expected_len):
|
||||
assert len(stream_results) == expected_len
|
||||
for _, result in stream_results:
|
||||
i = 0
|
||||
async for num in result:
|
||||
assert num == i * i
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_data():
|
||||
with DAG("test_single_data"):
|
||||
trigger_task = IteratorTrigger(data=2)
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 1
|
||||
assert results[0][1] == 4
|
||||
|
||||
with DAG("test_single_data_stream"):
|
||||
trigger_task = IteratorTrigger(data=2, streaming_call=True)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_data():
|
||||
with DAG("test_list_data"):
|
||||
trigger_task = IteratorTrigger(data=[0, 1, 2, 3])
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 4
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
with DAG("test_list_data_stream"):
|
||||
trigger_task = IteratorTrigger(data=[0, 1, 2, 3], streaming_call=True)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_iterator_data():
|
||||
async def async_iter():
|
||||
for i in range(4):
|
||||
yield i
|
||||
|
||||
with DAG("test_async_iterator_data"):
|
||||
trigger_task = IteratorTrigger(data=async_iter())
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 4
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
with DAG("test_async_iterator_data_stream"):
|
||||
trigger_task = IteratorTrigger(data=async_iter(), streaming_call=True)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_source_data():
|
||||
with DAG("test_input_source_data"):
|
||||
trigger_task = IteratorTrigger(data=InputSource.from_iterable([0, 1, 2, 3]))
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 4
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
with DAG("test_input_source_data_stream"):
|
||||
trigger_task = IteratorTrigger(
|
||||
data=InputSource.from_iterable([0, 1, 2, 3]),
|
||||
streaming_call=True,
|
||||
)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 4)
|
@@ -2,16 +2,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic
|
||||
|
||||
from ..operators.common_operator import TriggerOperator
|
||||
from ..task.base import OUT
|
||||
|
||||
|
||||
class Trigger(TriggerOperator, ABC):
|
||||
class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]):
|
||||
"""Base class for all trigger classes.
|
||||
|
||||
Now only support http trigger.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
async def trigger(self, **kwargs) -> Any:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
||||
|
@@ -397,9 +397,9 @@ class HttpTrigger(Trigger):
|
||||
self._end_node: Optional[BaseOperator] = None
|
||||
self._register_to_app = register_to_app
|
||||
|
||||
async def trigger(self) -> None:
|
||||
async def trigger(self, **kwargs) -> Any:
|
||||
"""Trigger the DAG. Not used in HttpTrigger."""
|
||||
pass
|
||||
raise NotImplementedError("HttpTrigger does not support trigger directly")
|
||||
|
||||
def register_to_app(self) -> bool:
|
||||
"""Register the trigger to a FastAPI app.
|
||||
|
148
dbgpt/core/awel/trigger/iterator_trigger.py
Normal file
148
dbgpt/core/awel/trigger/iterator_trigger.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Trigger for iterator data."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union, cast
|
||||
|
||||
from ..operators.base import BaseOperator
|
||||
from ..task.base import InputSource, TaskState
|
||||
from ..task.task_impl import DefaultTaskContext, _is_async_iterator, _is_iterable
|
||||
from .base import Trigger
|
||||
|
||||
IterDataType = Union[InputSource, Iterator, AsyncIterator, Any]
|
||||
|
||||
|
||||
async def _to_async_iterator(iter_data: IterDataType, task_id: str) -> AsyncIterator:
|
||||
"""Convert iter_data to an async iterator."""
|
||||
if _is_async_iterator(iter_data):
|
||||
async for item in iter_data: # type: ignore
|
||||
yield item
|
||||
elif _is_iterable(iter_data):
|
||||
for item in iter_data: # type: ignore
|
||||
yield item
|
||||
elif isinstance(iter_data, InputSource):
|
||||
task_ctx: DefaultTaskContext[Any] = DefaultTaskContext(
|
||||
task_id, TaskState.RUNNING, None
|
||||
)
|
||||
data = await iter_data.read(task_ctx)
|
||||
if data.is_stream:
|
||||
async for item in data.output_stream:
|
||||
yield item
|
||||
else:
|
||||
yield data.output
|
||||
else:
|
||||
yield iter_data
|
||||
|
||||
|
||||
class IteratorTrigger(Trigger):
|
||||
"""Trigger for iterator data.
|
||||
|
||||
Trigger the dag with iterator data.
|
||||
Return the list of results of the leaf nodes in the dag.
|
||||
The times of dag running is the length of the iterator data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: IterDataType,
|
||||
parallel_num: int = 1,
|
||||
streaming_call: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a IteratorTrigger.
|
||||
|
||||
Args:
|
||||
data (IterDataType): The iterator data.
|
||||
parallel_num (int, optional): The parallel number of the dag running.
|
||||
Defaults to 1.
|
||||
streaming_call (bool, optional): Whether the dag is a streaming call.
|
||||
Defaults to False.
|
||||
"""
|
||||
self._iter_data = data
|
||||
self._parallel_num = parallel_num
|
||||
self._streaming_call = streaming_call
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def trigger(
|
||||
self, parallel_num: Optional[int] = None, **kwargs
|
||||
) -> List[Tuple[Any, Any]]:
|
||||
"""Trigger the dag with iterator data.
|
||||
|
||||
If the dag is a streaming call, return the list of async iterator.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator
|
||||
|
||||
with DAG("test_dag") as dag:
|
||||
trigger_task = IteratorTrigger([0, 1, 2, 3])
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = asyncio.run(trigger_task.trigger())
|
||||
# Fist element of the tuple is the input data, the second element is the
|
||||
# output data of the leaf node.
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from datasets import Dataset
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
IteratorTrigger,
|
||||
MapOperator,
|
||||
InputSource,
|
||||
)
|
||||
|
||||
data_samples = {
|
||||
"question": ["What is 1+1?", "What is 7*7?"],
|
||||
"answer": [2, 49],
|
||||
}
|
||||
dataset = Dataset.from_dict(data_samples)
|
||||
with DAG("test_dag_stream") as dag:
|
||||
trigger_task = IteratorTrigger(InputSource.from_iterable(dataset))
|
||||
task = MapOperator(lambda x: x["answer"])
|
||||
trigger_task >> task
|
||||
results = asyncio.run(trigger_task.trigger())
|
||||
assert results == [
|
||||
({"question": "What is 1+1?", "answer": 2}, 2),
|
||||
({"question": "What is 7*7?", "answer": 49}, 49),
|
||||
]
|
||||
Args:
|
||||
parallel_num (Optional[int], optional): The parallel number of the dag
|
||||
running. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Any, Any]]: The list of results of the leaf nodes in the dag.
|
||||
The first element of the tuple is the input data, the second element is
|
||||
the output data of the leaf node.
|
||||
"""
|
||||
dag = self.dag
|
||||
if not dag:
|
||||
raise ValueError("DAG is not set for IteratorTrigger")
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("IteratorTrigger just support one leaf node in dag")
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
streaming_call = self._streaming_call
|
||||
semaphore = asyncio.Semaphore(parallel_num or self._parallel_num)
|
||||
task_id = self.node_id
|
||||
|
||||
async def call_stream(call_data: Any):
|
||||
async for out in await end_node.call_stream(call_data):
|
||||
yield out
|
||||
|
||||
async def run_node(call_data: Any):
|
||||
async with semaphore:
|
||||
if streaming_call:
|
||||
task_output = call_stream(call_data)
|
||||
else:
|
||||
task_output = await end_node.call(call_data)
|
||||
return call_data, task_output
|
||||
|
||||
tasks = []
|
||||
async for data in _to_async_iterator(self._iter_data, task_id):
|
||||
tasks.append(run_node(data))
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
Reference in New Issue
Block a user