Files
langchain/tests/integration_tests/evals/unit_test.py
vowelparrot 1dd6b99b1b update
2023-06-13 22:26:57 -07:00

68 lines
2.3 KiB
Python

from pathlib import Path
from typing import Tuple
from uuid import uuid4
import pytest
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
from langchainplus_sdk.schemas import Example
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.chat_models import ChatOpenAI
from langchain.evaluation.run_evaluators.implementations import (
get_criteria_evaluator,
get_qa_evaluator,
)
from langchain.sql_database import SQLDatabase
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
from langchain.callbacks.tracers.schemas import Run
_DIR = Path(__file__).parent.resolve()
_TEST_RUN_ID = uuid4().hex
_CLIENT = LangChainPlusClient()
_EVAL_LLM = ChatOpenAI(temperature=0)
_EVALUATORS = [
get_qa_evaluator(
_EVAL_LLM, input_key="query", prediction_key="result", answer_key="answer"
),
get_criteria_evaluator(
_EVAL_LLM, "helpfulness", input_key="query", prediction_key="result"
),
]
@pytest.fixture(scope="module")
def database() -> SQLDatabase:
return SQLDatabase.from_uri(f"sqlite:///{_DIR}/data/Chinook.db")
@pytest.fixture(scope="module")
def chain_to_test(database: SQLDatabase) -> SQLDatabaseChain:
llm = ChatOpenAI(temperature=0.0)
return SQLDatabaseChain.from_llm(llm, database)
@pytest.fixture(
scope="module", params=_CLIENT.list_examples(dataset_name="sql-qa-chinook")
)
def run_example_pair(request, chain_to_test: SQLDatabaseChain) -> Tuple[Run, Example]:
example: Example = request.param
# TODO: Add context manager for this
run_stack = RunCollectorCallbackHandler()
with tracing_v2_enabled(
session_name=f"test_chain_on_example-{_TEST_RUN_ID}", example_id=example.id
):
chain_to_test(example.inputs, callbacks=[run_stack])
return (run_stack.pop(), example)
@pytest.mark.parametrize("evaluator", _EVALUATORS)
def test_chain_on_example(
evaluator: RunEvaluator, run_example_pair: Tuple[Run, Example]
) -> None:
run, example = run_example_pair
evaluation_result = _CLIENT.evaluate_run(run, evaluator, reference_example=example)
assert (
evaluation_result.score == 1
), f"My Chain failed evaluation {evaluation_result.key}\n\n{evaluation_result}"