From b2eb4ff0fcbc60d1a216f04b6260a6d2225c6c7b Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Tue, 8 Aug 2023 11:59:30 -0700 Subject: [PATCH] Relax Validation in Eval (#8902) Just check for missing keys --- .../smith/evaluation/runner_utils.py | 12 +- .../smith/evaluation/test_runner_utils.py | 6 +- .../unit_tests/smith/test_runner_utils.py | 349 ------------------ 3 files changed, 8 insertions(+), 359 deletions(-) delete mode 100644 libs/langchain/tests/unit_tests/smith/test_runner_utils.py diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 2a2e0a8edef..5b3d5775c49 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -314,28 +314,28 @@ def _validate_example_inputs_for_chain( """Validate that the example inputs match the chain input keys.""" if input_mapper: first_inputs = input_mapper(first_example.inputs) + missing_keys = set(chain.input_keys).difference(first_inputs) if not isinstance(first_inputs, dict): raise InputFormatError( "When using an input_mapper to prepare dataset example" " inputs for a chain, the mapped value must be a dictionary." f"\nGot: {first_inputs} of type {type(first_inputs)}." ) - if not set(first_inputs.keys()) == set(chain.input_keys): + if missing_keys: raise InputFormatError( - "When using an input_mapper to prepare dataset example inputs" - " for a chain mapped value must have keys that match the chain's" - " expected input keys." + "Missing keys after loading example using input_mapper." f"\nExpected: {chain.input_keys}. Got: {first_inputs.keys()}" ) else: first_inputs = first_example.inputs + missing_keys = set(chain.input_keys).difference(first_inputs) if len(first_inputs) == 1 and len(chain.input_keys) == 1: # We can pass this through the run method. # Refrain from calling to validate. pass - elif not set(first_inputs.keys()) == set(chain.input_keys): + elif missing_keys: raise InputFormatError( - "Example inputs do not match chain input keys." + "Example inputs missing expected chain input keys." " Please provide an input_mapper to convert the example.inputs" " to a compatible format for the chain you wish to evaluate." f"Expected: {chain.input_keys}. " diff --git a/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py b/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py index 3bf9a1d32c8..de7f9e9434c 100644 --- a/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py +++ b/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py @@ -124,7 +124,7 @@ def test__validate_example_inputs_for_chain_input_mapper() -> None: assert "baz" in inputs return {"not foo": "foo", "not baz": "baz"} - with pytest.raises(InputFormatError, match="keys that match"): + with pytest.raises(InputFormatError, match="Missing keys after loading example"): _validate_example_inputs_for_chain(mock_, chain, wrong_output_keys) def input_mapper(inputs: dict) -> dict: @@ -148,9 +148,7 @@ def test__validate_example_inputs_for_chain_single_input_multi_expect() -> None: mock_.inputs = {"foo": "bar"} chain = mock.MagicMock() chain.input_keys = ["def not foo", "oh here is another"] - with pytest.raises( - InputFormatError, match="Example inputs do not match chain input keys." - ): + with pytest.raises(InputFormatError, match="Example inputs missing expected"): _validate_example_inputs_for_chain(mock_, chain, None) diff --git a/libs/langchain/tests/unit_tests/smith/test_runner_utils.py b/libs/langchain/tests/unit_tests/smith/test_runner_utils.py deleted file mode 100644 index 3bf9a1d32c8..00000000000 --- a/libs/langchain/tests/unit_tests/smith/test_runner_utils.py +++ /dev/null @@ -1,349 +0,0 @@ -"""Test the LangSmith evaluation helpers.""" -import uuid -from datetime import datetime -from typing import Any, Dict, Iterator, List, Optional, Union -from unittest import mock - -import pytest -from langsmith.client import Client -from langsmith.schemas import Dataset, Example - -from langchain.chains.base import Chain -from langchain.chains.transform import TransformChain -from langchain.schema.language_model import BaseLanguageModel -from langchain.smith.evaluation.runner_utils import ( - InputFormatError, - _get_messages, - _get_prompt, - _run_llm, - _run_llm_or_chain, - _validate_example_inputs_for_chain, - _validate_example_inputs_for_language_model, - arun_on_dataset, -) -from tests.unit_tests.llms.fake_chat_model import FakeChatModel -from tests.unit_tests.llms.fake_llm import FakeLLM - -_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0) -_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4" -_EXAMPLE_MESSAGE = { - "data": {"content": "Foo", "example": False, "additional_kwargs": {}}, - "type": "human", -} -_VALID_MESSAGES = [ - {"messages": [_EXAMPLE_MESSAGE], "other_key": "value"}, - {"messages": [], "other_key": "value"}, - { - "messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE]], - "other_key": "value", - }, - {"any_key": [_EXAMPLE_MESSAGE]}, - {"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE]]}, -] -_VALID_PROMPTS = [ - {"prompts": ["foo"], "other_key": "value"}, - {"prompt": "foo", "other_key": ["bar", "baz"]}, - {"some_key": "foo"}, - {"some_key": ["foo"]}, -] - -_INVALID_PROMPTS = ( - [ - {"prompts": "foo"}, - {"prompt": ["foo"]}, - {"some_key": 3}, - {"some_key": "foo", "other_key": "bar"}, - ], -) - - -@pytest.mark.parametrize( - "inputs", - _VALID_MESSAGES, -) -def test__get_messages_valid(inputs: Dict[str, Any]) -> None: - {"messages": []} - _get_messages(inputs) - - -@pytest.mark.parametrize( - "inputs", - _VALID_PROMPTS, -) -def test__get_prompts_valid(inputs: Dict[str, Any]) -> None: - _get_prompt(inputs) - - -@pytest.mark.parametrize( - "inputs", - _VALID_PROMPTS, -) -def test__validate_example_inputs_for_language_model(inputs: Dict[str, Any]) -> None: - mock_ = mock.MagicMock() - mock_.inputs = inputs - _validate_example_inputs_for_language_model(mock_, None) - - -@pytest.mark.parametrize( - "inputs", - _INVALID_PROMPTS, -) -def test__validate_example_inputs_for_language_model_invalid( - inputs: Dict[str, Any] -) -> None: - mock_ = mock.MagicMock() - mock_.inputs = inputs - with pytest.raises(InputFormatError): - _validate_example_inputs_for_language_model(mock_, None) - - -def test__validate_example_inputs_for_chain_single_input() -> None: - mock_ = mock.MagicMock() - mock_.inputs = {"foo": "bar"} - chain = mock.MagicMock() - chain.input_keys = ["def not foo"] - _validate_example_inputs_for_chain(mock_, chain, None) - - -def test__validate_example_inputs_for_chain_input_mapper() -> None: - mock_ = mock.MagicMock() - mock_.inputs = {"foo": "bar", "baz": "qux"} - chain = mock.MagicMock() - chain.input_keys = ["not foo", "not baz", "not qux"] - - def wrong_output_format(inputs: dict) -> str: - assert "foo" in inputs - assert "baz" in inputs - return "hehe" - - with pytest.raises(InputFormatError, match="must be a dictionary"): - _validate_example_inputs_for_chain(mock_, chain, wrong_output_format) - - def wrong_output_keys(inputs: dict) -> dict: - assert "foo" in inputs - assert "baz" in inputs - return {"not foo": "foo", "not baz": "baz"} - - with pytest.raises(InputFormatError, match="keys that match"): - _validate_example_inputs_for_chain(mock_, chain, wrong_output_keys) - - def input_mapper(inputs: dict) -> dict: - assert "foo" in inputs - assert "baz" in inputs - return {"not foo": inputs["foo"], "not baz": inputs["baz"], "not qux": "qux"} - - _validate_example_inputs_for_chain(mock_, chain, input_mapper) - - -def test__validate_example_inputs_for_chain_multi_io() -> None: - mock_ = mock.MagicMock() - mock_.inputs = {"foo": "bar", "baz": "qux"} - chain = mock.MagicMock() - chain.input_keys = ["foo", "baz"] - _validate_example_inputs_for_chain(mock_, chain, None) - - -def test__validate_example_inputs_for_chain_single_input_multi_expect() -> None: - mock_ = mock.MagicMock() - mock_.inputs = {"foo": "bar"} - chain = mock.MagicMock() - chain.input_keys = ["def not foo", "oh here is another"] - with pytest.raises( - InputFormatError, match="Example inputs do not match chain input keys." - ): - _validate_example_inputs_for_chain(mock_, chain, None) - - -@pytest.mark.parametrize("inputs", _INVALID_PROMPTS) -def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None: - with pytest.raises(InputFormatError): - _get_prompt(inputs) - - -def test_run_llm_or_chain_with_input_mapper() -> None: - example = Example( - id=uuid.uuid4(), - created_at=_CREATED_AT, - inputs={"the wrong input": "1", "another key": "2"}, - outputs={"output": "2"}, - dataset_id=str(uuid.uuid4()), - ) - - def run_val(inputs: dict) -> dict: - assert "the right input" in inputs - return {"output": "2"} - - mock_chain = TransformChain( - input_variables=["the right input"], - output_variables=["output"], - transform=run_val, - ) - - def input_mapper(inputs: dict) -> dict: - assert "the wrong input" in inputs - return {"the right input": inputs["the wrong input"]} - - result = _run_llm_or_chain( - example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper - ) - assert len(result) == 1 - assert result[0] == {"output": "2", "the right input": "1"} - bad_result = _run_llm_or_chain( - example, - lambda: mock_chain, - n_repetitions=1, - ) - assert len(bad_result) == 1 - assert "Error" in bad_result[0] - - # Try with LLM - def llm_input_mapper(inputs: dict) -> str: - assert "the wrong input" in inputs - return "the right input" - - mock_llm = FakeLLM(queries={"the right input": "somenumber"}) - result = _run_llm_or_chain( - example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper - ) - assert len(result) == 1 - llm_result = result[0] - assert isinstance(llm_result, str) - assert llm_result == "somenumber" - - -@pytest.mark.parametrize( - "inputs", - [ - {"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"}, - { - "messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE], - "other_key": "value", - }, - {"prompts": "foo"}, - {}, - ], -) -def test__get_messages_invalid(inputs: Dict[str, Any]) -> None: - with pytest.raises(InputFormatError): - _get_messages(inputs) - - -@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES) -def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None: - llm = FakeLLM() - _run_llm(llm, inputs, mock.MagicMock()) - - -@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS) -def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None: - llm = FakeChatModel() - _run_llm(llm, inputs, mock.MagicMock()) - - -@pytest.mark.asyncio -async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: - dataset = Dataset( - id=uuid.uuid4(), - name="test", - description="Test dataset", - owner_id="owner", - created_at=_CREATED_AT, - tenant_id=_TENANT_ID, - ) - uuids = [ - "0c193153-2309-4704-9a47-17aee4fb25c8", - "0d11b5fd-8e66-4485-b696-4b55155c0c05", - "90d696f0-f10d-4fd0-b88b-bfee6df08b84", - "4ce2c6d8-5124-4c0c-8292-db7bdebcf167", - "7b5a524c-80fa-4960-888e-7d380f9a11ee", - ] - examples = [ - Example( - id=uuids[0], - created_at=_CREATED_AT, - inputs={"input": "1"}, - outputs={"output": "2"}, - dataset_id=str(uuid.uuid4()), - ), - Example( - id=uuids[1], - created_at=_CREATED_AT, - inputs={"input": "3"}, - outputs={"output": "4"}, - dataset_id=str(uuid.uuid4()), - ), - Example( - id=uuids[2], - created_at=_CREATED_AT, - inputs={"input": "5"}, - outputs={"output": "6"}, - dataset_id=str(uuid.uuid4()), - ), - Example( - id=uuids[3], - created_at=_CREATED_AT, - inputs={"input": "7"}, - outputs={"output": "8"}, - dataset_id=str(uuid.uuid4()), - ), - Example( - id=uuids[4], - created_at=_CREATED_AT, - inputs={"input": "9"}, - outputs={"output": "10"}, - dataset_id=str(uuid.uuid4()), - ), - ] - - def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset: - return dataset - - def mock_list_examples(*args: Any, **kwargs: Any) -> Iterator[Example]: - return iter(examples) - - async def mock_arun_chain( - example: Example, - llm_or_chain: Union[BaseLanguageModel, Chain], - n_repetitions: int, - tags: Optional[List[str]] = None, - callbacks: Optional[Any] = None, - **kwargs: Any, - ) -> List[Dict[str, Any]]: - return [ - {"result": f"Result for example {example.id}"} for _ in range(n_repetitions) - ] - - def mock_create_project(*args: Any, **kwargs: Any) -> Any: - proj = mock.MagicMock() - proj.id = "123" - return proj - - with mock.patch.object( - Client, "read_dataset", new=mock_read_dataset - ), mock.patch.object(Client, "list_examples", new=mock_list_examples), mock.patch( - "langchain.smith.evaluation.runner_utils._arun_llm_or_chain", - new=mock_arun_chain, - ), mock.patch.object( - Client, "create_project", new=mock_create_project - ): - client = Client(api_url="http://localhost:1984", api_key="123") - chain = mock.MagicMock() - chain.input_keys = ["foothing"] - num_repetitions = 3 - results = await arun_on_dataset( - dataset_name="test", - llm_or_chain_factory=lambda: chain, - concurrency_level=2, - project_name="test_project", - num_repetitions=num_repetitions, - client=client, - ) - - expected = { - uuid_: [ - {"result": f"Result for example {uuid.UUID(uuid_)}"} - for _ in range(num_repetitions) - ] - for uuid_ in uuids - } - assert results["results"] == expected