mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 14:23:58 +00:00
test(openai): add tests for prompt_cache_key
parameter and update docs (#32363)
Introduce tests to validate the behavior and inclusion of the `prompt_cache_key` parameter in request payloads for the `ChatOpenAI` model.
This commit is contained in:
parent
68c70da33e
commit
145d38f7dd
@ -2731,6 +2731,31 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
Always use ``extra_body`` for custom parameters, **not** ``model_kwargs``.
|
Always use ``extra_body`` for custom parameters, **not** ``model_kwargs``.
|
||||||
Using ``model_kwargs`` for non-OpenAI parameters will cause API errors.
|
Using ``model_kwargs`` for non-OpenAI parameters will cause API errors.
|
||||||
|
|
||||||
|
.. dropdown:: Prompt caching optimization
|
||||||
|
|
||||||
|
For high-volume applications with repetitive prompts, use ``prompt_cache_key``
|
||||||
|
per-invocation to improve cache hit rates and reduce costs:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||||
|
|
||||||
|
response = llm.invoke(
|
||||||
|
messages,
|
||||||
|
prompt_cache_key="example-key-a", # Routes to same machine for cache hits
|
||||||
|
)
|
||||||
|
|
||||||
|
customer_response = llm.invoke(messages, prompt_cache_key="example-key-b")
|
||||||
|
support_response = llm.invoke(messages, prompt_cache_key="example-key-c")
|
||||||
|
|
||||||
|
# Dynamic cache keys based on context
|
||||||
|
cache_key = f"example-key-{dynamic_suffix}"
|
||||||
|
response = llm.invoke(messages, prompt_cache_key=cache_key)
|
||||||
|
|
||||||
|
Cache keys help ensure requests with the same prompt prefix are routed to
|
||||||
|
machines with existing cache, providing cost reduction and latency improvement on
|
||||||
|
cached tokens.
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
|
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
|
||||||
|
@ -1110,3 +1110,46 @@ def test_tools_and_structured_output() -> None:
|
|||||||
assert isinstance(aggregated["raw"], AIMessage)
|
assert isinstance(aggregated["raw"], AIMessage)
|
||||||
assert aggregated["raw"].tool_calls
|
assert aggregated["raw"].tool_calls
|
||||||
assert aggregated["parsed"] is None
|
assert aggregated["parsed"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_prompt_cache_key_invoke() -> None:
|
||||||
|
"""Test that prompt_cache_key works with invoke calls."""
|
||||||
|
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=20)
|
||||||
|
messages = [HumanMessage("Say hello")]
|
||||||
|
|
||||||
|
# Test that invoke works with prompt_cache_key parameter
|
||||||
|
response = chat.invoke(messages, prompt_cache_key="integration-test-v1")
|
||||||
|
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
assert len(response.content) > 0
|
||||||
|
|
||||||
|
# Test that subsequent call with same cache key also works
|
||||||
|
response2 = chat.invoke(messages, prompt_cache_key="integration-test-v1")
|
||||||
|
|
||||||
|
assert isinstance(response2, AIMessage)
|
||||||
|
assert isinstance(response2.content, str)
|
||||||
|
assert len(response2.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_prompt_cache_key_usage_methods_integration() -> None:
|
||||||
|
"""Integration test for prompt_cache_key usage methods."""
|
||||||
|
messages = [HumanMessage("Say hi")]
|
||||||
|
|
||||||
|
# Test keyword argument method
|
||||||
|
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
|
||||||
|
response = chat.invoke(messages, prompt_cache_key="integration-test-v1")
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
# Test model-level via model_kwargs
|
||||||
|
chat_model_level = ChatOpenAI(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
max_completion_tokens=10,
|
||||||
|
model_kwargs={"prompt_cache_key": "integration-model-level-v1"},
|
||||||
|
)
|
||||||
|
response_model_level = chat_model_level.invoke(messages)
|
||||||
|
assert isinstance(response_model_level, AIMessage)
|
||||||
|
assert isinstance(response_model_level.content, str)
|
||||||
|
@ -0,0 +1,84 @@
|
|||||||
|
"""Unit tests for prompt_cache_key parameter."""
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_cache_key_parameter_inclusion() -> None:
|
||||||
|
"""Test that prompt_cache_key parameter is properly included in request payload."""
|
||||||
|
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
|
||||||
|
messages = [HumanMessage("Hello")]
|
||||||
|
|
||||||
|
payload = chat._get_request_payload(messages, prompt_cache_key="test-cache-key")
|
||||||
|
assert "prompt_cache_key" in payload
|
||||||
|
assert payload["prompt_cache_key"] == "test-cache-key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_cache_key_parameter_exclusion() -> None:
|
||||||
|
"""Test that prompt_cache_key parameter behavior matches OpenAI API."""
|
||||||
|
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
|
||||||
|
messages = [HumanMessage("Hello")]
|
||||||
|
|
||||||
|
# Test with explicit None (OpenAI should accept None values (marked Optional))
|
||||||
|
payload = chat._get_request_payload(messages, prompt_cache_key=None)
|
||||||
|
assert "prompt_cache_key" in payload
|
||||||
|
assert payload["prompt_cache_key"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_cache_key_per_call() -> None:
|
||||||
|
"""Test that prompt_cache_key can be passed per-call with different values."""
|
||||||
|
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
|
||||||
|
messages = [HumanMessage("Hello")]
|
||||||
|
|
||||||
|
# Test different cache keys per call
|
||||||
|
payload1 = chat._get_request_payload(messages, prompt_cache_key="cache-v1")
|
||||||
|
payload2 = chat._get_request_payload(messages, prompt_cache_key="cache-v2")
|
||||||
|
|
||||||
|
assert payload1["prompt_cache_key"] == "cache-v1"
|
||||||
|
assert payload2["prompt_cache_key"] == "cache-v2"
|
||||||
|
|
||||||
|
# Test dynamic cache key assignment
|
||||||
|
cache_keys = ["customer-v1", "support-v1", "feedback-v1"]
|
||||||
|
|
||||||
|
for cache_key in cache_keys:
|
||||||
|
payload = chat._get_request_payload(messages, prompt_cache_key=cache_key)
|
||||||
|
assert "prompt_cache_key" in payload
|
||||||
|
assert payload["prompt_cache_key"] == cache_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_cache_key_model_kwargs() -> None:
|
||||||
|
"""Test prompt_cache_key via model_kwargs and method precedence."""
|
||||||
|
messages = [HumanMessage("Hello world")]
|
||||||
|
|
||||||
|
# Test model-level via model_kwargs
|
||||||
|
chat = ChatOpenAI(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
max_completion_tokens=10,
|
||||||
|
model_kwargs={"prompt_cache_key": "model-level-cache"},
|
||||||
|
)
|
||||||
|
payload = chat._get_request_payload(messages)
|
||||||
|
assert "prompt_cache_key" in payload
|
||||||
|
assert payload["prompt_cache_key"] == "model-level-cache"
|
||||||
|
|
||||||
|
# Test that per-call cache key overrides model-level
|
||||||
|
payload_override = chat._get_request_payload(
|
||||||
|
messages, prompt_cache_key="per-call-cache"
|
||||||
|
)
|
||||||
|
assert payload_override["prompt_cache_key"] == "per-call-cache"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_cache_key_responses_api() -> None:
|
||||||
|
"""Test that prompt_cache_key works with Responses API."""
|
||||||
|
chat = ChatOpenAI(
|
||||||
|
model="gpt-4o-mini", use_responses_api=True, max_completion_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [HumanMessage("Hello")]
|
||||||
|
payload = chat._get_request_payload(
|
||||||
|
messages, prompt_cache_key="responses-api-cache-v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# prompt_cache_key should be present regardless of API type
|
||||||
|
assert "prompt_cache_key" in payload
|
||||||
|
assert payload["prompt_cache_key"] == "responses-api-cache-v1"
|
Loading…
Reference in New Issue
Block a user