diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 46a720f8da1..56220d9f531 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2731,6 +2731,31 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] Always use ``extra_body`` for custom parameters, **not** ``model_kwargs``. Using ``model_kwargs`` for non-OpenAI parameters will cause API errors. + .. dropdown:: Prompt caching optimization + + For high-volume applications with repetitive prompts, use ``prompt_cache_key`` + per-invocation to improve cache hit rates and reduce costs: + + .. code-block:: python + + llm = ChatOpenAI(model="gpt-4o-mini") + + response = llm.invoke( + messages, + prompt_cache_key="example-key-a", # Routes to same machine for cache hits + ) + + customer_response = llm.invoke(messages, prompt_cache_key="example-key-b") + support_response = llm.invoke(messages, prompt_cache_key="example-key-c") + + # Dynamic cache keys based on context + cache_key = f"example-key-{dynamic_suffix}" + response = llm.invoke(messages, prompt_cache_key=cache_key) + + Cache keys help ensure requests with the same prompt prefix are routed to + machines with existing cache, providing cost reduction and latency improvement on + cached tokens. + """ # noqa: E501 max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens") diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 48fa87bbb19..1bc191d418f 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1110,3 +1110,46 @@ def test_tools_and_structured_output() -> None: assert isinstance(aggregated["raw"], AIMessage) assert aggregated["raw"].tool_calls assert aggregated["parsed"] is None + + +@pytest.mark.scheduled +def test_prompt_cache_key_invoke() -> None: + """Test that prompt_cache_key works with invoke calls.""" + chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=20) + messages = [HumanMessage("Say hello")] + + # Test that invoke works with prompt_cache_key parameter + response = chat.invoke(messages, prompt_cache_key="integration-test-v1") + + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert len(response.content) > 0 + + # Test that subsequent call with same cache key also works + response2 = chat.invoke(messages, prompt_cache_key="integration-test-v1") + + assert isinstance(response2, AIMessage) + assert isinstance(response2.content, str) + assert len(response2.content) > 0 + + +@pytest.mark.scheduled +def test_prompt_cache_key_usage_methods_integration() -> None: + """Integration test for prompt_cache_key usage methods.""" + messages = [HumanMessage("Say hi")] + + # Test keyword argument method + chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10) + response = chat.invoke(messages, prompt_cache_key="integration-test-v1") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + # Test model-level via model_kwargs + chat_model_level = ChatOpenAI( + model="gpt-4o-mini", + max_completion_tokens=10, + model_kwargs={"prompt_cache_key": "integration-model-level-v1"}, + ) + response_model_level = chat_model_level.invoke(messages) + assert isinstance(response_model_level, AIMessage) + assert isinstance(response_model_level.content, str) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_prompt_cache_key.py b/libs/partners/openai/tests/unit_tests/chat_models/test_prompt_cache_key.py new file mode 100644 index 00000000000..1f6c8c5d583 --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_prompt_cache_key.py @@ -0,0 +1,84 @@ +"""Unit tests for prompt_cache_key parameter.""" + +from langchain_core.messages import HumanMessage + +from langchain_openai import ChatOpenAI + + +def test_prompt_cache_key_parameter_inclusion() -> None: + """Test that prompt_cache_key parameter is properly included in request payload.""" + chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10) + messages = [HumanMessage("Hello")] + + payload = chat._get_request_payload(messages, prompt_cache_key="test-cache-key") + assert "prompt_cache_key" in payload + assert payload["prompt_cache_key"] == "test-cache-key" + + +def test_prompt_cache_key_parameter_exclusion() -> None: + """Test that prompt_cache_key parameter behavior matches OpenAI API.""" + chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10) + messages = [HumanMessage("Hello")] + + # Test with explicit None (OpenAI should accept None values (marked Optional)) + payload = chat._get_request_payload(messages, prompt_cache_key=None) + assert "prompt_cache_key" in payload + assert payload["prompt_cache_key"] is None + + +def test_prompt_cache_key_per_call() -> None: + """Test that prompt_cache_key can be passed per-call with different values.""" + chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10) + messages = [HumanMessage("Hello")] + + # Test different cache keys per call + payload1 = chat._get_request_payload(messages, prompt_cache_key="cache-v1") + payload2 = chat._get_request_payload(messages, prompt_cache_key="cache-v2") + + assert payload1["prompt_cache_key"] == "cache-v1" + assert payload2["prompt_cache_key"] == "cache-v2" + + # Test dynamic cache key assignment + cache_keys = ["customer-v1", "support-v1", "feedback-v1"] + + for cache_key in cache_keys: + payload = chat._get_request_payload(messages, prompt_cache_key=cache_key) + assert "prompt_cache_key" in payload + assert payload["prompt_cache_key"] == cache_key + + +def test_prompt_cache_key_model_kwargs() -> None: + """Test prompt_cache_key via model_kwargs and method precedence.""" + messages = [HumanMessage("Hello world")] + + # Test model-level via model_kwargs + chat = ChatOpenAI( + model="gpt-4o-mini", + max_completion_tokens=10, + model_kwargs={"prompt_cache_key": "model-level-cache"}, + ) + payload = chat._get_request_payload(messages) + assert "prompt_cache_key" in payload + assert payload["prompt_cache_key"] == "model-level-cache" + + # Test that per-call cache key overrides model-level + payload_override = chat._get_request_payload( + messages, prompt_cache_key="per-call-cache" + ) + assert payload_override["prompt_cache_key"] == "per-call-cache" + + +def test_prompt_cache_key_responses_api() -> None: + """Test that prompt_cache_key works with Responses API.""" + chat = ChatOpenAI( + model="gpt-4o-mini", use_responses_api=True, max_completion_tokens=10 + ) + + messages = [HumanMessage("Hello")] + payload = chat._get_request_payload( + messages, prompt_cache_key="responses-api-cache-v1" + ) + + # prompt_cache_key should be present regardless of API type + assert "prompt_cache_key" in payload + assert payload["prompt_cache_key"] == "responses-api-cache-v1"