mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
Add Support for Flexible Input Format for LLM and Chat Model Runs (#4805)
Previously, the client expected a strict 'prompt' or 'messages' format and wouldn't permit running a chat model or llm on prompts or messages (respectively). Since many datasets may want to specify custom key: string , relax this requirement. Also, add support for running a chat model on raw prompts and LLM on chat messages through their respective fallbacks.
This commit is contained in:
parent
a47c62fcba
commit
8dcad0f272
@ -39,7 +39,14 @@ from langchain.client.models import (
|
|||||||
ListRunsQueryParams,
|
ListRunsQueryParams,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.schema import ChatResult, LLMResult, messages_from_dict
|
from langchain.schema import (
|
||||||
|
BaseMessage,
|
||||||
|
ChatResult,
|
||||||
|
HumanMessage,
|
||||||
|
LLMResult,
|
||||||
|
get_buffer_string,
|
||||||
|
messages_from_dict,
|
||||||
|
)
|
||||||
from langchain.utils import raise_for_status_with_text, xor_args
|
from langchain.utils import raise_for_status_with_text, xor_args
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -50,6 +57,10 @@ logger = logging.getLogger(__name__)
|
|||||||
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
||||||
|
|
||||||
|
|
||||||
|
class InputFormatError(Exception):
|
||||||
|
"""Raised when input format is invalid."""
|
||||||
|
|
||||||
|
|
||||||
def _get_link_stem(url: str) -> str:
|
def _get_link_stem(url: str) -> str:
|
||||||
scheme = urlsplit(url).scheme
|
scheme = urlsplit(url).scheme
|
||||||
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
||||||
@ -389,6 +400,76 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
return [Example(**dataset) for dataset in response.json()]
|
return [Example(**dataset) for dataset in response.json()]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
|
||||||
|
"""Get prompts from inputs."""
|
||||||
|
if not inputs:
|
||||||
|
raise InputFormatError("Inputs should not be empty.")
|
||||||
|
|
||||||
|
prompts = []
|
||||||
|
|
||||||
|
if "prompt" in inputs:
|
||||||
|
if not isinstance(inputs["prompt"], str):
|
||||||
|
raise InputFormatError(
|
||||||
|
"Expected string for 'prompt', got"
|
||||||
|
f" {type(inputs['prompt']).__name__}"
|
||||||
|
)
|
||||||
|
prompts = [inputs["prompt"]]
|
||||||
|
elif "prompts" in inputs:
|
||||||
|
if not isinstance(inputs["prompts"], list) or not all(
|
||||||
|
isinstance(i, str) for i in inputs["prompts"]
|
||||||
|
):
|
||||||
|
raise InputFormatError(
|
||||||
|
"Expected list of strings for 'prompts',"
|
||||||
|
f" got {type(inputs['prompts']).__name__}"
|
||||||
|
)
|
||||||
|
prompts = inputs["prompts"]
|
||||||
|
elif len(inputs) == 1:
|
||||||
|
prompt_ = next(iter(inputs.values()))
|
||||||
|
if isinstance(prompt_, str):
|
||||||
|
prompts = [prompt_]
|
||||||
|
elif isinstance(prompt_, list) and all(isinstance(i, str) for i in prompt_):
|
||||||
|
prompts = prompt_
|
||||||
|
else:
|
||||||
|
raise InputFormatError(
|
||||||
|
f"LLM Run expects string prompt input. Got {inputs}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise InputFormatError(
|
||||||
|
f"LLM Run expects 'prompt' or 'prompts' in inputs. Got {inputs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
|
||||||
|
"""Get Chat Messages from inputs."""
|
||||||
|
if not inputs:
|
||||||
|
raise InputFormatError("Inputs should not be empty.")
|
||||||
|
|
||||||
|
if "messages" in inputs:
|
||||||
|
single_input = inputs["messages"]
|
||||||
|
elif len(inputs) == 1:
|
||||||
|
single_input = next(iter(inputs.values()))
|
||||||
|
else:
|
||||||
|
raise InputFormatError(
|
||||||
|
f"Chat Run expects 'messages' in inputs. Got {inputs}"
|
||||||
|
)
|
||||||
|
if isinstance(single_input, list) and all(
|
||||||
|
isinstance(i, dict) for i in single_input
|
||||||
|
):
|
||||||
|
raw_messages = [single_input]
|
||||||
|
elif isinstance(single_input, list) and all(
|
||||||
|
isinstance(i, list) for i in single_input
|
||||||
|
):
|
||||||
|
raw_messages = single_input
|
||||||
|
else:
|
||||||
|
raise InputFormatError(
|
||||||
|
f"Chat Run expects List[dict] or List[List[dict]] 'messages'"
|
||||||
|
f" input. Got {inputs}"
|
||||||
|
)
|
||||||
|
return [messages_from_dict(batch) for batch in raw_messages]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _arun_llm(
|
async def _arun_llm(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
@ -396,16 +477,31 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
langchain_tracer: LangChainTracer,
|
langchain_tracer: LangChainTracer,
|
||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
if isinstance(llm, BaseLLM):
|
if isinstance(llm, BaseLLM):
|
||||||
if "prompt" not in inputs:
|
try:
|
||||||
raise ValueError(f"LLM Run requires 'prompt' input. Got {inputs}")
|
llm_prompts = LangChainPlusClient._get_prompts(inputs)
|
||||||
llm_prompt: str = inputs["prompt"]
|
llm_output = await llm.agenerate(
|
||||||
llm_output = await llm.agenerate([llm_prompt], callbacks=[langchain_tracer])
|
llm_prompts, callbacks=[langchain_tracer]
|
||||||
|
)
|
||||||
|
except InputFormatError:
|
||||||
|
llm_messages = LangChainPlusClient._get_messages(inputs)
|
||||||
|
buffer_strings = [
|
||||||
|
get_buffer_string(messages) for messages in llm_messages
|
||||||
|
]
|
||||||
|
llm_output = await llm.agenerate(
|
||||||
|
buffer_strings, callbacks=[langchain_tracer]
|
||||||
|
)
|
||||||
elif isinstance(llm, BaseChatModel):
|
elif isinstance(llm, BaseChatModel):
|
||||||
if "messages" not in inputs:
|
try:
|
||||||
raise ValueError(f"Chat Run requires 'messages' input. Got {inputs}")
|
messages = LangChainPlusClient._get_messages(inputs)
|
||||||
raw_messages: List[dict] = inputs["messages"]
|
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
|
||||||
messages = messages_from_dict(raw_messages)
|
except InputFormatError:
|
||||||
llm_output = await llm.agenerate([messages], callbacks=[langchain_tracer])
|
prompts = LangChainPlusClient._get_prompts(inputs)
|
||||||
|
converted_messages: List[List[BaseMessage]] = [
|
||||||
|
[HumanMessage(content=prompt)] for prompt in prompts
|
||||||
|
]
|
||||||
|
llm_output = await llm.agenerate(
|
||||||
|
converted_messages, callbacks=[langchain_tracer]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||||
return llm_output
|
return llm_output
|
||||||
@ -562,18 +658,27 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
"""Run the language model on the example."""
|
"""Run the language model on the example."""
|
||||||
if isinstance(llm, BaseLLM):
|
if isinstance(llm, BaseLLM):
|
||||||
if "prompt" not in inputs:
|
try:
|
||||||
raise ValueError(f"LLM Run must contain 'prompt' key. Got {inputs}")
|
llm_prompts = LangChainPlusClient._get_prompts(inputs)
|
||||||
llm_prompt: str = inputs["prompt"]
|
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
|
||||||
llm_output = llm.generate([llm_prompt], callbacks=[langchain_tracer])
|
except InputFormatError:
|
||||||
|
llm_messages = LangChainPlusClient._get_messages(inputs)
|
||||||
|
buffer_strings = [
|
||||||
|
get_buffer_string(messages) for messages in llm_messages
|
||||||
|
]
|
||||||
|
llm_output = llm.generate(buffer_strings, callbacks=[langchain_tracer])
|
||||||
elif isinstance(llm, BaseChatModel):
|
elif isinstance(llm, BaseChatModel):
|
||||||
if "messages" not in inputs:
|
try:
|
||||||
raise ValueError(
|
messages = LangChainPlusClient._get_messages(inputs)
|
||||||
f"Chat Model Run must contain 'messages' key. Got {inputs}"
|
llm_output = llm.generate(messages, callbacks=[langchain_tracer])
|
||||||
|
except InputFormatError:
|
||||||
|
prompts = LangChainPlusClient._get_prompts(inputs)
|
||||||
|
converted_messages: List[List[BaseMessage]] = [
|
||||||
|
[HumanMessage(content=prompt)] for prompt in prompts
|
||||||
|
]
|
||||||
|
llm_output = llm.generate(
|
||||||
|
converted_messages, callbacks=[langchain_tracer]
|
||||||
)
|
)
|
||||||
raw_messages: List[dict] = inputs["messages"]
|
|
||||||
messages = messages_from_dict(raw_messages)
|
|
||||||
llm_output = llm.generate([messages], callbacks=[langchain_tracer])
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||||
return llm_output
|
return llm_output
|
||||||
|
@ -12,11 +12,14 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
|
|||||||
from langchain.callbacks.tracers.schemas import TracerSession
|
from langchain.callbacks.tracers.schemas import TracerSession
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.client.langchain import (
|
from langchain.client.langchain import (
|
||||||
|
InputFormatError,
|
||||||
LangChainPlusClient,
|
LangChainPlusClient,
|
||||||
_get_link_stem,
|
_get_link_stem,
|
||||||
_is_localhost,
|
_is_localhost,
|
||||||
)
|
)
|
||||||
from langchain.client.models import Dataset, Example
|
from langchain.client.models import Dataset, Example
|
||||||
|
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||||
@ -230,3 +233,85 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
for uuid_ in uuids
|
for uuid_ in uuids
|
||||||
}
|
}
|
||||||
assert results == expected
|
assert results == expected
|
||||||
|
|
||||||
|
|
||||||
|
_EXAMPLE_MESSAGE = {
|
||||||
|
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
|
||||||
|
"type": "human",
|
||||||
|
}
|
||||||
|
_VALID_MESSAGES = [
|
||||||
|
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||||
|
{"messages": [], "other_key": "value"},
|
||||||
|
{
|
||||||
|
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]],
|
||||||
|
"other_key": "value",
|
||||||
|
},
|
||||||
|
{"any_key": [_EXAMPLE_MESSAGE]},
|
||||||
|
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]},
|
||||||
|
]
|
||||||
|
_VALID_PROMPTS = [
|
||||||
|
{"prompts": ["foo", "bar", "baz"], "other_key": "value"},
|
||||||
|
{"prompt": "foo", "other_key": ["bar", "baz"]},
|
||||||
|
{"some_key": "foo"},
|
||||||
|
{"some_key": ["foo", "bar"]},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"inputs",
|
||||||
|
_VALID_MESSAGES,
|
||||||
|
)
|
||||||
|
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
|
||||||
|
{"messages": []}
|
||||||
|
LangChainPlusClient._get_messages(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"inputs",
|
||||||
|
_VALID_PROMPTS,
|
||||||
|
)
|
||||||
|
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
|
||||||
|
LangChainPlusClient._get_prompts(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"inputs",
|
||||||
|
[
|
||||||
|
{"prompts": "foo"},
|
||||||
|
{"prompt": ["foo"]},
|
||||||
|
{"some_key": 3},
|
||||||
|
{"some_key": "foo", "other_key": "bar"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
|
||||||
|
with pytest.raises(InputFormatError):
|
||||||
|
LangChainPlusClient._get_prompts(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"inputs",
|
||||||
|
[
|
||||||
|
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||||
|
{
|
||||||
|
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
|
||||||
|
"other_key": "value",
|
||||||
|
},
|
||||||
|
{"prompts": "foo"},
|
||||||
|
{},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
|
||||||
|
with pytest.raises(InputFormatError):
|
||||||
|
LangChainPlusClient._get_messages(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
|
||||||
|
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
|
||||||
|
llm = FakeLLM()
|
||||||
|
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
|
||||||
|
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
|
||||||
|
llm = FakeChatModel()
|
||||||
|
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())
|
||||||
|
Loading…
Reference in New Issue
Block a user