Add Support for Flexible Input Format for LLM and Chat Model Runs (#4805)

Previously, the client expected a strict 'prompt' or 'messages' format
and wouldn't permit running a chat model or llm on prompts or messages
(respectively).

Since many datasets may want to specify custom key: string , relax this
requirement.
Also, add support for running a chat model on raw prompts and LLM on
chat messages through their respective fallbacks.
This commit is contained in:
Zander Chase
2023-05-17 07:24:17 -07:00
committed by GitHub
parent a47c62fcba
commit 8dcad0f272
2 changed files with 210 additions and 20 deletions

View File

@@ -12,11 +12,14 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import TracerSession
from langchain.chains.base import Chain
from langchain.client.langchain import (
InputFormatError,
LangChainPlusClient,
_get_link_stem,
_is_localhost,
)
from langchain.client.models import Dataset, Example
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
@@ -230,3 +233,85 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
for uuid_ in uuids
}
assert results == expected
_EXAMPLE_MESSAGE = {
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
"type": "human",
}
_VALID_MESSAGES = [
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
{"messages": [], "other_key": "value"},
{
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]],
"other_key": "value",
},
{"any_key": [_EXAMPLE_MESSAGE]},
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]},
]
_VALID_PROMPTS = [
{"prompts": ["foo", "bar", "baz"], "other_key": "value"},
{"prompt": "foo", "other_key": ["bar", "baz"]},
{"some_key": "foo"},
{"some_key": ["foo", "bar"]},
]
@pytest.mark.parametrize(
"inputs",
_VALID_MESSAGES,
)
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
{"messages": []}
LangChainPlusClient._get_messages(inputs)
@pytest.mark.parametrize(
"inputs",
_VALID_PROMPTS,
)
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
LangChainPlusClient._get_prompts(inputs)
@pytest.mark.parametrize(
"inputs",
[
{"prompts": "foo"},
{"prompt": ["foo"]},
{"some_key": 3},
{"some_key": "foo", "other_key": "bar"},
],
)
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
with pytest.raises(InputFormatError):
LangChainPlusClient._get_prompts(inputs)
@pytest.mark.parametrize(
"inputs",
[
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
{
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
"other_key": "value",
},
{"prompts": "foo"},
{},
],
)
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
with pytest.raises(InputFormatError):
LangChainPlusClient._get_messages(inputs)
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
llm = FakeLLM()
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
llm = FakeChatModel()
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())