unit tests

This commit is contained in:
Bagatur 2025-03-11 20:32:09 -07:00
parent d662b095ca
commit 1d10d0d66f
4 changed files with 530 additions and 19 deletions

View File

@ -2833,13 +2833,13 @@ def _construct_lc_result_from_response_api(response: Response) -> ChatResult:
)
}
# for compatibility with chat completion calls.
response_metadata["model_name"] = response.get("model")
response_metadata["model_name"] = response_metadata.get("model")
if response.usage:
usage_metadata = _create_usage_metadata_responses(response.usage.model_dump())
else:
usage_metadata = None
content_blocks = []
content_blocks: list = []
tool_calls = []
invalid_tool_calls = []
additional_kwargs: dict = {}
@ -2898,7 +2898,7 @@ def _construct_lc_result_from_response_api(response: Response) -> ChatResult:
else:
additional_kwargs["tool_outputs"] = [tool_output]
message = AIMessage(
content=content_blocks or None, # type: ignore[arg-type]
content=content_blocks,
id=msg_id,
usage_metadata=usage_metadata,
response_metadata=response_metadata,

View File

@ -1286,7 +1286,7 @@ def test_web_search() -> None:
llm.invoke(
"what about a negative one",
tools=[{"type": "web_search_preview"}],
response_id=response.response_metadata["id"]
response_id=response.response_metadata["id"],
)
_check_response(response)

View File

@ -3,7 +3,7 @@
import json
from functools import partial
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -19,13 +19,29 @@ from langchain_core.messages import (
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import RunnableLambda
from openai.types.responses import ResponseOutputMessage
from openai.types.responses.response import IncompleteDetails, Response, ResponseUsage
from openai.types.responses.response_error import ResponseError
from openai.types.responses.response_file_search_tool_call import (
ResponseFileSearchToolCall,
Result,
)
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_function_web_search import (
ResponseFunctionWebSearch,
)
from openai.types.responses.response_output_refusal import ResponseOutputRefusal
from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_usage import OutputTokensDetails
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
_FUNCTION_CALL_IDS_MAP_KEY,
_construct_lc_result_from_response_api,
_convert_dict_to_message,
_convert_message_to_dict,
_convert_to_openai_response_format,
@ -862,7 +878,7 @@ def test_nested_structured_output_strict() -> None:
setup: str
punchline: str
self_evaluation: SelfEvaluation
_evaluation: SelfEvaluation
llm.with_structured_output(JokeWithEvaluation, method="json_schema")
@ -936,3 +952,497 @@ def test_structured_outputs_parser() -> None:
assert isinstance(deserialized, ChatGeneration)
result = output_parser.invoke(deserialized.message)
assert result == parsed_response
def test__construct_lc_result_from_response_api_error_handling() -> None:
"""Test that errors in the response are properly raised."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
error=ResponseError(message="Test error", code="server_error"),
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[],
)
with pytest.raises(ValueError) as excinfo:
_construct_lc_result_from_response_api(response)
assert "Test error" in str(excinfo.value)
def test__construct_lc_result_from_response_api_basic_text_response() -> None:
"""Test a basic text response with no tools or special features."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseOutputMessage(
type="message",
id="msg_123",
content=[
ResponseOutputText(
type="output_text", text="Hello, world!", annotations=[]
)
],
role="assistant",
status="completed",
)
],
usage=ResponseUsage(
input_tokens=10,
output_tokens=3,
total_tokens=13,
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
),
)
result = _construct_lc_result_from_response_api(response)
assert isinstance(result, ChatResult)
assert len(result.generations) == 1
assert isinstance(result.generations[0], ChatGeneration)
assert isinstance(result.generations[0].message, AIMessage)
assert result.generations[0].message.content == [
{"type": "text", "text": "Hello, world!", "annotations": []}
]
assert result.generations[0].message.id == "msg_123"
assert result.generations[0].message.usage_metadata
assert result.generations[0].message.usage_metadata["input_tokens"] == 10
assert result.generations[0].message.usage_metadata["output_tokens"] == 3
assert result.generations[0].message.usage_metadata["total_tokens"] == 13
assert result.generations[0].message.response_metadata["id"] == "resp_123"
assert result.generations[0].message.response_metadata["model_name"] == "gpt-4o"
def test__construct_lc_result_from_response_api_multiple_text_blocks() -> None:
"""Test a response with multiple text blocks."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseOutputMessage(
type="message",
id="msg_123",
content=[
ResponseOutputText(
type="output_text", text="First part", annotations=[]
),
ResponseOutputText(
type="output_text", text="Second part", annotations=[]
),
],
role="assistant",
status="completed",
)
],
)
result = _construct_lc_result_from_response_api(response)
assert len(result.generations[0].message.content) == 2
assert result.generations[0].message.content[0]["text"] == "First part" # type: ignore
assert result.generations[0].message.content[1]["text"] == "Second part" # type: ignore
def test__construct_lc_result_from_response_api_refusal_response() -> None:
"""Test a response with a refusal."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseOutputMessage(
type="message",
id="msg_123",
content=[
ResponseOutputRefusal(
type="refusal", refusal="I cannot assist with that request."
)
],
role="assistant",
status="completed",
)
],
)
result = _construct_lc_result_from_response_api(response)
assert result.generations[0].message.content == []
assert (
result.generations[0].message.additional_kwargs["refusal"]
== "I cannot assist with that request."
)
def test__construct_lc_result_from_response_api_function_call_valid_json() -> None:
"""Test a response with a valid function call."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseFunctionToolCall(
type="function_call",
id="func_123",
call_id="call_123",
name="get_weather",
arguments='{"location": "New York", "unit": "celsius"}',
)
],
)
result = _construct_lc_result_from_response_api(response)
msg: AIMessage = cast(AIMessage, result.generations[0].message)
assert len(msg.tool_calls) == 1
assert msg.tool_calls[0]["type"] == "tool_call"
assert msg.tool_calls[0]["name"] == "get_weather"
assert msg.tool_calls[0]["id"] == "call_123"
assert msg.tool_calls[0]["args"] == {"location": "New York", "unit": "celsius"}
assert _FUNCTION_CALL_IDS_MAP_KEY in result.generations[0].message.additional_kwargs
assert (
result.generations[0].message.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY][
"call_123"
]
== "func_123"
)
def test__construct_lc_result_from_response_api_function_call_invalid_json() -> None:
"""Test a response with an invalid JSON function call."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseFunctionToolCall(
type="function_call",
id="func_123",
call_id="call_123",
name="get_weather",
arguments='{"location": "New York", "unit": "celsius"',
# Missing closing brace
)
],
)
result = _construct_lc_result_from_response_api(response)
msg: AIMessage = cast(AIMessage, result.generations[0].message)
assert len(msg.invalid_tool_calls) == 1
assert msg.invalid_tool_calls[0]["type"] == "invalid_tool_call"
assert msg.invalid_tool_calls[0]["name"] == "get_weather"
assert msg.invalid_tool_calls[0]["id"] == "call_123"
assert (
msg.invalid_tool_calls[0]["args"]
== '{"location": "New York", "unit": "celsius"'
)
assert "error" in msg.invalid_tool_calls[0]
assert _FUNCTION_CALL_IDS_MAP_KEY in result.generations[0].message.additional_kwargs
def test__construct_lc_result_from_response_api_complex_response() -> None:
"""Test a complex response with multiple output types."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseOutputMessage(
type="message",
id="msg_123",
content=[
ResponseOutputText(
type="output_text",
text="Here's the information you requested:",
annotations=[],
)
],
role="assistant",
status="completed",
),
ResponseFunctionToolCall(
type="function_call",
id="func_123",
call_id="call_123",
name="get_weather",
arguments='{"location": "New York"}',
),
],
metadata=dict(key1="value1", key2="value2"),
incomplete_details=IncompleteDetails(reason="max_output_tokens"),
status="completed",
user="user_123",
)
result = _construct_lc_result_from_response_api(response)
# Check message content
assert result.generations[0].message.content == [
{
"type": "text",
"text": "Here's the information you requested:",
"annotations": [],
}
]
# Check tool calls
msg: AIMessage = cast(AIMessage, result.generations[0].message)
assert len(msg.tool_calls) == 1
assert msg.tool_calls[0]["name"] == "get_weather"
# Check metadata
assert result.generations[0].message.response_metadata["id"] == "resp_123"
assert result.generations[0].message.response_metadata["metadata"] == {
"key1": "value1",
"key2": "value2",
}
assert result.generations[0].message.response_metadata["incomplete_details"] == {
"reason": "max_output_tokens"
}
assert result.generations[0].message.response_metadata["status"] == "completed"
assert result.generations[0].message.response_metadata["user"] == "user_123"
def test__construct_lc_result_from_response_api_no_usage_metadata() -> None:
"""Test a response without usage metadata."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseOutputMessage(
type="message",
id="msg_123",
content=[
ResponseOutputText(
type="output_text", text="Hello, world!", annotations=[]
)
],
role="assistant",
status="completed",
)
],
# No usage field
)
result = _construct_lc_result_from_response_api(response)
assert cast(AIMessage, result.generations[0].message).usage_metadata is None
def test__construct_lc_result_from_response_api_web_search_response() -> None:
"""Test a response with web search output."""
from openai.types.responses.response_function_web_search import (
ResponseFunctionWebSearch,
)
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseFunctionWebSearch(
id="websearch_123", type="web_search_call", status="completed"
)
],
)
result = _construct_lc_result_from_response_api(response)
assert "tool_outputs" in result.generations[0].message.additional_kwargs
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1
assert (
result.generations[0].message.additional_kwargs["tool_outputs"][0]["type"]
== "web_search_call"
)
assert (
result.generations[0].message.additional_kwargs["tool_outputs"][0]["id"]
== "websearch_123"
)
assert (
result.generations[0].message.additional_kwargs["tool_outputs"][0]["status"]
== "completed"
)
def test__construct_lc_result_from_response_api_file_search_response() -> None:
"""Test a response with file search output."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseFileSearchToolCall(
id="filesearch_123",
type="file_search_call",
status="completed",
queries=["python code", "langchain"],
results=[
Result(
file_id="file_123",
filename="example.py",
score=0.95,
text="def hello_world() -> None:\n print('Hello, world!')",
attributes={"language": "python", "size": 42},
)
],
)
],
)
result = _construct_lc_result_from_response_api(response)
assert "tool_outputs" in result.generations[0].message.additional_kwargs
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 1
assert (
result.generations[0].message.additional_kwargs["tool_outputs"][0]["type"]
== "file_search_call"
)
assert (
result.generations[0].message.additional_kwargs["tool_outputs"][0]["id"]
== "filesearch_123"
)
assert (
result.generations[0].message.additional_kwargs["tool_outputs"][0]["status"]
== "completed"
)
assert result.generations[0].message.additional_kwargs["tool_outputs"][0][
"queries"
] == ["python code", "langchain"]
assert (
len(
result.generations[0].message.additional_kwargs["tool_outputs"][0][
"results"
]
)
== 1
)
assert (
result.generations[0].message.additional_kwargs["tool_outputs"][0]["results"][
0
]["file_id"]
== "file_123"
)
assert (
result.generations[0].message.additional_kwargs["tool_outputs"][0]["results"][
0
]["score"]
== 0.95
)
def test__construct_lc_result_from_response_api_mixed_search_responses() -> None:
"""Test a response with both web search and file search outputs."""
response = Response(
id="resp_123",
created_at=1234567890,
model="gpt-4o",
object="response",
parallel_tool_calls=True,
tools=[],
tool_choice="auto",
output=[
ResponseOutputMessage(
type="message",
id="msg_123",
content=[
ResponseOutputText(
type="output_text", text="Here's what I found:", annotations=[]
)
],
role="assistant",
status="completed",
),
ResponseFunctionWebSearch(
id="websearch_123", type="web_search_call", status="completed"
),
ResponseFileSearchToolCall(
id="filesearch_123",
type="file_search_call",
status="completed",
queries=["python code"],
results=[
Result(
file_id="file_123",
filename="example.py",
score=0.95,
text="def hello_world() -> None:\n print('Hello, world!')",
)
],
),
],
)
result = _construct_lc_result_from_response_api(response)
# Check message content
assert result.generations[0].message.content == [
{"type": "text", "text": "Here's what I found:", "annotations": []}
]
# Check tool outputs
assert "tool_outputs" in result.generations[0].message.additional_kwargs
assert len(result.generations[0].message.additional_kwargs["tool_outputs"]) == 2
# Check web search output
web_search = next(
output
for output in result.generations[0].message.additional_kwargs["tool_outputs"]
if output["type"] == "web_search_call"
)
assert web_search["id"] == "websearch_123"
assert web_search["status"] == "completed"
# Check file search output
file_search = next(
output
for output in result.generations[0].message.additional_kwargs["tool_outputs"]
if output["type"] == "file_search_call"
)
assert file_search["id"] == "filesearch_123"
assert file_search["queries"] == ["python code"]
assert file_search["results"][0]["filename"] == "example.py"

25
uv.lock
View File

@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.9, <4.0"
resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
@ -2152,7 +2153,7 @@ wheels = [
[[package]]
name = "langchain"
version = "0.3.19"
version = "0.3.20"
source = { editable = "libs/langchain" }
dependencies = [
{ name = "async-timeout", marker = "python_full_version < '3.11'" },
@ -2191,6 +2192,7 @@ requires-dist = [
{ name = "requests", specifier = ">=2,<3" },
{ name = "sqlalchemy", specifier = ">=1.4,<3" },
]
provides-extras = ["community", "anthropic", "openai", "cohere", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai"]
[package.metadata.requires-dev]
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
@ -2259,7 +2261,7 @@ typing = [
[[package]]
name = "langchain-anthropic"
version = "0.3.8"
version = "0.3.9"
source = { editable = "libs/partners/anthropic" }
dependencies = [
{ name = "anthropic" },
@ -2360,7 +2362,7 @@ typing = [
[[package]]
name = "langchain-community"
version = "0.3.18"
version = "0.3.19"
source = { editable = "libs/community" }
dependencies = [
{ name = "aiohttp" },
@ -2385,8 +2387,7 @@ requires-dist = [
{ name = "langchain", editable = "libs/langchain" },
{ name = "langchain-core", editable = "libs/core" },
{ name = "langsmith", specifier = ">=0.1.125,<0.4" },
{ name = "numpy", marker = "python_full_version < '3.12'", specifier = ">=1.26.4,<2" },
{ name = "numpy", marker = "python_full_version >= '3.12'", specifier = ">=1.26.2,<3" },
{ name = "numpy", specifier = ">=1.26.2,<3" },
{ name = "pydantic-settings", specifier = ">=2.4.0,<3.0.0" },
{ name = "pyyaml", specifier = ">=5.3" },
{ name = "requests", specifier = ">=2,<3" },
@ -2450,7 +2451,7 @@ typing = [
[[package]]
name = "langchain-core"
version = "0.3.40"
version = "0.3.43"
source = { editable = "libs/core" }
dependencies = [
{ name = "jsonpatch" },
@ -2573,7 +2574,7 @@ dependencies = [
[[package]]
name = "langchain-groq"
version = "0.2.4"
version = "0.2.5"
source = { editable = "libs/partners/groq" }
dependencies = [
{ name = "groq" },
@ -2732,7 +2733,7 @@ typing = []
[[package]]
name = "langchain-openai"
version = "0.3.7"
version = "0.3.8"
source = { editable = "libs/partners/openai" }
dependencies = [
{ name = "langchain-core" },
@ -2743,7 +2744,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "langchain-core", editable = "libs/core" },
{ name = "openai", specifier = ">=1.58.1,<2.0.0" },
{ name = "openai", specifier = ">=1.66.0,<2.0.0" },
{ name = "tiktoken", specifier = ">=0.7,<1" },
]
@ -3630,7 +3631,7 @@ wheels = [
[[package]]
name = "openai"
version = "1.61.1"
version = "1.66.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@ -3642,9 +3643,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d9/cf/61e71ce64cf0a38f029da0f9a5f10c9fa0e69a7a977b537126dac50adfea/openai-1.61.1.tar.gz", hash = "sha256:ce1851507218209961f89f3520e06726c0aa7d0512386f0f977e3ac3e4f2472e", size = 350784 }
sdist = { url = "https://files.pythonhosted.org/packages/d8/e1/b3e1fda1aa32d4f40d4de744e91de4de65c854c3e53c63342e4b5f9c5995/openai-1.66.2.tar.gz", hash = "sha256:9b3a843c25f81ee09b6469d483d9fba779d5c6ea41861180772f043481b0598d", size = 397041 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9a/b6/2e2a011b2dc27a6711376808b4cd8c922c476ea0f1420b39892117fa8563/openai-1.61.1-py3-none-any.whl", hash = "sha256:72b0826240ce26026ac2cd17951691f046e5be82ad122d20a8e1b30ca18bd11e", size = 463126 },
{ url = "https://files.pythonhosted.org/packages/2c/6f/3315b3583ffe3e31c55b446cb22d2a7c235e65ca191674fffae62deb3c11/openai-1.66.2-py3-none-any.whl", hash = "sha256:75194057ee6bb8b732526387b6041327a05656d976fc21c064e21c8ac6b07999", size = 567268 },
]
[[package]]