feat(rag): Support rag retriever evaluation (#1291)

This commit is contained in:
Fangyin Cheng
2024-03-14 13:06:57 +08:00
committed by GitHub
parent cd2dcc253c
commit adaa68eb00
34 changed files with 1452 additions and 67 deletions

View File

@@ -7,6 +7,7 @@ from dbgpt.core.interface.cache import ( # noqa: F401
CachePolicy,
CacheValue,
)
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
from dbgpt.core.interface.llm import ( # noqa: F401
DefaultMessageConverter,
LLMClient,
@@ -103,4 +104,5 @@ __ALL__ = [
"DefaultStorageItemAdapter",
"QuerySpec",
"StorageError",
"Embeddings",
]

View File

@@ -55,6 +55,7 @@ from .trigger.http_trigger import (
CommonLLMHttpResponseBody,
HttpTrigger,
)
from .trigger.iterator_trigger import IteratorTrigger
_request_http_trigger_available = False
try:
@@ -100,6 +101,7 @@ __all__ = [
"TransformStreamAbsOperator",
"Trigger",
"HttpTrigger",
"IteratorTrigger",
"CommonLLMHTTPRequestContext",
"CommonLLMHttpResponseBody",
"CommonLLMHttpRequestBody",

View File

@@ -277,7 +277,7 @@ class InputOperator(BaseOperator, Generic[OUT]):
return task_output
class TriggerOperator(InputOperator, Generic[OUT]):
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
"""Operator node that triggers the DAG to run."""
def __init__(self, **kwargs) -> None:

View File

@@ -60,8 +60,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids,
)
if node.dag:
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
# if node.dag:
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
)
@@ -76,8 +76,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
if not streaming_call and node.dag:
# streaming call not work for dag end
await node.dag._after_dag_end()
if node.dag:
del self._running_dag_ctx[node.dag.dag_id]
# if node.dag:
# del self._running_dag_ctx[node.dag.dag_id]
return dag_ctx
async def _execute_node(

View File

@@ -3,11 +3,13 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
TypeVar,
@@ -421,3 +423,40 @@ class InputSource(ABC, Generic[T]):
Returns:
TaskOutput[T]: The output object read from current source
"""
@classmethod
def from_data(cls, data: T) -> "InputSource[T]":
"""Create an InputSource from data.
Args:
data (T): The data to create the InputSource from.
Returns:
InputSource[T]: The InputSource created from the data.
"""
from .task_impl import SimpleInputSource
return SimpleInputSource(data, streaming=False)
@classmethod
def from_iterable(
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
) -> "InputSource[T]":
"""Create an InputSource from an iterable.
Args:
iterable (List[T]): The iterable to create the InputSource from.
Returns:
InputSource[T]: The InputSource created from the iterable.
"""
from .task_impl import SimpleInputSource
return SimpleInputSource(iterable, streaming=True)
@classmethod
def from_callable(cls) -> "InputSource[T]":
"""Create an InputSource from a callable."""
from .task_impl import SimpleCallDataInputSource
return SimpleCallDataInputSource()

View File

@@ -261,13 +261,42 @@ def _is_async_iterator(obj):
)
def _is_async_iterable(obj):
return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None))
def _is_iterator(obj):
return (
hasattr(obj, "__iter__")
and callable(getattr(obj, "__iter__", None))
and hasattr(obj, "__next__")
and callable(getattr(obj, "__next__", None))
)
def _is_iterable(obj):
return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None))
async def _to_async_iterator(obj) -> AsyncIterator:
if _is_async_iterable(obj):
async for item in obj:
yield item
elif _is_iterable(obj):
for item in obj:
yield item
else:
raise ValueError(f"Can not convert {obj} to AsyncIterator")
class BaseInputSource(InputSource, ABC):
"""The base class of InputSource."""
def __init__(self) -> None:
def __init__(self, streaming: Optional[bool] = None) -> None:
"""Create a BaseInputSource."""
super().__init__()
self._is_read = False
self._streaming_data = streaming
@abstractmethod
def _read_data(self, task_ctx: TaskContext) -> Any:
@@ -286,10 +315,15 @@ class BaseInputSource(InputSource, ABC):
ValueError: If the input source is a stream and has been read.
"""
data = self._read_data(task_ctx)
if _is_async_iterator(data):
if self._streaming_data is None:
streaming_data = _is_async_iterator(data) or _is_iterator(data)
else:
streaming_data = self._streaming_data
if streaming_data:
if self._is_read:
raise ValueError(f"Input iterator {data} has been read!")
output: TaskOutput = SimpleStreamTaskOutput(data)
it_data = _to_async_iterator(data)
output: TaskOutput = SimpleStreamTaskOutput(it_data)
else:
output = SimpleTaskOutput(data)
self._is_read = True
@@ -299,13 +333,13 @@ class BaseInputSource(InputSource, ABC):
class SimpleInputSource(BaseInputSource):
"""The default implementation of InputSource."""
def __init__(self, data: Any) -> None:
def __init__(self, data: Any, streaming: Optional[bool] = None) -> None:
"""Create a SimpleInputSource.
Args:
data (Any): The input data.
"""
super().__init__()
super().__init__(streaming=streaming)
self._data = data
def _read_data(self, task_ctx: TaskContext) -> Any:

View File

@@ -0,0 +1,118 @@
from typing import AsyncIterator
import pytest
from dbgpt.core.awel import (
DAG,
InputSource,
MapOperator,
StreamifyAbsOperator,
TransformStreamAbsOperator,
)
from dbgpt.core.awel.trigger.iterator_trigger import IteratorTrigger
class NumberProducerOperator(StreamifyAbsOperator[int, int]):
"""Create a stream of numbers from 0 to `n-1`"""
async def streamify(self, n: int) -> AsyncIterator[int]:
for i in range(n):
yield i
class MyStreamingOperator(TransformStreamAbsOperator[int, int]):
async def transform_stream(self, data: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in data:
yield i * i
async def _check_stream_results(stream_results, expected_len):
assert len(stream_results) == expected_len
for _, result in stream_results:
i = 0
async for num in result:
assert num == i * i
i += 1
@pytest.mark.asyncio
async def test_single_data():
with DAG("test_single_data"):
trigger_task = IteratorTrigger(data=2)
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 1
assert results[0][1] == 4
with DAG("test_single_data_stream"):
trigger_task = IteratorTrigger(data=2, streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 1)
@pytest.mark.asyncio
async def test_list_data():
with DAG("test_list_data"):
trigger_task = IteratorTrigger(data=[0, 1, 2, 3])
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
with DAG("test_list_data_stream"):
trigger_task = IteratorTrigger(data=[0, 1, 2, 3], streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)
@pytest.mark.asyncio
async def test_async_iterator_data():
async def async_iter():
for i in range(4):
yield i
with DAG("test_async_iterator_data"):
trigger_task = IteratorTrigger(data=async_iter())
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
with DAG("test_async_iterator_data_stream"):
trigger_task = IteratorTrigger(data=async_iter(), streaming_call=True)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)
@pytest.mark.asyncio
async def test_input_source_data():
with DAG("test_input_source_data"):
trigger_task = IteratorTrigger(data=InputSource.from_iterable([0, 1, 2, 3]))
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = await trigger_task.trigger()
assert len(results) == 4
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
with DAG("test_input_source_data_stream"):
trigger_task = IteratorTrigger(
data=InputSource.from_iterable([0, 1, 2, 3]),
streaming_call=True,
)
number_task = NumberProducerOperator()
task = MyStreamingOperator()
trigger_task >> number_task >> task
stream_results = await trigger_task.trigger()
await _check_stream_results(stream_results, 4)

View File

@@ -2,16 +2,18 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Generic
from ..operators.common_operator import TriggerOperator
from ..task.base import OUT
class Trigger(TriggerOperator, ABC):
class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]):
"""Base class for all trigger classes.
Now only support http trigger.
"""
@abstractmethod
async def trigger(self) -> None:
async def trigger(self, **kwargs) -> Any:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@@ -397,9 +397,9 @@ class HttpTrigger(Trigger):
self._end_node: Optional[BaseOperator] = None
self._register_to_app = register_to_app
async def trigger(self) -> None:
async def trigger(self, **kwargs) -> Any:
"""Trigger the DAG. Not used in HttpTrigger."""
pass
raise NotImplementedError("HttpTrigger does not support trigger directly")
def register_to_app(self) -> bool:
"""Register the trigger to a FastAPI app.

View File

@@ -0,0 +1,148 @@
"""Trigger for iterator data."""
import asyncio
from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union, cast
from ..operators.base import BaseOperator
from ..task.base import InputSource, TaskState
from ..task.task_impl import DefaultTaskContext, _is_async_iterator, _is_iterable
from .base import Trigger
IterDataType = Union[InputSource, Iterator, AsyncIterator, Any]
async def _to_async_iterator(iter_data: IterDataType, task_id: str) -> AsyncIterator:
"""Convert iter_data to an async iterator."""
if _is_async_iterator(iter_data):
async for item in iter_data: # type: ignore
yield item
elif _is_iterable(iter_data):
for item in iter_data: # type: ignore
yield item
elif isinstance(iter_data, InputSource):
task_ctx: DefaultTaskContext[Any] = DefaultTaskContext(
task_id, TaskState.RUNNING, None
)
data = await iter_data.read(task_ctx)
if data.is_stream:
async for item in data.output_stream:
yield item
else:
yield data.output
else:
yield iter_data
class IteratorTrigger(Trigger):
"""Trigger for iterator data.
Trigger the dag with iterator data.
Return the list of results of the leaf nodes in the dag.
The times of dag running is the length of the iterator data.
"""
def __init__(
self,
data: IterDataType,
parallel_num: int = 1,
streaming_call: bool = False,
**kwargs
):
"""Create a IteratorTrigger.
Args:
data (IterDataType): The iterator data.
parallel_num (int, optional): The parallel number of the dag running.
Defaults to 1.
streaming_call (bool, optional): Whether the dag is a streaming call.
Defaults to False.
"""
self._iter_data = data
self._parallel_num = parallel_num
self._streaming_call = streaming_call
super().__init__(**kwargs)
async def trigger(
self, parallel_num: Optional[int] = None, **kwargs
) -> List[Tuple[Any, Any]]:
"""Trigger the dag with iterator data.
If the dag is a streaming call, return the list of async iterator.
Examples:
.. code-block:: python
import asyncio
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator
with DAG("test_dag") as dag:
trigger_task = IteratorTrigger([0, 1, 2, 3])
task = MapOperator(lambda x: x * x)
trigger_task >> task
results = asyncio.run(trigger_task.trigger())
# Fist element of the tuple is the input data, the second element is the
# output data of the leaf node.
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
.. code-block:: python
import asyncio
from datasets import Dataset
from dbgpt.core.awel import (
DAG,
IteratorTrigger,
MapOperator,
InputSource,
)
data_samples = {
"question": ["What is 1+1?", "What is 7*7?"],
"answer": [2, 49],
}
dataset = Dataset.from_dict(data_samples)
with DAG("test_dag_stream") as dag:
trigger_task = IteratorTrigger(InputSource.from_iterable(dataset))
task = MapOperator(lambda x: x["answer"])
trigger_task >> task
results = asyncio.run(trigger_task.trigger())
assert results == [
({"question": "What is 1+1?", "answer": 2}, 2),
({"question": "What is 7*7?", "answer": 49}, 49),
]
Args:
parallel_num (Optional[int], optional): The parallel number of the dag
running. Defaults to None.
Returns:
List[Tuple[Any, Any]]: The list of results of the leaf nodes in the dag.
The first element of the tuple is the input data, the second element is
the output data of the leaf node.
"""
dag = self.dag
if not dag:
raise ValueError("DAG is not set for IteratorTrigger")
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("IteratorTrigger just support one leaf node in dag")
end_node = cast(BaseOperator, leaf_nodes[0])
streaming_call = self._streaming_call
semaphore = asyncio.Semaphore(parallel_num or self._parallel_num)
task_id = self.node_id
async def call_stream(call_data: Any):
async for out in await end_node.call_stream(call_data):
yield out
async def run_node(call_data: Any):
async with semaphore:
if streaming_call:
task_output = call_stream(call_data)
else:
task_output = await end_node.call(call_data)
return call_data, task_output
tasks = []
async for data in _to_async_iterator(self._iter_data, task_id):
tasks.append(run_node(data))
results = await asyncio.gather(*tasks)
return results

View File

@@ -0,0 +1,32 @@
"""Interface for embedding models."""
import asyncio
from abc import ABC, abstractmethod
from typing import List
class Embeddings(ABC):
"""Interface for embedding models.
Refer to `Langchain Embeddings <https://github.com/langchain-ai/langchain/tree/
master/libs/langchain/langchain/embeddings>`_.
"""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_documents, texts
)
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await asyncio.get_running_loop().run_in_executor(
None, self.embed_query, text
)

View File

@@ -0,0 +1,253 @@
"""Evaluation module."""
import asyncio
import string
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Generic,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
)
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.util.similarity_util import calculate_cosine_similarity
from .embeddings import Embeddings
from .llm import LLMClient
if TYPE_CHECKING:
from dbgpt.core.awel.task.base import InputSource
QueryType = Union[str, Any]
PredictionType = Union[str, Any]
ContextType = Union[str, Sequence[str], Any]
DatasetType = Union["InputSource", Iterator, AsyncIterator]
class BaseEvaluationResult(BaseModel):
"""Base evaluation result."""
prediction: Optional[PredictionType] = Field(
None,
description="Prediction data(including the output of LLM, the data from "
"retrieval, etc.)",
)
contexts: Optional[ContextType] = Field(None, description="Context data")
score: Optional[float] = Field(None, description="Score for the prediction")
passing: Optional[bool] = Field(
None, description="Binary evaluation result (passing or not)"
)
metric_name: Optional[str] = Field(None, description="Name of the metric")
class EvaluationResult(BaseEvaluationResult):
"""Evaluation result.
Output of an BaseEvaluator.
"""
query: Optional[QueryType] = Field(None, description="Query data")
raw_dataset: Optional[Any] = Field(None, description="Raw dataset")
Q = TypeVar("Q")
P = TypeVar("P")
C = TypeVar("C")
class EvaluationMetric(ABC, Generic[P, C]):
"""Base class for evaluation metric."""
@property
def name(self) -> str:
"""Name of the metric."""
return self.__class__.__name__
async def compute(
self,
prediction: P,
contexts: Optional[Sequence[C]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric.
Args:
prediction(P): The prediction data.
contexts(Optional[Sequence[C]]): The context data.
Returns:
BaseEvaluationResult: The evaluation result.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.sync_compute, prediction, contexts
)
def sync_compute(
self,
prediction: P,
contexts: Optional[Sequence[C]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric.
Args:
prediction(P): The prediction data.
contexts(Optional[Sequence[C]]): The context data.
Returns:
BaseEvaluationResult: The evaluation result.
"""
raise NotImplementedError("sync_compute is not implemented")
class FunctionMetric(EvaluationMetric[P, C], Generic[P, C]):
"""Evaluation metric based on a function."""
def __init__(
self,
name: str,
func: Callable[
[P, Optional[Sequence[C]]],
BaseEvaluationResult,
],
):
"""Create a FunctionMetric.
Args:
name(str): The name of the metric.
func(Callable[[P, Optional[Sequence[C]]], BaseEvaluationResult]):
The function to use for evaluation.
"""
self._name = name
self.func = func
@property
def name(self) -> str:
"""Name of the metric."""
return self._name
async def compute(
self,
prediction: P,
context: Optional[Sequence[C]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric."""
return self.func(prediction, context)
class ExactMatchMetric(EvaluationMetric[str, str]):
"""Exact match metric.
Just support string prediction and context.
"""
def __init__(self, ignore_case: bool = False, ignore_punctuation: bool = False):
"""Create an ExactMatchMetric."""
self._ignore_case = ignore_case
self._ignore_punctuation = ignore_punctuation
async def compute(
self,
prediction: str,
contexts: Optional[Sequence[str]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric."""
if self._ignore_case:
prediction = prediction.lower()
if contexts:
contexts = [c.lower() for c in contexts]
if self._ignore_punctuation:
prediction = prediction.translate(str.maketrans("", "", string.punctuation))
if contexts:
contexts = [
c.translate(str.maketrans("", "", string.punctuation))
for c in contexts
]
score = 0 if not contexts else float(prediction in contexts)
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=score,
)
class SimilarityMetric(EvaluationMetric[str, str]):
"""Similarity metric.
Calculate the cosine similarity between a prediction and a list of contexts.
"""
def __init__(self, embeddings: Embeddings):
"""Create a SimilarityMetric with embeddings."""
self._embeddings = embeddings
def sync_compute(
self,
prediction: str,
contexts: Optional[Sequence[str]] = None,
) -> BaseEvaluationResult:
"""Compute the evaluation metric."""
if not contexts:
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=0.0,
)
try:
import numpy as np
except ImportError:
raise ImportError("numpy is required for SimilarityMetric")
similarity: np.ndarray = calculate_cosine_similarity(
self._embeddings, prediction, contexts
)
return BaseEvaluationResult(
prediction=prediction,
contexts=contexts,
score=float(similarity.mean()),
)
class Evaluator(ABC):
"""Base Evaluator class."""
def __init__(
self,
llm_client: Optional[LLMClient] = None,
):
"""Create an Evaluator."""
self.llm_client = llm_client
@abstractmethod
async def evaluate(
self,
dataset: DatasetType,
metrics: Optional[List[EvaluationMetric]] = None,
query_key: str = "query",
contexts_key: str = "contexts",
prediction_key: str = "prediction",
parallel_num: int = 1,
**kwargs
) -> List[List[EvaluationResult]]:
"""Run evaluation with a dataset and metrics.
Args:
dataset(DatasetType): The dataset to evaluate.
metrics(Optional[List[EvaluationMetric]]): The metrics to use for
evaluation.
query_key(str): The key for query in the dataset.
contexts_key(str): The key for contexts in the dataset.
prediction_key(str): The key for prediction in the dataset.
parallel_num(int): The number of parallel tasks.
kwargs: Additional arguments.
Returns:
List[List[EvaluationResult]]: The evaluation results, the length of the
result equals to the length of the dataset. The first element in the
list is the list of evaluation results for metrics.
"""