mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 09:06:55 +00:00
feat(rag): Support rag retriever evaluation (#1291)
This commit is contained in:
@@ -3,11 +3,13 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
@@ -421,3 +423,40 @@ class InputSource(ABC, Generic[T]):
|
||||
Returns:
|
||||
TaskOutput[T]: The output object read from current source
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: T) -> "InputSource[T]":
|
||||
"""Create an InputSource from data.
|
||||
|
||||
Args:
|
||||
data (T): The data to create the InputSource from.
|
||||
|
||||
Returns:
|
||||
InputSource[T]: The InputSource created from the data.
|
||||
"""
|
||||
from .task_impl import SimpleInputSource
|
||||
|
||||
return SimpleInputSource(data, streaming=False)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
|
||||
) -> "InputSource[T]":
|
||||
"""Create an InputSource from an iterable.
|
||||
|
||||
Args:
|
||||
iterable (List[T]): The iterable to create the InputSource from.
|
||||
|
||||
Returns:
|
||||
InputSource[T]: The InputSource created from the iterable.
|
||||
"""
|
||||
from .task_impl import SimpleInputSource
|
||||
|
||||
return SimpleInputSource(iterable, streaming=True)
|
||||
|
||||
@classmethod
|
||||
def from_callable(cls) -> "InputSource[T]":
|
||||
"""Create an InputSource from a callable."""
|
||||
from .task_impl import SimpleCallDataInputSource
|
||||
|
||||
return SimpleCallDataInputSource()
|
||||
|
@@ -261,13 +261,42 @@ def _is_async_iterator(obj):
|
||||
)
|
||||
|
||||
|
||||
def _is_async_iterable(obj):
|
||||
return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None))
|
||||
|
||||
|
||||
def _is_iterator(obj):
|
||||
return (
|
||||
hasattr(obj, "__iter__")
|
||||
and callable(getattr(obj, "__iter__", None))
|
||||
and hasattr(obj, "__next__")
|
||||
and callable(getattr(obj, "__next__", None))
|
||||
)
|
||||
|
||||
|
||||
def _is_iterable(obj):
|
||||
return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None))
|
||||
|
||||
|
||||
async def _to_async_iterator(obj) -> AsyncIterator:
|
||||
if _is_async_iterable(obj):
|
||||
async for item in obj:
|
||||
yield item
|
||||
elif _is_iterable(obj):
|
||||
for item in obj:
|
||||
yield item
|
||||
else:
|
||||
raise ValueError(f"Can not convert {obj} to AsyncIterator")
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
"""The base class of InputSource."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, streaming: Optional[bool] = None) -> None:
|
||||
"""Create a BaseInputSource."""
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
self._streaming_data = streaming
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
@@ -286,10 +315,15 @@ class BaseInputSource(InputSource, ABC):
|
||||
ValueError: If the input source is a stream and has been read.
|
||||
"""
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._streaming_data is None:
|
||||
streaming_data = _is_async_iterator(data) or _is_iterator(data)
|
||||
else:
|
||||
streaming_data = self._streaming_data
|
||||
if streaming_data:
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output: TaskOutput = SimpleStreamTaskOutput(data)
|
||||
it_data = _to_async_iterator(data)
|
||||
output: TaskOutput = SimpleStreamTaskOutput(it_data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
@@ -299,13 +333,13 @@ class BaseInputSource(InputSource, ABC):
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
"""The default implementation of InputSource."""
|
||||
|
||||
def __init__(self, data: Any) -> None:
|
||||
def __init__(self, data: Any, streaming: Optional[bool] = None) -> None:
|
||||
"""Create a SimpleInputSource.
|
||||
|
||||
Args:
|
||||
data (Any): The input data.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(streaming=streaming)
|
||||
self._data = data
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
|
Reference in New Issue
Block a user