test(openai): add tests for prompt_cache_key parameter and update docs (#32363)

Introduce tests to validate the behavior and inclusion of the
`prompt_cache_key` parameter in request payloads for the `ChatOpenAI`
model.
This commit is contained in:
Mason Daugherty 2025-08-07 15:29:47 -04:00 committed by GitHub
parent 68c70da33e
commit 145d38f7dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 152 additions and 0 deletions

View File

@ -2731,6 +2731,31 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
Always use ``extra_body`` for custom parameters, **not** ``model_kwargs``.
Using ``model_kwargs`` for non-OpenAI parameters will cause API errors.
.. dropdown:: Prompt caching optimization
For high-volume applications with repetitive prompts, use ``prompt_cache_key``
per-invocation to improve cache hit rates and reduce costs:
.. code-block:: python
llm = ChatOpenAI(model="gpt-4o-mini")
response = llm.invoke(
messages,
prompt_cache_key="example-key-a", # Routes to same machine for cache hits
)
customer_response = llm.invoke(messages, prompt_cache_key="example-key-b")
support_response = llm.invoke(messages, prompt_cache_key="example-key-c")
# Dynamic cache keys based on context
cache_key = f"example-key-{dynamic_suffix}"
response = llm.invoke(messages, prompt_cache_key=cache_key)
Cache keys help ensure requests with the same prompt prefix are routed to
machines with existing cache, providing cost reduction and latency improvement on
cached tokens.
""" # noqa: E501
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")

View File

@ -1110,3 +1110,46 @@ def test_tools_and_structured_output() -> None:
assert isinstance(aggregated["raw"], AIMessage)
assert aggregated["raw"].tool_calls
assert aggregated["parsed"] is None
@pytest.mark.scheduled
def test_prompt_cache_key_invoke() -> None:
"""Test that prompt_cache_key works with invoke calls."""
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=20)
messages = [HumanMessage("Say hello")]
# Test that invoke works with prompt_cache_key parameter
response = chat.invoke(messages, prompt_cache_key="integration-test-v1")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.content) > 0
# Test that subsequent call with same cache key also works
response2 = chat.invoke(messages, prompt_cache_key="integration-test-v1")
assert isinstance(response2, AIMessage)
assert isinstance(response2.content, str)
assert len(response2.content) > 0
@pytest.mark.scheduled
def test_prompt_cache_key_usage_methods_integration() -> None:
"""Integration test for prompt_cache_key usage methods."""
messages = [HumanMessage("Say hi")]
# Test keyword argument method
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
response = chat.invoke(messages, prompt_cache_key="integration-test-v1")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
# Test model-level via model_kwargs
chat_model_level = ChatOpenAI(
model="gpt-4o-mini",
max_completion_tokens=10,
model_kwargs={"prompt_cache_key": "integration-model-level-v1"},
)
response_model_level = chat_model_level.invoke(messages)
assert isinstance(response_model_level, AIMessage)
assert isinstance(response_model_level.content, str)

View File

@ -0,0 +1,84 @@
"""Unit tests for prompt_cache_key parameter."""
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
def test_prompt_cache_key_parameter_inclusion() -> None:
"""Test that prompt_cache_key parameter is properly included in request payload."""
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
messages = [HumanMessage("Hello")]
payload = chat._get_request_payload(messages, prompt_cache_key="test-cache-key")
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] == "test-cache-key"
def test_prompt_cache_key_parameter_exclusion() -> None:
"""Test that prompt_cache_key parameter behavior matches OpenAI API."""
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
messages = [HumanMessage("Hello")]
# Test with explicit None (OpenAI should accept None values (marked Optional))
payload = chat._get_request_payload(messages, prompt_cache_key=None)
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] is None
def test_prompt_cache_key_per_call() -> None:
"""Test that prompt_cache_key can be passed per-call with different values."""
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
messages = [HumanMessage("Hello")]
# Test different cache keys per call
payload1 = chat._get_request_payload(messages, prompt_cache_key="cache-v1")
payload2 = chat._get_request_payload(messages, prompt_cache_key="cache-v2")
assert payload1["prompt_cache_key"] == "cache-v1"
assert payload2["prompt_cache_key"] == "cache-v2"
# Test dynamic cache key assignment
cache_keys = ["customer-v1", "support-v1", "feedback-v1"]
for cache_key in cache_keys:
payload = chat._get_request_payload(messages, prompt_cache_key=cache_key)
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] == cache_key
def test_prompt_cache_key_model_kwargs() -> None:
"""Test prompt_cache_key via model_kwargs and method precedence."""
messages = [HumanMessage("Hello world")]
# Test model-level via model_kwargs
chat = ChatOpenAI(
model="gpt-4o-mini",
max_completion_tokens=10,
model_kwargs={"prompt_cache_key": "model-level-cache"},
)
payload = chat._get_request_payload(messages)
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] == "model-level-cache"
# Test that per-call cache key overrides model-level
payload_override = chat._get_request_payload(
messages, prompt_cache_key="per-call-cache"
)
assert payload_override["prompt_cache_key"] == "per-call-cache"
def test_prompt_cache_key_responses_api() -> None:
"""Test that prompt_cache_key works with Responses API."""
chat = ChatOpenAI(
model="gpt-4o-mini", use_responses_api=True, max_completion_tokens=10
)
messages = [HumanMessage("Hello")]
payload = chat._get_request_payload(
messages, prompt_cache_key="responses-api-cache-v1"
)
# prompt_cache_key should be present regardless of API type
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] == "responses-api-cache-v1"