mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-30 00:04:19 +00:00
standard-tests: add test for structured output (#23631)
- add test for structured output - fix bug with structured output for Azure - better testing on Groq (break out Mixtral + Llama3 and add xfails where needed)
This commit is contained in:
parent
6c1ba9731d
commit
390ee8d971
@ -9,7 +9,7 @@ from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
from langchain_groq import ChatGroq
|
||||
|
||||
|
||||
class TestGroqStandard(ChatModelIntegrationTests):
|
||||
class BaseTestGroq(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatGroq
|
||||
@ -17,3 +17,32 @@ class TestGroqStandard(ChatModelIntegrationTests):
|
||||
@pytest.mark.xfail(reason="Not yet implemented.")
|
||||
def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_message_histories_list_content(model)
|
||||
|
||||
|
||||
class TestGroqMixtral(BaseTestGroq):
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
|
||||
)
|
||||
def test_structured_output(self, model: BaseChatModel) -> None:
|
||||
super().test_structured_output(model)
|
||||
|
||||
|
||||
class TestGroqLlama(BaseTestGroq):
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "llama3-8b-8192",
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("Fails with 'Failed to call a function. Please adjust your prompt.'")
|
||||
)
|
||||
def test_tool_message_histories_string_content(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_message_histories_string_content(model)
|
||||
|
@ -3,15 +3,35 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Type, Union
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
@ -21,6 +41,21 @@ from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
|
||||
|
||||
class _AllReturnType(TypedDict):
|
||||
raw: BaseMessage
|
||||
parsed: Optional[_DictOrPydantic]
|
||||
parsing_error: Optional[BaseException]
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
|
||||
|
||||
class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"""`Azure OpenAI` Chat Completion API.
|
||||
|
||||
@ -233,6 +268,250 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
tool_choice = convert_to_openai_tool(tools[0])["function"]["name"]
|
||||
return super().bind_tools(tools, tool_choice=tool_choice, **kwargs)
|
||||
|
||||
# TODO: Fix typing.
|
||||
@overload # type: ignore[override]
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: Literal[True] = True,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _AllReturnType]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
...
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
|
||||
then the model output will be an object of that class. If a dict then
|
||||
the model output will be a dict. With a Pydantic class the returned
|
||||
attributes will be validated, whereas with a dict they will not be. If
|
||||
`method` is "function_calling" and `schema` is a dict, then the dict
|
||||
must match the OpenAI function-calling spec or be a valid JSON schema
|
||||
with top level 'title' and 'description' keys specified.
|
||||
method: The method for steering model generation, either "function_calling"
|
||||
or "json_mode". If "function_calling" then the schema will be converted
|
||||
to an OpenAI function and the returned model will make use of the
|
||||
function-calling API. If "json_mode" then OpenAI's JSON mode will be
|
||||
used. Note that if using "json_mode" then you must include instructions
|
||||
for formatting the output into the desired schema into the model call.
|
||||
include_raw: If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
Returns:
|
||||
A Runnable that takes any ChatModel input and returns as output:
|
||||
|
||||
If include_raw is True then a dict with keys:
|
||||
raw: BaseMessage
|
||||
parsed: Optional[_DictOrPydantic]
|
||||
parsing_error: Optional[BaseException]
|
||||
|
||||
If include_raw is False then just _DictOrPydantic is returned,
|
||||
where _DictOrPydantic depends on the schema:
|
||||
|
||||
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
|
||||
class.
|
||||
|
||||
If schema is a dict then _DictOrPydantic is a dict.
|
||||
|
||||
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
|
||||
llm = AzureChatOpenAI(azure_deployment="gpt-35-turbo", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
|
||||
# -> AnswerWithJustification(
|
||||
# answer='They weigh the same',
|
||||
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
|
||||
# )
|
||||
|
||||
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
|
||||
llm = AzureChatOpenAI(azure_deployment="gpt-35-turbo", temperature=0)
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification, include_raw=True
|
||||
)
|
||||
|
||||
structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
|
||||
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
Example: Function-calling, dict schema (method="function_calling", include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
|
||||
dict_schema = convert_to_openai_tool(AnswerWithJustification)
|
||||
llm = AzureChatOpenAI(azure_deployment="gpt-35-turbo", temperature=0)
|
||||
structured_llm = llm.with_structured_output(dict_schema)
|
||||
|
||||
structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
)
|
||||
# -> {
|
||||
# 'answer': 'They weigh the same',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
|
||||
# }
|
||||
|
||||
Example: JSON mode, Pydantic schema (method="json_mode", include_raw=True):
|
||||
.. code-block::
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = AzureChatOpenAI(azure_deployment="gpt-35-turbo", temperature=0)
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification,
|
||||
method="json_mode",
|
||||
include_raw=True
|
||||
)
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
Example: JSON mode, no schema (schema=None, method="json_mode", include_raw=True):
|
||||
.. code-block::
|
||||
|
||||
structured_llm = llm.with_structured_output(method="json_mode", include_raw=True)
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'parsed': {
|
||||
# 'answer': 'They are both the same weight.',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
|
||||
# },
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
llm = self.bind_tools([schema], tool_choice=True)
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
)
|
||||
else:
|
||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=key_name, first_tool_only=True
|
||||
)
|
||||
elif method == "json_mode":
|
||||
llm = self.bind(response_format={"type": "json_object"})
|
||||
output_parser = (
|
||||
PydanticOutputParser(pydantic_object=schema)
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
@ -1113,8 +1113,6 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
Example: JSON mode, no schema (schema=None, method="json_mode", include_raw=True):
|
||||
.. code-block::
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
structured_llm = llm.with_structured_output(method="json_mode", include_raw=True)
|
||||
|
||||
structured_llm.invoke(
|
||||
|
@ -12,7 +12,6 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
@ -226,18 +225,3 @@ def test_openai_invoke(llm: AzureChatOpenAI) -> None:
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
assert result.response_metadata.get("model_name") is not None
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need tool calling model deployed on azure")
|
||||
def test_openai_structured_output(llm: AzureChatOpenAI) -> None:
|
||||
class MyModel(BaseModel):
|
||||
"""A Person"""
|
||||
|
||||
name: str
|
||||
age: int
|
||||
|
||||
llm_structure = llm.with_structured_output(MyModel)
|
||||
result = llm_structure.invoke("I'm a 27 year old named Erick")
|
||||
assert isinstance(result, MyModel)
|
||||
assert result.name == "Erick"
|
||||
assert result.age == 27
|
||||
|
@ -12,6 +12,7 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain_standard_tests.unit_tests.chat_models import (
|
||||
@ -139,6 +140,20 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert isinstance(full, AIMessage)
|
||||
_validate_tool_call_message(full)
|
||||
|
||||
def test_structured_output(self, model: BaseChatModel) -> None:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
class Joke(BaseModel):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: str = Field(description="question to set up a joke")
|
||||
punchline: str = Field(description="answer to resolve the joke")
|
||||
|
||||
chat = model.with_structured_output(Joke)
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, Joke)
|
||||
|
||||
def test_tool_message_histories_string_content(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
|
Loading…
Reference in New Issue
Block a user