mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 09:37:03 +00:00
feat(rag): Support rag retriever evaluation (#1291)
This commit is contained in:
@@ -7,6 +7,7 @@ from dbgpt.core.interface.cache import ( # noqa: F401
|
||||
CachePolicy,
|
||||
CacheValue,
|
||||
)
|
||||
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
|
||||
from dbgpt.core.interface.llm import ( # noqa: F401
|
||||
DefaultMessageConverter,
|
||||
LLMClient,
|
||||
@@ -103,4 +104,5 @@ __ALL__ = [
|
||||
"DefaultStorageItemAdapter",
|
||||
"QuerySpec",
|
||||
"StorageError",
|
||||
"Embeddings",
|
||||
]
|
||||
|
@@ -55,6 +55,7 @@ from .trigger.http_trigger import (
|
||||
CommonLLMHttpResponseBody,
|
||||
HttpTrigger,
|
||||
)
|
||||
from .trigger.iterator_trigger import IteratorTrigger
|
||||
|
||||
_request_http_trigger_available = False
|
||||
try:
|
||||
@@ -100,6 +101,7 @@ __all__ = [
|
||||
"TransformStreamAbsOperator",
|
||||
"Trigger",
|
||||
"HttpTrigger",
|
||||
"IteratorTrigger",
|
||||
"CommonLLMHTTPRequestContext",
|
||||
"CommonLLMHttpResponseBody",
|
||||
"CommonLLMHttpRequestBody",
|
||||
|
@@ -277,7 +277,7 @@ class InputOperator(BaseOperator, Generic[OUT]):
|
||||
return task_output
|
||||
|
||||
|
||||
class TriggerOperator(InputOperator, Generic[OUT]):
|
||||
class TriggerOperator(InputOperator[OUT], Generic[OUT]):
|
||||
"""Operator node that triggers the DAG to run."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
|
@@ -60,8 +60,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
streaming_call=streaming_call,
|
||||
node_name_to_ids=job_manager._node_name_to_ids,
|
||||
)
|
||||
if node.dag:
|
||||
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
|
||||
# if node.dag:
|
||||
# self._running_dag_ctx[node.dag.dag_id] = dag_ctx
|
||||
logger.info(
|
||||
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
|
||||
)
|
||||
@@ -76,8 +76,8 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
||||
if not streaming_call and node.dag:
|
||||
# streaming call not work for dag end
|
||||
await node.dag._after_dag_end()
|
||||
if node.dag:
|
||||
del self._running_dag_ctx[node.dag.dag_id]
|
||||
# if node.dag:
|
||||
# del self._running_dag_ctx[node.dag.dag_id]
|
||||
return dag_ctx
|
||||
|
||||
async def _execute_node(
|
||||
|
@@ -3,11 +3,13 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
@@ -421,3 +423,40 @@ class InputSource(ABC, Generic[T]):
|
||||
Returns:
|
||||
TaskOutput[T]: The output object read from current source
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: T) -> "InputSource[T]":
|
||||
"""Create an InputSource from data.
|
||||
|
||||
Args:
|
||||
data (T): The data to create the InputSource from.
|
||||
|
||||
Returns:
|
||||
InputSource[T]: The InputSource created from the data.
|
||||
"""
|
||||
from .task_impl import SimpleInputSource
|
||||
|
||||
return SimpleInputSource(data, streaming=False)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls, iterable: Union[AsyncIterable[T], Iterable[T]]
|
||||
) -> "InputSource[T]":
|
||||
"""Create an InputSource from an iterable.
|
||||
|
||||
Args:
|
||||
iterable (List[T]): The iterable to create the InputSource from.
|
||||
|
||||
Returns:
|
||||
InputSource[T]: The InputSource created from the iterable.
|
||||
"""
|
||||
from .task_impl import SimpleInputSource
|
||||
|
||||
return SimpleInputSource(iterable, streaming=True)
|
||||
|
||||
@classmethod
|
||||
def from_callable(cls) -> "InputSource[T]":
|
||||
"""Create an InputSource from a callable."""
|
||||
from .task_impl import SimpleCallDataInputSource
|
||||
|
||||
return SimpleCallDataInputSource()
|
||||
|
@@ -261,13 +261,42 @@ def _is_async_iterator(obj):
|
||||
)
|
||||
|
||||
|
||||
def _is_async_iterable(obj):
|
||||
return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None))
|
||||
|
||||
|
||||
def _is_iterator(obj):
|
||||
return (
|
||||
hasattr(obj, "__iter__")
|
||||
and callable(getattr(obj, "__iter__", None))
|
||||
and hasattr(obj, "__next__")
|
||||
and callable(getattr(obj, "__next__", None))
|
||||
)
|
||||
|
||||
|
||||
def _is_iterable(obj):
|
||||
return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None))
|
||||
|
||||
|
||||
async def _to_async_iterator(obj) -> AsyncIterator:
|
||||
if _is_async_iterable(obj):
|
||||
async for item in obj:
|
||||
yield item
|
||||
elif _is_iterable(obj):
|
||||
for item in obj:
|
||||
yield item
|
||||
else:
|
||||
raise ValueError(f"Can not convert {obj} to AsyncIterator")
|
||||
|
||||
|
||||
class BaseInputSource(InputSource, ABC):
|
||||
"""The base class of InputSource."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, streaming: Optional[bool] = None) -> None:
|
||||
"""Create a BaseInputSource."""
|
||||
super().__init__()
|
||||
self._is_read = False
|
||||
self._streaming_data = streaming
|
||||
|
||||
@abstractmethod
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
@@ -286,10 +315,15 @@ class BaseInputSource(InputSource, ABC):
|
||||
ValueError: If the input source is a stream and has been read.
|
||||
"""
|
||||
data = self._read_data(task_ctx)
|
||||
if _is_async_iterator(data):
|
||||
if self._streaming_data is None:
|
||||
streaming_data = _is_async_iterator(data) or _is_iterator(data)
|
||||
else:
|
||||
streaming_data = self._streaming_data
|
||||
if streaming_data:
|
||||
if self._is_read:
|
||||
raise ValueError(f"Input iterator {data} has been read!")
|
||||
output: TaskOutput = SimpleStreamTaskOutput(data)
|
||||
it_data = _to_async_iterator(data)
|
||||
output: TaskOutput = SimpleStreamTaskOutput(it_data)
|
||||
else:
|
||||
output = SimpleTaskOutput(data)
|
||||
self._is_read = True
|
||||
@@ -299,13 +333,13 @@ class BaseInputSource(InputSource, ABC):
|
||||
class SimpleInputSource(BaseInputSource):
|
||||
"""The default implementation of InputSource."""
|
||||
|
||||
def __init__(self, data: Any) -> None:
|
||||
def __init__(self, data: Any, streaming: Optional[bool] = None) -> None:
|
||||
"""Create a SimpleInputSource.
|
||||
|
||||
Args:
|
||||
data (Any): The input data.
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(streaming=streaming)
|
||||
self._data = data
|
||||
|
||||
def _read_data(self, task_ctx: TaskContext) -> Any:
|
||||
|
0
dbgpt/core/awel/tests/trigger/__init__.py
Normal file
0
dbgpt/core/awel/tests/trigger/__init__.py
Normal file
118
dbgpt/core/awel/tests/trigger/test_iterator_trigger.py
Normal file
118
dbgpt/core/awel/tests/trigger/test_iterator_trigger.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import AsyncIterator
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
InputSource,
|
||||
MapOperator,
|
||||
StreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from dbgpt.core.awel.trigger.iterator_trigger import IteratorTrigger
|
||||
|
||||
|
||||
class NumberProducerOperator(StreamifyAbsOperator[int, int]):
|
||||
"""Create a stream of numbers from 0 to `n-1`"""
|
||||
|
||||
async def streamify(self, n: int) -> AsyncIterator[int]:
|
||||
for i in range(n):
|
||||
yield i
|
||||
|
||||
|
||||
class MyStreamingOperator(TransformStreamAbsOperator[int, int]):
|
||||
async def transform_stream(self, data: AsyncIterator[int]) -> AsyncIterator[int]:
|
||||
async for i in data:
|
||||
yield i * i
|
||||
|
||||
|
||||
async def _check_stream_results(stream_results, expected_len):
|
||||
assert len(stream_results) == expected_len
|
||||
for _, result in stream_results:
|
||||
i = 0
|
||||
async for num in result:
|
||||
assert num == i * i
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_data():
|
||||
with DAG("test_single_data"):
|
||||
trigger_task = IteratorTrigger(data=2)
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 1
|
||||
assert results[0][1] == 4
|
||||
|
||||
with DAG("test_single_data_stream"):
|
||||
trigger_task = IteratorTrigger(data=2, streaming_call=True)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_data():
|
||||
with DAG("test_list_data"):
|
||||
trigger_task = IteratorTrigger(data=[0, 1, 2, 3])
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 4
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
with DAG("test_list_data_stream"):
|
||||
trigger_task = IteratorTrigger(data=[0, 1, 2, 3], streaming_call=True)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_iterator_data():
|
||||
async def async_iter():
|
||||
for i in range(4):
|
||||
yield i
|
||||
|
||||
with DAG("test_async_iterator_data"):
|
||||
trigger_task = IteratorTrigger(data=async_iter())
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 4
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
with DAG("test_async_iterator_data_stream"):
|
||||
trigger_task = IteratorTrigger(data=async_iter(), streaming_call=True)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_source_data():
|
||||
with DAG("test_input_source_data"):
|
||||
trigger_task = IteratorTrigger(data=InputSource.from_iterable([0, 1, 2, 3]))
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = await trigger_task.trigger()
|
||||
assert len(results) == 4
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
with DAG("test_input_source_data_stream"):
|
||||
trigger_task = IteratorTrigger(
|
||||
data=InputSource.from_iterable([0, 1, 2, 3]),
|
||||
streaming_call=True,
|
||||
)
|
||||
number_task = NumberProducerOperator()
|
||||
task = MyStreamingOperator()
|
||||
trigger_task >> number_task >> task
|
||||
stream_results = await trigger_task.trigger()
|
||||
await _check_stream_results(stream_results, 4)
|
@@ -2,16 +2,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic
|
||||
|
||||
from ..operators.common_operator import TriggerOperator
|
||||
from ..task.base import OUT
|
||||
|
||||
|
||||
class Trigger(TriggerOperator, ABC):
|
||||
class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]):
|
||||
"""Base class for all trigger classes.
|
||||
|
||||
Now only support http trigger.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
async def trigger(self, **kwargs) -> Any:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
||||
|
@@ -397,9 +397,9 @@ class HttpTrigger(Trigger):
|
||||
self._end_node: Optional[BaseOperator] = None
|
||||
self._register_to_app = register_to_app
|
||||
|
||||
async def trigger(self) -> None:
|
||||
async def trigger(self, **kwargs) -> Any:
|
||||
"""Trigger the DAG. Not used in HttpTrigger."""
|
||||
pass
|
||||
raise NotImplementedError("HttpTrigger does not support trigger directly")
|
||||
|
||||
def register_to_app(self) -> bool:
|
||||
"""Register the trigger to a FastAPI app.
|
||||
|
148
dbgpt/core/awel/trigger/iterator_trigger.py
Normal file
148
dbgpt/core/awel/trigger/iterator_trigger.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Trigger for iterator data."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union, cast
|
||||
|
||||
from ..operators.base import BaseOperator
|
||||
from ..task.base import InputSource, TaskState
|
||||
from ..task.task_impl import DefaultTaskContext, _is_async_iterator, _is_iterable
|
||||
from .base import Trigger
|
||||
|
||||
IterDataType = Union[InputSource, Iterator, AsyncIterator, Any]
|
||||
|
||||
|
||||
async def _to_async_iterator(iter_data: IterDataType, task_id: str) -> AsyncIterator:
|
||||
"""Convert iter_data to an async iterator."""
|
||||
if _is_async_iterator(iter_data):
|
||||
async for item in iter_data: # type: ignore
|
||||
yield item
|
||||
elif _is_iterable(iter_data):
|
||||
for item in iter_data: # type: ignore
|
||||
yield item
|
||||
elif isinstance(iter_data, InputSource):
|
||||
task_ctx: DefaultTaskContext[Any] = DefaultTaskContext(
|
||||
task_id, TaskState.RUNNING, None
|
||||
)
|
||||
data = await iter_data.read(task_ctx)
|
||||
if data.is_stream:
|
||||
async for item in data.output_stream:
|
||||
yield item
|
||||
else:
|
||||
yield data.output
|
||||
else:
|
||||
yield iter_data
|
||||
|
||||
|
||||
class IteratorTrigger(Trigger):
|
||||
"""Trigger for iterator data.
|
||||
|
||||
Trigger the dag with iterator data.
|
||||
Return the list of results of the leaf nodes in the dag.
|
||||
The times of dag running is the length of the iterator data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: IterDataType,
|
||||
parallel_num: int = 1,
|
||||
streaming_call: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a IteratorTrigger.
|
||||
|
||||
Args:
|
||||
data (IterDataType): The iterator data.
|
||||
parallel_num (int, optional): The parallel number of the dag running.
|
||||
Defaults to 1.
|
||||
streaming_call (bool, optional): Whether the dag is a streaming call.
|
||||
Defaults to False.
|
||||
"""
|
||||
self._iter_data = data
|
||||
self._parallel_num = parallel_num
|
||||
self._streaming_call = streaming_call
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def trigger(
|
||||
self, parallel_num: Optional[int] = None, **kwargs
|
||||
) -> List[Tuple[Any, Any]]:
|
||||
"""Trigger the dag with iterator data.
|
||||
|
||||
If the dag is a streaming call, return the list of async iterator.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator
|
||||
|
||||
with DAG("test_dag") as dag:
|
||||
trigger_task = IteratorTrigger([0, 1, 2, 3])
|
||||
task = MapOperator(lambda x: x * x)
|
||||
trigger_task >> task
|
||||
results = asyncio.run(trigger_task.trigger())
|
||||
# Fist element of the tuple is the input data, the second element is the
|
||||
# output data of the leaf node.
|
||||
assert results == [(0, 0), (1, 1), (2, 4), (3, 9)]
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from datasets import Dataset
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
IteratorTrigger,
|
||||
MapOperator,
|
||||
InputSource,
|
||||
)
|
||||
|
||||
data_samples = {
|
||||
"question": ["What is 1+1?", "What is 7*7?"],
|
||||
"answer": [2, 49],
|
||||
}
|
||||
dataset = Dataset.from_dict(data_samples)
|
||||
with DAG("test_dag_stream") as dag:
|
||||
trigger_task = IteratorTrigger(InputSource.from_iterable(dataset))
|
||||
task = MapOperator(lambda x: x["answer"])
|
||||
trigger_task >> task
|
||||
results = asyncio.run(trigger_task.trigger())
|
||||
assert results == [
|
||||
({"question": "What is 1+1?", "answer": 2}, 2),
|
||||
({"question": "What is 7*7?", "answer": 49}, 49),
|
||||
]
|
||||
Args:
|
||||
parallel_num (Optional[int], optional): The parallel number of the dag
|
||||
running. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Any, Any]]: The list of results of the leaf nodes in the dag.
|
||||
The first element of the tuple is the input data, the second element is
|
||||
the output data of the leaf node.
|
||||
"""
|
||||
dag = self.dag
|
||||
if not dag:
|
||||
raise ValueError("DAG is not set for IteratorTrigger")
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("IteratorTrigger just support one leaf node in dag")
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
streaming_call = self._streaming_call
|
||||
semaphore = asyncio.Semaphore(parallel_num or self._parallel_num)
|
||||
task_id = self.node_id
|
||||
|
||||
async def call_stream(call_data: Any):
|
||||
async for out in await end_node.call_stream(call_data):
|
||||
yield out
|
||||
|
||||
async def run_node(call_data: Any):
|
||||
async with semaphore:
|
||||
if streaming_call:
|
||||
task_output = call_stream(call_data)
|
||||
else:
|
||||
task_output = await end_node.call(call_data)
|
||||
return call_data, task_output
|
||||
|
||||
tasks = []
|
||||
async for data in _to_async_iterator(self._iter_data, task_id):
|
||||
tasks.append(run_node(data))
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
32
dbgpt/core/interface/embeddings.py
Normal file
32
dbgpt/core/interface/embeddings.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Interface for embedding models."""
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models.
|
||||
|
||||
Refer to `Langchain Embeddings <https://github.com/langchain-ai/langchain/tree/
|
||||
master/libs/langchain/langchain/embeddings>`_.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronous Embed search docs."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_documents, texts
|
||||
)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.embed_query, text
|
||||
)
|
253
dbgpt/core/interface/evaluation.py
Normal file
253
dbgpt/core/interface/evaluation.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Evaluation module."""
|
||||
import asyncio
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.util.similarity_util import calculate_cosine_similarity
|
||||
|
||||
from .embeddings import Embeddings
|
||||
from .llm import LLMClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.core.awel.task.base import InputSource
|
||||
|
||||
QueryType = Union[str, Any]
|
||||
PredictionType = Union[str, Any]
|
||||
ContextType = Union[str, Sequence[str], Any]
|
||||
DatasetType = Union["InputSource", Iterator, AsyncIterator]
|
||||
|
||||
|
||||
class BaseEvaluationResult(BaseModel):
|
||||
"""Base evaluation result."""
|
||||
|
||||
prediction: Optional[PredictionType] = Field(
|
||||
None,
|
||||
description="Prediction data(including the output of LLM, the data from "
|
||||
"retrieval, etc.)",
|
||||
)
|
||||
contexts: Optional[ContextType] = Field(None, description="Context data")
|
||||
score: Optional[float] = Field(None, description="Score for the prediction")
|
||||
passing: Optional[bool] = Field(
|
||||
None, description="Binary evaluation result (passing or not)"
|
||||
)
|
||||
metric_name: Optional[str] = Field(None, description="Name of the metric")
|
||||
|
||||
|
||||
class EvaluationResult(BaseEvaluationResult):
|
||||
"""Evaluation result.
|
||||
|
||||
Output of an BaseEvaluator.
|
||||
"""
|
||||
|
||||
query: Optional[QueryType] = Field(None, description="Query data")
|
||||
raw_dataset: Optional[Any] = Field(None, description="Raw dataset")
|
||||
|
||||
|
||||
Q = TypeVar("Q")
|
||||
P = TypeVar("P")
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
class EvaluationMetric(ABC, Generic[P, C]):
|
||||
"""Base class for evaluation metric."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the metric."""
|
||||
return self.__class__.__name__
|
||||
|
||||
async def compute(
|
||||
self,
|
||||
prediction: P,
|
||||
contexts: Optional[Sequence[C]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric.
|
||||
|
||||
Args:
|
||||
prediction(P): The prediction data.
|
||||
contexts(Optional[Sequence[C]]): The context data.
|
||||
|
||||
Returns:
|
||||
BaseEvaluationResult: The evaluation result.
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.sync_compute, prediction, contexts
|
||||
)
|
||||
|
||||
def sync_compute(
|
||||
self,
|
||||
prediction: P,
|
||||
contexts: Optional[Sequence[C]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric.
|
||||
|
||||
Args:
|
||||
prediction(P): The prediction data.
|
||||
contexts(Optional[Sequence[C]]): The context data.
|
||||
|
||||
Returns:
|
||||
BaseEvaluationResult: The evaluation result.
|
||||
"""
|
||||
raise NotImplementedError("sync_compute is not implemented")
|
||||
|
||||
|
||||
class FunctionMetric(EvaluationMetric[P, C], Generic[P, C]):
|
||||
"""Evaluation metric based on a function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
func: Callable[
|
||||
[P, Optional[Sequence[C]]],
|
||||
BaseEvaluationResult,
|
||||
],
|
||||
):
|
||||
"""Create a FunctionMetric.
|
||||
|
||||
Args:
|
||||
name(str): The name of the metric.
|
||||
func(Callable[[P, Optional[Sequence[C]]], BaseEvaluationResult]):
|
||||
The function to use for evaluation.
|
||||
"""
|
||||
self._name = name
|
||||
self.func = func
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the metric."""
|
||||
return self._name
|
||||
|
||||
async def compute(
|
||||
self,
|
||||
prediction: P,
|
||||
context: Optional[Sequence[C]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric."""
|
||||
return self.func(prediction, context)
|
||||
|
||||
|
||||
class ExactMatchMetric(EvaluationMetric[str, str]):
|
||||
"""Exact match metric.
|
||||
|
||||
Just support string prediction and context.
|
||||
"""
|
||||
|
||||
def __init__(self, ignore_case: bool = False, ignore_punctuation: bool = False):
|
||||
"""Create an ExactMatchMetric."""
|
||||
self._ignore_case = ignore_case
|
||||
self._ignore_punctuation = ignore_punctuation
|
||||
|
||||
async def compute(
|
||||
self,
|
||||
prediction: str,
|
||||
contexts: Optional[Sequence[str]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric."""
|
||||
if self._ignore_case:
|
||||
prediction = prediction.lower()
|
||||
if contexts:
|
||||
contexts = [c.lower() for c in contexts]
|
||||
if self._ignore_punctuation:
|
||||
prediction = prediction.translate(str.maketrans("", "", string.punctuation))
|
||||
if contexts:
|
||||
contexts = [
|
||||
c.translate(str.maketrans("", "", string.punctuation))
|
||||
for c in contexts
|
||||
]
|
||||
score = 0 if not contexts else float(prediction in contexts)
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=score,
|
||||
)
|
||||
|
||||
|
||||
class SimilarityMetric(EvaluationMetric[str, str]):
|
||||
"""Similarity metric.
|
||||
|
||||
Calculate the cosine similarity between a prediction and a list of contexts.
|
||||
"""
|
||||
|
||||
def __init__(self, embeddings: Embeddings):
|
||||
"""Create a SimilarityMetric with embeddings."""
|
||||
self._embeddings = embeddings
|
||||
|
||||
def sync_compute(
|
||||
self,
|
||||
prediction: str,
|
||||
contexts: Optional[Sequence[str]] = None,
|
||||
) -> BaseEvaluationResult:
|
||||
"""Compute the evaluation metric."""
|
||||
if not contexts:
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=0.0,
|
||||
)
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError("numpy is required for SimilarityMetric")
|
||||
|
||||
similarity: np.ndarray = calculate_cosine_similarity(
|
||||
self._embeddings, prediction, contexts
|
||||
)
|
||||
return BaseEvaluationResult(
|
||||
prediction=prediction,
|
||||
contexts=contexts,
|
||||
score=float(similarity.mean()),
|
||||
)
|
||||
|
||||
|
||||
class Evaluator(ABC):
|
||||
"""Base Evaluator class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
):
|
||||
"""Create an Evaluator."""
|
||||
self.llm_client = llm_client
|
||||
|
||||
@abstractmethod
|
||||
async def evaluate(
|
||||
self,
|
||||
dataset: DatasetType,
|
||||
metrics: Optional[List[EvaluationMetric]] = None,
|
||||
query_key: str = "query",
|
||||
contexts_key: str = "contexts",
|
||||
prediction_key: str = "prediction",
|
||||
parallel_num: int = 1,
|
||||
**kwargs
|
||||
) -> List[List[EvaluationResult]]:
|
||||
"""Run evaluation with a dataset and metrics.
|
||||
|
||||
Args:
|
||||
dataset(DatasetType): The dataset to evaluate.
|
||||
metrics(Optional[List[EvaluationMetric]]): The metrics to use for
|
||||
evaluation.
|
||||
query_key(str): The key for query in the dataset.
|
||||
contexts_key(str): The key for contexts in the dataset.
|
||||
prediction_key(str): The key for prediction in the dataset.
|
||||
parallel_num(int): The number of parallel tasks.
|
||||
kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
List[List[EvaluationResult]]: The evaluation results, the length of the
|
||||
result equals to the length of the dataset. The first element in the
|
||||
list is the list of evaluation results for metrics.
|
||||
"""
|
Reference in New Issue
Block a user