openai[patch]: fix json mode for Azure (#25488)

https://github.com/langchain-ai/langchain/issues/25479
https://github.com/langchain-ai/langchain/issues/25485

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
ccurme 2024-08-16 12:50:50 -04:00 committed by GitHub
parent 1fd1c1dca5
commit 01ecd0acba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 81 additions and 5 deletions

View File

@ -640,14 +640,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
}
if not values.get("client"):
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.AzureOpenAI(
**client_params, **sync_specific
).chat.completions
values["root_client"] = openai.AzureOpenAI(**client_params, **sync_specific)
values["client"] = values["root_client"].chat.completions
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncAzureOpenAI(
values["root_async_client"] = openai.AsyncAzureOpenAI(
**client_params, **async_specific
).chat.completions
)
values["async_client"] = values["root_async_client"].chat.completions
return values
def bind_tools(

View File

@ -1,5 +1,6 @@
"""Test AzureChatOpenAI wrapper."""
import json
import os
from typing import Any, Optional
@ -225,3 +226,39 @@ def test_openai_invoke(llm: AzureChatOpenAI) -> None:
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
assert result.response_metadata.get("model_name") is not None
def test_json_mode(llm: AzureChatOpenAI) -> None:
response = llm.invoke(
"Return this as json: {'a': 1}", response_format={"type": "json_object"}
)
assert isinstance(response.content, str)
assert json.loads(response.content) == {"a": 1}
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"Return this as json: {'a': 1}", response_format={"type": "json_object"}
):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert json.loads(full.content) == {"a": 1}
async def test_json_mode_async(llm: AzureChatOpenAI) -> None:
response = await llm.ainvoke(
"Return this as json: {'a': 1}", response_format={"type": "json_object"}
)
assert isinstance(response.content, str)
assert json.loads(response.content) == {"a": 1}
# Test streaming
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream(
"Return this as json: {'a': 1}", response_format={"type": "json_object"}
):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert json.loads(full.content) == {"a": 1}

View File

@ -1,6 +1,7 @@
"""Test ChatOpenAI chat model."""
import base64
import json
from typing import Any, AsyncIterator, List, Literal, Optional, cast
import httpx
@ -865,3 +866,41 @@ def test_structured_output_strict(
chat.invoke("Tell me a joke about cats.")
with pytest.raises(openai.BadRequestError):
next(chat.stream("Tell me a joke about cats."))
def test_json_mode() -> None:
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
response = llm.invoke(
"Return this as json: {'a': 1}", response_format={"type": "json_object"}
)
assert isinstance(response.content, str)
assert json.loads(response.content) == {"a": 1}
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream(
"Return this as json: {'a': 1}", response_format={"type": "json_object"}
):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert json.loads(full.content) == {"a": 1}
async def test_json_mode_async() -> None:
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
response = await llm.ainvoke(
"Return this as json: {'a': 1}", response_format={"type": "json_object"}
)
assert isinstance(response.content, str)
assert json.loads(response.content) == {"a": 1}
# Test streaming
full: Optional[BaseMessageChunk] = None
async for chunk in llm.astream(
"Return this as json: {'a': 1}", response_format={"type": "json_object"}
):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert json.loads(full.content) == {"a": 1}