From 01ecd0acba22f8bf84c56a603c3e277ce5bc7ae6 Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 16 Aug 2024 12:50:50 -0400 Subject: [PATCH] openai[patch]: fix json mode for Azure (#25488) https://github.com/langchain-ai/langchain/issues/25479 https://github.com/langchain-ai/langchain/issues/25485 --------- Co-authored-by: Bagatur --- .../langchain_openai/chat_models/azure.py | 10 ++--- .../chat_models/test_azure.py | 37 ++++++++++++++++++ .../chat_models/test_base.py | 39 +++++++++++++++++++ 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index eaf31a56a33..1882ad147dc 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -640,14 +640,14 @@ class AzureChatOpenAI(BaseChatOpenAI): } if not values.get("client"): sync_specific = {"http_client": values["http_client"]} - values["client"] = openai.AzureOpenAI( - **client_params, **sync_specific - ).chat.completions + values["root_client"] = openai.AzureOpenAI(**client_params, **sync_specific) + values["client"] = values["root_client"].chat.completions if not values.get("async_client"): async_specific = {"http_client": values["http_async_client"]} - values["async_client"] = openai.AsyncAzureOpenAI( + values["root_async_client"] = openai.AsyncAzureOpenAI( **client_params, **async_specific - ).chat.completions + ) + values["async_client"] = values["root_async_client"].chat.completions return values def bind_tools( diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py index 3436f165d7c..b62e9dfdd1a 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py @@ -1,5 +1,6 @@ """Test AzureChatOpenAI wrapper.""" +import json import os from typing import Any, Optional @@ -225,3 +226,39 @@ def test_openai_invoke(llm: AzureChatOpenAI) -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) assert result.response_metadata.get("model_name") is not None + + +def test_json_mode(llm: AzureChatOpenAI) -> None: + response = llm.invoke( + "Return this as json: {'a': 1}", response_format={"type": "json_object"} + ) + assert isinstance(response.content, str) + assert json.loads(response.content) == {"a": 1} + + # Test streaming + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream( + "Return this as json: {'a': 1}", response_format={"type": "json_object"} + ): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert json.loads(full.content) == {"a": 1} + + +async def test_json_mode_async(llm: AzureChatOpenAI) -> None: + response = await llm.ainvoke( + "Return this as json: {'a': 1}", response_format={"type": "json_object"} + ) + assert isinstance(response.content, str) + assert json.loads(response.content) == {"a": 1} + + # Test streaming + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream( + "Return this as json: {'a': 1}", response_format={"type": "json_object"} + ): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert json.loads(full.content) == {"a": 1} diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 3e03755b765..f1eb6c39a5f 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1,6 +1,7 @@ """Test ChatOpenAI chat model.""" import base64 +import json from typing import Any, AsyncIterator, List, Literal, Optional, cast import httpx @@ -865,3 +866,41 @@ def test_structured_output_strict( chat.invoke("Tell me a joke about cats.") with pytest.raises(openai.BadRequestError): next(chat.stream("Tell me a joke about cats.")) + + +def test_json_mode() -> None: + llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) + response = llm.invoke( + "Return this as json: {'a': 1}", response_format={"type": "json_object"} + ) + assert isinstance(response.content, str) + assert json.loads(response.content) == {"a": 1} + + # Test streaming + full: Optional[BaseMessageChunk] = None + for chunk in llm.stream( + "Return this as json: {'a': 1}", response_format={"type": "json_object"} + ): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert json.loads(full.content) == {"a": 1} + + +async def test_json_mode_async() -> None: + llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) + response = await llm.ainvoke( + "Return this as json: {'a': 1}", response_format={"type": "json_object"} + ) + assert isinstance(response.content, str) + assert json.loads(response.content) == {"a": 1} + + # Test streaming + full: Optional[BaseMessageChunk] = None + async for chunk in llm.astream( + "Return this as json: {'a': 1}", response_format={"type": "json_object"} + ): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert isinstance(full.content, str) + assert json.loads(full.content) == {"a": 1}