mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
openai[patch]: support Responses streaming in AzureChatOpenAI (#31641)
Resolves https://github.com/langchain-ai/langchain/issues/31303, https://github.com/langchain-ai/langchain/issues/31624
This commit is contained in:
parent
bc1b5ffc91
commit
b610859633
@ -4,14 +4,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import Awaitable
|
from collections.abc import AsyncIterator, Awaitable, Iterator
|
||||||
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
|
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from langchain_core.language_models import LanguageModelInput
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import LangSmithParams
|
from langchain_core.language_models.chat_models import LangSmithParams
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.outputs import ChatResult
|
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.utils import from_env, secret_from_env
|
from langchain_core.utils import from_env, secret_from_env
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
@ -736,6 +736,24 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
|
|
||||||
return chat_result
|
return chat_result
|
||||||
|
|
||||||
|
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||||
|
"""Route to Chat Completions or Responses API."""
|
||||||
|
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||||
|
return super()._stream_responses(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super()._stream(*args, **kwargs)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
"""Route to Chat Completions or Responses API."""
|
||||||
|
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||||
|
async for chunk in super()._astream_responses(*args, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
async for chunk in super()._astream(*args, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
self,
|
self,
|
||||||
schema: Optional[_DictOrPydanticClass] = None,
|
schema: Optional[_DictOrPydanticClass] = None,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||||
|
|
||||||
@ -39,6 +40,38 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureOpenAIResponses(ChatModelIntegrationTests):
|
||||||
|
@property
|
||||||
|
def chat_model_class(self) -> type[BaseChatModel]:
|
||||||
|
return AzureChatOpenAI
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_model_params(self) -> dict:
|
||||||
|
return {
|
||||||
|
"deployment_name": os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"],
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"openai_api_version": OPENAI_API_VERSION,
|
||||||
|
"azure_endpoint": OPENAI_API_BASE,
|
||||||
|
"use_responses_api": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_image_inputs(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_image_urls(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_json_mode(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Unsupported.")
|
||||||
|
def test_stop_sequence(self, model: BaseChatModel) -> None:
|
||||||
|
super().test_stop_sequence(model)
|
||||||
|
|
||||||
|
|
||||||
class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
|
class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
|
||||||
"""Test a legacy model."""
|
"""Test a legacy model."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user