Compare commits

...

10 Commits

Author SHA1 Message Date
vowelparrot
1dd6b99b1b update 2023-06-13 22:26:57 -07:00
vowelparrot
bc476378d3 Merge branch 'master' into vwp/drafts/unit_testing 2023-06-13 22:26:10 -07:00
vowelparrot
cc210567ed Merge branch 'master' into vwp/drafts/unit_testing 2023-06-13 22:18:10 -07:00
vowelparrot
8f8df2d2fc Merge branch 'vwp/run_collector' into vwp/drafts/unit_testing 2023-06-13 22:17:53 -07:00
vowelparrot
bde440eb16 Add Run Collector Callback 2023-06-13 19:35:11 -07:00
vowelparrot
d6e8261edd Update 2023-06-13 17:01:12 -07:00
vowelparrot
4d692a8cad add test 2023-06-13 15:24:25 -07:00
vowelparrot
1016b7ecad Merge branch 'vwp/return_session_name' into vwp/drafts/unit_testing 2023-06-13 14:40:07 -07:00
vowelparrot
62cfabaaac Example 2023-06-13 14:39:22 -07:00
vowelparrot
dc922fc501 Return session name in runner response 2023-06-13 13:38:08 -07:00
2 changed files with 67 additions and 0 deletions

Binary file not shown.

View File

@@ -0,0 +1,67 @@
from pathlib import Path
from typing import Tuple
from uuid import uuid4
import pytest
from langchainplus_sdk import LangChainPlusClient, RunEvaluator
from langchainplus_sdk.schemas import Example
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.chat_models import ChatOpenAI
from langchain.evaluation.run_evaluators.implementations import (
get_criteria_evaluator,
get_qa_evaluator,
)
from langchain.sql_database import SQLDatabase
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
from langchain.callbacks.tracers.schemas import Run
_DIR = Path(__file__).parent.resolve()
_TEST_RUN_ID = uuid4().hex
_CLIENT = LangChainPlusClient()
_EVAL_LLM = ChatOpenAI(temperature=0)
_EVALUATORS = [
get_qa_evaluator(
_EVAL_LLM, input_key="query", prediction_key="result", answer_key="answer"
),
get_criteria_evaluator(
_EVAL_LLM, "helpfulness", input_key="query", prediction_key="result"
),
]
@pytest.fixture(scope="module")
def database() -> SQLDatabase:
return SQLDatabase.from_uri(f"sqlite:///{_DIR}/data/Chinook.db")
@pytest.fixture(scope="module")
def chain_to_test(database: SQLDatabase) -> SQLDatabaseChain:
llm = ChatOpenAI(temperature=0.0)
return SQLDatabaseChain.from_llm(llm, database)
@pytest.fixture(
scope="module", params=_CLIENT.list_examples(dataset_name="sql-qa-chinook")
)
def run_example_pair(request, chain_to_test: SQLDatabaseChain) -> Tuple[Run, Example]:
example: Example = request.param
# TODO: Add context manager for this
run_stack = RunCollectorCallbackHandler()
with tracing_v2_enabled(
session_name=f"test_chain_on_example-{_TEST_RUN_ID}", example_id=example.id
):
chain_to_test(example.inputs, callbacks=[run_stack])
return (run_stack.pop(), example)
@pytest.mark.parametrize("evaluator", _EVALUATORS)
def test_chain_on_example(
evaluator: RunEvaluator, run_example_pair: Tuple[Run, Example]
) -> None:
run, example = run_example_pair
evaluation_result = _CLIENT.evaluate_run(run, evaluator, reference_example=example)
assert (
evaluation_result.score == 1
), f"My Chain failed evaluation {evaluation_result.key}\n\n{evaluation_result}"