diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 142e7eca1a8..f4e26253484 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -318,8 +318,14 @@ def _convert_delta_to_message_chunk( def _convert_chunk_to_generation_chunk( chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict] ) -> Optional[ChatGenerationChunk]: + if chunk.get("type") == "content.delta": # from beta.chat.completions.stream + return None token_usage = chunk.get("usage") - choices = chunk.get("choices", []) + choices = ( + chunk.get("choices", []) + # from beta.chat.completions.stream + or chunk.get("chunk", {}).get("choices", []) + ) usage_metadata: Optional[UsageMetadata] = ( _create_usage_metadata(token_usage) if token_usage else None @@ -660,13 +666,24 @@ class BaseChatOpenAI(BaseChatModel): default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} - if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} + if "response_format" in payload: + if self.include_response_headers: + warnings.warn( + "Cannot currently include response headers when response_format is " + "specified." + ) + payload.pop("stream") + response_stream = self.root_client.beta.chat.completions.stream(**payload) + context_manager = response_stream else: - response = self.client.create(**payload) - with response: + if self.include_response_headers: + raw_response = self.client.with_raw_response.create(**payload) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.client.create(**payload) + context_manager = response + with context_manager as response: is_first_chunk = True for chunk in response: if not isinstance(chunk, dict): @@ -686,6 +703,16 @@ class BaseChatOpenAI(BaseChatModel): ) is_first_chunk = False yield generation_chunk + if hasattr(response, "get_final_completion") and "response_format" in payload: + final_completion = response.get_final_completion() + generation_chunk = self._get_generation_chunk_from_completion( + final_completion + ) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk def _generate( self, @@ -794,13 +821,29 @@ class BaseChatOpenAI(BaseChatModel): payload = self._get_request_payload(messages, stop=stop, **kwargs) default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} - if self.include_response_headers: - raw_response = await self.async_client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} + + if "response_format" in payload: + if self.include_response_headers: + warnings.warn( + "Cannot currently include response headers when response_format is " + "specified." + ) + payload.pop("stream") + response_stream = self.root_async_client.beta.chat.completions.stream( + **payload + ) + context_manager = response_stream else: - response = await self.async_client.create(**payload) - async with response: + if self.include_response_headers: + raw_response = await self.async_client.with_raw_response.create( + **payload + ) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = await self.async_client.create(**payload) + context_manager = response + async with context_manager as response: is_first_chunk = True async for chunk in response: if not isinstance(chunk, dict): @@ -820,6 +863,16 @@ class BaseChatOpenAI(BaseChatModel): ) is_first_chunk = False yield generation_chunk + if hasattr(response, "get_final_completion") and "response_format" in payload: + final_completion = await response.get_final_completion() + generation_chunk = self._get_generation_chunk_from_completion( + final_completion + ) + if run_manager: + await run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk async def _agenerate( self, @@ -1010,25 +1063,6 @@ class BaseChatOpenAI(BaseChatModel): num_tokens += 3 return num_tokens - def _should_stream( - self, - *, - async_api: bool, - run_manager: Optional[ - Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] - ] = None, - response_format: Optional[Union[dict, type]] = None, - **kwargs: Any, - ) -> bool: - if isinstance(response_format, type) and is_basemodel_subclass(response_format): - # TODO: Add support for streaming with Pydantic response_format. - warnings.warn("Streaming with Pydantic response_format not yet supported.") - return False - - return super()._should_stream( - async_api=async_api, run_manager=run_manager, **kwargs - ) - @deprecated( since="0.2.1", alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools", @@ -1531,6 +1565,25 @@ class BaseChatOpenAI(BaseChatModel): filtered[k] = v return filtered + def _get_generation_chunk_from_completion( + self, completion: openai.BaseModel + ) -> ChatGenerationChunk: + """Get chunk from completion (e.g., from final completion of a stream).""" + chat_result = self._create_chat_result(completion) + chat_message = chat_result.generations[0].message + if isinstance(chat_message, AIMessage): + usage_metadata = chat_message.usage_metadata + else: + usage_metadata = None + message = AIMessageChunk( + content="", + additional_kwargs=chat_message.additional_kwargs, + usage_metadata=usage_metadata, + ) + return ChatGenerationChunk( + message=message, generation_info=chat_result.llm_output + ) + class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] """OpenAI chat model integration. diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py index 4ed531ad119..e9228df0730 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py @@ -13,6 +13,7 @@ from langchain_core.messages import ( HumanMessage, ) from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult +from pydantic import BaseModel from langchain_openai import AzureChatOpenAI from tests.unit_tests.fake.callbacks import FakeCallbackHandler @@ -262,3 +263,37 @@ async def test_json_mode_async(llm: AzureChatOpenAI) -> None: assert isinstance(full, AIMessageChunk) assert isinstance(full.content, str) assert json.loads(full.content) == {"a": 1} + + +class Foo(BaseModel): + response: str + + +def test_stream_response_format(llm: AzureChatOpenAI) -> None: + full: Optional[BaseMessageChunk] = None + chunks = [] + for chunk in llm.stream("how are ya", response_format=Foo): + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] + + +async def test_astream_response_format(llm: AzureChatOpenAI) -> None: + full: Optional[BaseMessageChunk] = None + chunks = [] + async for chunk in llm.astream("how are ya", response_format=Foo): + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 93c08ce2141..506799aef4b 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1092,14 +1092,37 @@ class Foo(BaseModel): def test_stream_response_format() -> None: - list(ChatOpenAI(model="gpt-4o-mini").stream("how are ya", response_format=Foo)) + full: Optional[BaseMessageChunk] = None + chunks = [] + for chunk in ChatOpenAI(model="gpt-4o-mini").stream( + "how are ya", response_format=Foo + ): + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] async def test_astream_response_format() -> None: - async for _ in ChatOpenAI(model="gpt-4o-mini").astream( + full: Optional[BaseMessageChunk] = None + chunks = [] + async for chunk in ChatOpenAI(model="gpt-4o-mini").astream( "how are ya", response_format=Foo ): - pass + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] @pytest.mark.parametrize("use_max_completion_tokens", [True, False])