mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-10 07:21:03 +00:00
Separate Runner Functions from Client (#5079)
Extract the methods specific to running an LLM or Chain on a dataset to separate utility functions. This simplifies the client a bit and lets us separate concerns of LCP details from running examples (e.g., for evals)
This commit is contained in:
@@ -12,14 +12,11 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.schemas import TracerSession
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.client.langchain import (
|
||||
InputFormatError,
|
||||
LangChainPlusClient,
|
||||
_get_link_stem,
|
||||
_is_localhost,
|
||||
)
|
||||
from langchain.client.models import Dataset, Example
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||
@@ -191,9 +188,9 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
async def mock_arun_chain(
|
||||
example: Example,
|
||||
tracer: Any,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
n_repetitions: int,
|
||||
tracer: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||
@@ -206,8 +203,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||
), mock.patch.object(
|
||||
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
|
||||
), mock.patch(
|
||||
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
||||
), mock.patch.object(
|
||||
LangChainTracer, "ensure_session", new=mock_ensure_session
|
||||
):
|
||||
@@ -233,85 +230,3 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for uuid_ in uuids
|
||||
}
|
||||
assert results == expected
|
||||
|
||||
|
||||
_EXAMPLE_MESSAGE = {
|
||||
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
|
||||
"type": "human",
|
||||
}
|
||||
_VALID_MESSAGES = [
|
||||
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||
{"messages": [], "other_key": "value"},
|
||||
{
|
||||
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]],
|
||||
"other_key": "value",
|
||||
},
|
||||
{"any_key": [_EXAMPLE_MESSAGE]},
|
||||
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]},
|
||||
]
|
||||
_VALID_PROMPTS = [
|
||||
{"prompts": ["foo", "bar", "baz"], "other_key": "value"},
|
||||
{"prompt": "foo", "other_key": ["bar", "baz"]},
|
||||
{"some_key": "foo"},
|
||||
{"some_key": ["foo", "bar"]},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
_VALID_MESSAGES,
|
||||
)
|
||||
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
|
||||
{"messages": []}
|
||||
LangChainPlusClient._get_messages(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
_VALID_PROMPTS,
|
||||
)
|
||||
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
|
||||
LangChainPlusClient._get_prompts(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
{"prompts": "foo"},
|
||||
{"prompt": ["foo"]},
|
||||
{"some_key": 3},
|
||||
{"some_key": "foo", "other_key": "bar"},
|
||||
],
|
||||
)
|
||||
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
|
||||
with pytest.raises(InputFormatError):
|
||||
LangChainPlusClient._get_prompts(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||
{
|
||||
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
|
||||
"other_key": "value",
|
||||
},
|
||||
{"prompts": "foo"},
|
||||
{},
|
||||
],
|
||||
)
|
||||
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
|
||||
with pytest.raises(InputFormatError):
|
||||
LangChainPlusClient._get_messages(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
|
||||
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
|
||||
llm = FakeLLM()
|
||||
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
|
||||
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
|
||||
llm = FakeChatModel()
|
||||
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())
|
||||
|
95
tests/unit_tests/client/test_runner_utils.py
Normal file
95
tests/unit_tests/client/test_runner_utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Test the LangChain+ client."""
|
||||
from typing import Any, Dict
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.client.runner_utils import (
|
||||
InputFormatError,
|
||||
_get_messages,
|
||||
_get_prompts,
|
||||
run_llm,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
_EXAMPLE_MESSAGE = {
|
||||
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
|
||||
"type": "human",
|
||||
}
|
||||
_VALID_MESSAGES = [
|
||||
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||
{"messages": [], "other_key": "value"},
|
||||
{
|
||||
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]],
|
||||
"other_key": "value",
|
||||
},
|
||||
{"any_key": [_EXAMPLE_MESSAGE]},
|
||||
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]},
|
||||
]
|
||||
_VALID_PROMPTS = [
|
||||
{"prompts": ["foo", "bar", "baz"], "other_key": "value"},
|
||||
{"prompt": "foo", "other_key": ["bar", "baz"]},
|
||||
{"some_key": "foo"},
|
||||
{"some_key": ["foo", "bar"]},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
_VALID_MESSAGES,
|
||||
)
|
||||
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
|
||||
{"messages": []}
|
||||
_get_messages(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
_VALID_PROMPTS,
|
||||
)
|
||||
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
|
||||
_get_prompts(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
{"prompts": "foo"},
|
||||
{"prompt": ["foo"]},
|
||||
{"some_key": 3},
|
||||
{"some_key": "foo", "other_key": "bar"},
|
||||
],
|
||||
)
|
||||
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
|
||||
with pytest.raises(InputFormatError):
|
||||
_get_prompts(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs",
|
||||
[
|
||||
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||
{
|
||||
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
|
||||
"other_key": "value",
|
||||
},
|
||||
{"prompts": "foo"},
|
||||
{},
|
||||
],
|
||||
)
|
||||
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
|
||||
with pytest.raises(InputFormatError):
|
||||
_get_messages(inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
|
||||
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
|
||||
llm = FakeLLM()
|
||||
run_llm(llm, inputs, mock.MagicMock())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
|
||||
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
|
||||
llm = FakeChatModel()
|
||||
run_llm(llm, inputs, mock.MagicMock())
|
Reference in New Issue
Block a user