fix(openai): tolerate prompt_cache_retention drift in streaming (#36925)

This commit is contained in:
Mason Daugherty
2026-04-21 14:54:32 -04:00
committed by GitHub
parent acc54987fa
commit 488c6a73bb
2 changed files with 109 additions and 15 deletions

View File

@@ -123,6 +123,7 @@ from pydantic import (
ConfigDict,
Field,
SecretStr,
ValidationError,
model_validator,
)
from pydantic.v1 import BaseModel as BaseModelV1
@@ -4594,7 +4595,28 @@ def _coerce_chunk_response(resp: Any) -> Any:
if isinstance(resp, dict):
from openai.types.responses import Response
return Response.model_validate(resp)
# Known mismatch: API emits `prompt_cache_retention="in_memory"` while
# older `openai` packages declare only `"in-memory"` in the Literal
# (openai-python#2883). Pre-normalize so validation succeeds on
# currently-released SDK versions.
if resp.get("prompt_cache_retention") == "in_memory":
resp = {**resp, "prompt_cache_retention": "in-memory"}
try:
return Response.model_validate(resp)
except ValidationError as e:
# API sometimes drifts ahead of the installed SDK's Literal
# declarations. Fall back to a non-validating construct so streams
# still complete, and surface the drift so operators can upgrade.
logger.warning(
"OpenAI Responses payload failed SDK validation "
"(response id=%s); falling back to non-validating construct. "
"This usually means the OpenAI API has drifted ahead of the "
"installed `openai` package. Details: %s",
resp.get("id"),
e,
)
return Response.model_construct(**resp)
return resp

View File

@@ -47,6 +47,8 @@ from openai.types.shared.response_format_text import ResponseFormatText
from langchain_openai import ChatOpenAI
from tests.unit_tests.chat_models.test_base import MockSyncContextManager
MODEL = "gpt-5.4"
responses_stream = [
ResponseCreatedEvent(
response=Response(
@@ -56,7 +58,7 @@ responses_stream = [
incomplete_details=None,
instructions=None,
metadata={},
model="o4-mini-2025-04-16",
model=MODEL,
object="response",
output=[],
parallel_tool_calls=True,
@@ -88,7 +90,7 @@ responses_stream = [
incomplete_details=None,
instructions=None,
metadata={},
model="o4-mini-2025-04-16",
model=MODEL,
object="response",
output=[],
parallel_tool_calls=True,
@@ -534,7 +536,7 @@ responses_stream = [
incomplete_details=None,
instructions=None,
metadata={},
model="o4-mini-2025-04-16",
model=MODEL,
object="response",
output=[
ResponseReasoningItem(
@@ -722,9 +724,7 @@ def _strip_none(obj: Any) -> Any:
],
)
def test_responses_stream(output_version: str, expected_content: list[dict]) -> None:
llm = ChatOpenAI(
model="o4-mini", use_responses_api=True, output_version=output_version
)
llm = ChatOpenAI(model=MODEL, use_responses_api=True, output_version=output_version)
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
@@ -773,7 +773,7 @@ def test_responses_stream_with_image_generation_multiple_calls() -> None:
{"type": "function", "name": "my_tool", "parameters": {}},
]
llm = ChatOpenAI(
model="gpt-4o",
model=MODEL,
use_responses_api=True,
streaming=True,
)
@@ -808,7 +808,7 @@ def test_responses_stream_function_call_preserves_namespace() -> None:
incomplete_details=None,
instructions=None,
metadata={},
model="gpt-4o-2025-01-01",
model=MODEL,
object="response",
output=[],
parallel_tool_calls=True,
@@ -838,7 +838,7 @@ def test_responses_stream_function_call_preserves_namespace() -> None:
incomplete_details=None,
instructions=None,
metadata={},
model="gpt-4o-2025-01-01",
model=MODEL,
object="response",
output=[],
parallel_tool_calls=True,
@@ -918,7 +918,7 @@ def test_responses_stream_function_call_preserves_namespace() -> None:
incomplete_details=None,
instructions=None,
metadata={},
model="gpt-4o-2025-01-01",
model=MODEL,
object="response",
output=[
ResponseFunctionToolCallItem(
@@ -958,9 +958,7 @@ def test_responses_stream_function_call_preserves_namespace() -> None:
),
]
llm = ChatOpenAI(
model="gpt-4o", use_responses_api=True, output_version="responses/v1"
)
llm = ChatOpenAI(model=MODEL, use_responses_api=True, output_version="responses/v1")
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
@@ -1001,7 +999,81 @@ def test_responses_stream_tolerates_dict_response_field() -> None:
first_event.response = first_event.response.model_dump(mode="json") # type: ignore[assignment]
assert isinstance(first_event.response, dict)
llm = ChatOpenAI(model="o4-mini", use_responses_api=True)
llm = ChatOpenAI(model=MODEL, use_responses_api=True)
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
return MockSyncContextManager(stream)
mock_client.responses.create = mock_create
full: BaseMessageChunk | None = None
with patch.object(llm, "root_client", mock_client):
for chunk in llm.stream("test"):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.id == "resp_123"
@pytest.mark.parametrize(
("event_index", "event_type"),
[(0, ResponseCreatedEvent), (46, ResponseCompletedEvent)],
)
def test_responses_stream_normalizes_in_memory_prompt_cache_retention(
event_index: int, event_type: type
) -> None:
"""`prompt_cache_retention="in_memory"` from the API must not abort streams.
The API emits the underscore form while older `openai` packages declare only
`"in-memory"` in the Literal (openai-python#2883). `_coerce_chunk_response`
should normalize so both the `response.created` and `response.completed`
handlers can validate successfully.
"""
stream = copy.deepcopy(responses_stream)
target = stream[event_index]
assert isinstance(target, event_type)
assert isinstance(target, (ResponseCreatedEvent, ResponseCompletedEvent))
dumped = target.response.model_dump(mode="json")
dumped["prompt_cache_retention"] = "in_memory"
target.response = dumped # type: ignore[assignment]
llm = ChatOpenAI(model=MODEL, use_responses_api=True)
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
return MockSyncContextManager(stream)
mock_client.responses.create = mock_create
full: BaseMessageChunk | None = None
with patch.object(llm, "root_client", mock_client):
for chunk in llm.stream("test"):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.id == "resp_123"
# The completed event drives usage/metadata aggregation, so assert it
# survived coercion when that branch is exercised.
if event_type is ResponseCompletedEvent:
assert full.usage_metadata is not None
def test_responses_stream_tolerates_unknown_literal_drift() -> None:
"""API drift ahead of SDK Literal declarations must not abort streams.
When the API returns a value the installed SDK's Literal does not know
about, `_coerce_chunk_response` should fall back to a non-validating
construct so streaming still completes.
"""
stream = copy.deepcopy(responses_stream)
first_event = stream[0]
assert isinstance(first_event, ResponseCreatedEvent)
dumped = first_event.response.model_dump(mode="json")
dumped["status"] = "something_new"
first_event.response = dumped # type: ignore[assignment]
llm = ChatOpenAI(model=MODEL, use_responses_api=True)
mock_client = MagicMock()
def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: