feat(rag): Support rag retriever evaluation (#1291)

This commit is contained in:
Fangyin Cheng
2024-03-14 13:06:57 +08:00
committed by GitHub
parent cd2dcc253c
commit adaa68eb00
34 changed files with 1452 additions and 67 deletions

View File

@@ -3,11 +3,13 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
TypeVar,
@@ -421,3 +423,40 @@ class InputSource(ABC, Generic[T]):
Returns:
TaskOutput[T]: The output object read from current source
"""
@classmethod
def from_data(cls, data: T) -> "InputSource[T]":
"""Create an InputSource from data.
Args:
data (T): The data to create the InputSource from.
Returns:
InputSource[T]: The InputSource created from the data.
"""
from .task_impl import SimpleInputSource
return SimpleInputSource(data, streaming=False)
@classmethod
def from_iterable(
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
) -> "InputSource[T]":
"""Create an InputSource from an iterable.
Args:
iterable (List[T]): The iterable to create the InputSource from.
Returns:
InputSource[T]: The InputSource created from the iterable.
"""
from .task_impl import SimpleInputSource
return SimpleInputSource(iterable, streaming=True)
@classmethod
def from_callable(cls) -> "InputSource[T]":
"""Create an InputSource from a callable."""
from .task_impl import SimpleCallDataInputSource
return SimpleCallDataInputSource()

View File

@@ -261,13 +261,42 @@ def _is_async_iterator(obj):
)
def _is_async_iterable(obj):
return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None))
def _is_iterator(obj):
return (
hasattr(obj, "__iter__")
and callable(getattr(obj, "__iter__", None))
and hasattr(obj, "__next__")
and callable(getattr(obj, "__next__", None))
)
def _is_iterable(obj):
return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None))
async def _to_async_iterator(obj) -> AsyncIterator:
if _is_async_iterable(obj):
async for item in obj:
yield item
elif _is_iterable(obj):
for item in obj:
yield item
else:
raise ValueError(f"Can not convert {obj} to AsyncIterator")
class BaseInputSource(InputSource, ABC):
"""The base class of InputSource."""
def __init__(self) -> None:
def __init__(self, streaming: Optional[bool] = None) -> None:
"""Create a BaseInputSource."""
super().__init__()
self._is_read = False
self._streaming_data = streaming
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
@@ -286,10 +315,15 @@ class BaseInputSource(InputSource, ABC):
ValueError: If the input source is a stream and has been read.
"""
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._streaming_data is None:
streaming_data = _is_async_iterator(data) or _is_iterator(data)
else:
streaming_data = self._streaming_data
if streaming_data:
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output: TaskOutput = SimpleStreamTaskOutput(data)
it_data = _to_async_iterator(data)
output: TaskOutput = SimpleStreamTaskOutput(it_data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
@@ -299,13 +333,13 @@ class BaseInputSource(InputSource, ABC):
class SimpleInputSource(BaseInputSource):
"""The default implementation of InputSource."""
def __init__(self, data: Any) -> None:
def __init__(self, data: Any, streaming: Optional[bool] = None) -> None:
"""Create a SimpleInputSource.
Args:
data (Any): The input data.
"""
super().__init__()
super().__init__(streaming=streaming)
self._data = data
def _read_data(self, task_ctx: TaskContext) -> Any: