Files
langchain/tests/unit_tests/smith/test_runner_utils.py
William FH c58d35765d Add examples to docstrings (#7796)
and:
- remove dataset name from autogenerated project name
- print out project name to view
2023-07-16 12:05:56 -07:00

350 lines
10 KiB
Python

"""Test the LangSmith evaluation helpers."""
import uuid
from datetime import datetime
from typing import Any, Dict, Iterator, List, Optional, Union
from unittest import mock
import pytest
from langsmith.client import Client
from langsmith.schemas import Dataset, Example
from langchain.chains.base import Chain
from langchain.chains.transform import TransformChain
from langchain.schema.language_model import BaseLanguageModel
from langchain.smith.evaluation.runner_utils import (
InputFormatError,
_get_messages,
_get_prompt,
_run_llm,
_run_llm_or_chain,
_validate_example_inputs_for_chain,
_validate_example_inputs_for_language_model,
arun_on_dataset,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
_EXAMPLE_MESSAGE = {
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
"type": "human",
}
_VALID_MESSAGES = [
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
{"messages": [], "other_key": "value"},
{
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE]],
"other_key": "value",
},
{"any_key": [_EXAMPLE_MESSAGE]},
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE]]},
]
_VALID_PROMPTS = [
{"prompts": ["foo"], "other_key": "value"},
{"prompt": "foo", "other_key": ["bar", "baz"]},
{"some_key": "foo"},
{"some_key": ["foo"]},
]
_INVALID_PROMPTS = (
[
{"prompts": "foo"},
{"prompt": ["foo"]},
{"some_key": 3},
{"some_key": "foo", "other_key": "bar"},
],
)
@pytest.mark.parametrize(
"inputs",
_VALID_MESSAGES,
)
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
{"messages": []}
_get_messages(inputs)
@pytest.mark.parametrize(
"inputs",
_VALID_PROMPTS,
)
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
_get_prompt(inputs)
@pytest.mark.parametrize(
"inputs",
_VALID_PROMPTS,
)
def test__validate_example_inputs_for_language_model(inputs: Dict[str, Any]) -> None:
mock_ = mock.MagicMock()
mock_.inputs = inputs
_validate_example_inputs_for_language_model(mock_, None)
@pytest.mark.parametrize(
"inputs",
_INVALID_PROMPTS,
)
def test__validate_example_inputs_for_language_model_invalid(
inputs: Dict[str, Any]
) -> None:
mock_ = mock.MagicMock()
mock_.inputs = inputs
with pytest.raises(InputFormatError):
_validate_example_inputs_for_language_model(mock_, None)
def test__validate_example_inputs_for_chain_single_input() -> None:
mock_ = mock.MagicMock()
mock_.inputs = {"foo": "bar"}
chain = mock.MagicMock()
chain.input_keys = ["def not foo"]
_validate_example_inputs_for_chain(mock_, chain, None)
def test__validate_example_inputs_for_chain_input_mapper() -> None:
mock_ = mock.MagicMock()
mock_.inputs = {"foo": "bar", "baz": "qux"}
chain = mock.MagicMock()
chain.input_keys = ["not foo", "not baz", "not qux"]
def wrong_output_format(inputs: dict) -> str:
assert "foo" in inputs
assert "baz" in inputs
return "hehe"
with pytest.raises(InputFormatError, match="must be a dictionary"):
_validate_example_inputs_for_chain(mock_, chain, wrong_output_format)
def wrong_output_keys(inputs: dict) -> dict:
assert "foo" in inputs
assert "baz" in inputs
return {"not foo": "foo", "not baz": "baz"}
with pytest.raises(InputFormatError, match="keys that match"):
_validate_example_inputs_for_chain(mock_, chain, wrong_output_keys)
def input_mapper(inputs: dict) -> dict:
assert "foo" in inputs
assert "baz" in inputs
return {"not foo": inputs["foo"], "not baz": inputs["baz"], "not qux": "qux"}
_validate_example_inputs_for_chain(mock_, chain, input_mapper)
def test__validate_example_inputs_for_chain_multi_io() -> None:
mock_ = mock.MagicMock()
mock_.inputs = {"foo": "bar", "baz": "qux"}
chain = mock.MagicMock()
chain.input_keys = ["foo", "baz"]
_validate_example_inputs_for_chain(mock_, chain, None)
def test__validate_example_inputs_for_chain_single_input_multi_expect() -> None:
mock_ = mock.MagicMock()
mock_.inputs = {"foo": "bar"}
chain = mock.MagicMock()
chain.input_keys = ["def not foo", "oh here is another"]
with pytest.raises(
InputFormatError, match="Example inputs do not match chain input keys."
):
_validate_example_inputs_for_chain(mock_, chain, None)
@pytest.mark.parametrize("inputs", _INVALID_PROMPTS)
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
with pytest.raises(InputFormatError):
_get_prompt(inputs)
def test_run_llm_or_chain_with_input_mapper() -> None:
example = Example(
id=uuid.uuid4(),
created_at=_CREATED_AT,
inputs={"the wrong input": "1", "another key": "2"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
)
def run_val(inputs: dict) -> dict:
assert "the right input" in inputs
return {"output": "2"}
mock_chain = TransformChain(
input_variables=["the right input"],
output_variables=["output"],
transform=run_val,
)
def input_mapper(inputs: dict) -> dict:
assert "the wrong input" in inputs
return {"the right input": inputs["the wrong input"]}
result = _run_llm_or_chain(
example, lambda: mock_chain, n_repetitions=1, input_mapper=input_mapper
)
assert len(result) == 1
assert result[0] == {"output": "2", "the right input": "1"}
bad_result = _run_llm_or_chain(
example,
lambda: mock_chain,
n_repetitions=1,
)
assert len(bad_result) == 1
assert "Error" in bad_result[0]
# Try with LLM
def llm_input_mapper(inputs: dict) -> str:
assert "the wrong input" in inputs
return "the right input"
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
result = _run_llm_or_chain(
example, mock_llm, n_repetitions=1, input_mapper=llm_input_mapper
)
assert len(result) == 1
llm_result = result[0]
assert isinstance(llm_result, str)
assert llm_result == "somenumber"
@pytest.mark.parametrize(
"inputs",
[
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
{
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
"other_key": "value",
},
{"prompts": "foo"},
{},
],
)
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
with pytest.raises(InputFormatError):
_get_messages(inputs)
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
llm = FakeLLM()
_run_llm(llm, inputs, mock.MagicMock())
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
llm = FakeChatModel()
_run_llm(llm, inputs, mock.MagicMock())
@pytest.mark.asyncio
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = Dataset(
id=uuid.uuid4(),
name="test",
description="Test dataset",
owner_id="owner",
created_at=_CREATED_AT,
tenant_id=_TENANT_ID,
)
uuids = [
"0c193153-2309-4704-9a47-17aee4fb25c8",
"0d11b5fd-8e66-4485-b696-4b55155c0c05",
"90d696f0-f10d-4fd0-b88b-bfee6df08b84",
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167",
"7b5a524c-80fa-4960-888e-7d380f9a11ee",
]
examples = [
Example(
id=uuids[0],
created_at=_CREATED_AT,
inputs={"input": "1"},
outputs={"output": "2"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[1],
created_at=_CREATED_AT,
inputs={"input": "3"},
outputs={"output": "4"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[2],
created_at=_CREATED_AT,
inputs={"input": "5"},
outputs={"output": "6"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[3],
created_at=_CREATED_AT,
inputs={"input": "7"},
outputs={"output": "8"},
dataset_id=str(uuid.uuid4()),
),
Example(
id=uuids[4],
created_at=_CREATED_AT,
inputs={"input": "9"},
outputs={"output": "10"},
dataset_id=str(uuid.uuid4()),
),
]
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
return dataset
def mock_list_examples(*args: Any, **kwargs: Any) -> Iterator[Example]:
return iter(examples)
async def mock_arun_chain(
example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain],
n_repetitions: int,
tags: Optional[List[str]] = None,
callbacks: Optional[Any] = None,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return [
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
]
def mock_create_project(*args: Any, **kwargs: Any) -> Any:
proj = mock.MagicMock()
proj.id = "123"
return proj
with mock.patch.object(
Client, "read_dataset", new=mock_read_dataset
), mock.patch.object(Client, "list_examples", new=mock_list_examples), mock.patch(
"langchain.smith.evaluation.runner_utils._arun_llm_or_chain",
new=mock_arun_chain,
), mock.patch.object(
Client, "create_project", new=mock_create_project
):
client = Client(api_url="http://localhost:1984", api_key="123")
chain = mock.MagicMock()
chain.input_keys = ["foothing"]
num_repetitions = 3
results = await arun_on_dataset(
dataset_name="test",
llm_or_chain_factory=lambda: chain,
concurrency_level=2,
project_name="test_project",
num_repetitions=num_repetitions,
client=client,
)
expected = {
uuid_: [
{"result": f"Result for example {uuid.UUID(uuid_)}"}
for _ in range(num_repetitions)
]
for uuid_ in uuids
}
assert results["results"] == expected