diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index be8eaffc616..13972802d5e 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -317,11 +317,21 @@ def _convert_delta_to_message_chunk( def _convert_chunk_to_generation_chunk( - chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict] + chunk: dict, + default_chunk_class: Type, + base_generation_info: Optional[Dict], + chunk_object: Any, ) -> Optional[ChatGenerationChunk]: if chunk.get("type") == "content.delta": # from beta.chat.completions.stream return None - token_usage = chunk.get("usage") + if chunk.get("type") == "content.done" and hasattr(chunk_object, "parsed"): + return ChatGenerationChunk( + message=default_chunk_class( + content="", additional_kwargs={"parsed": chunk_object.parsed} + ) + ) + + token_usage = chunk.get("usage") or chunk.get("chunk", {}).get("usage") choices = ( chunk.get("choices", []) # from beta.chat.completions.stream @@ -725,12 +735,15 @@ class BaseChatOpenAI(BaseChatModel): with context_manager as response: is_first_chunk = True for chunk in response: - if not isinstance(chunk, dict): - chunk = chunk.model_dump() + if isinstance(chunk, dict): + chunk_dict = chunk + else: + chunk_dict = chunk.model_dump() generation_chunk = _convert_chunk_to_generation_chunk( - chunk, + chunk_dict, default_chunk_class, base_generation_info if is_first_chunk else {}, + chunk, ) if generation_chunk is None: continue @@ -746,16 +759,6 @@ class BaseChatOpenAI(BaseChatModel): yield generation_chunk except openai.BadRequestError as e: _handle_openai_bad_request(e) - if hasattr(response, "get_final_completion") and "response_format" in payload: - final_completion = response.get_final_completion() - generation_chunk = self._get_generation_chunk_from_completion( - final_completion - ) - if run_manager: - run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk - ) - yield generation_chunk def _generate( self, @@ -893,12 +896,15 @@ class BaseChatOpenAI(BaseChatModel): async with context_manager as response: is_first_chunk = True async for chunk in response: - if not isinstance(chunk, dict): - chunk = chunk.model_dump() + if isinstance(chunk, dict): + chunk_dict = chunk + else: + chunk_dict = chunk.model_dump() generation_chunk = _convert_chunk_to_generation_chunk( - chunk, + chunk_dict, default_chunk_class, base_generation_info if is_first_chunk else {}, + chunk, ) if generation_chunk is None: continue @@ -914,16 +920,6 @@ class BaseChatOpenAI(BaseChatModel): yield generation_chunk except openai.BadRequestError as e: _handle_openai_bad_request(e) - if hasattr(response, "get_final_completion") and "response_format" in payload: - final_completion = await response.get_final_completion() - generation_chunk = self._get_generation_chunk_from_completion( - final_completion - ) - if run_manager: - await run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk - ) - yield generation_chunk async def _agenerate( self, @@ -1475,28 +1471,6 @@ class BaseChatOpenAI(BaseChatModel): filtered[k] = v return filtered - def _get_generation_chunk_from_completion( - self, completion: openai.BaseModel - ) -> ChatGenerationChunk: - """Get chunk from completion (e.g., from final completion of a stream).""" - chat_result = self._create_chat_result(completion) - chat_message = chat_result.generations[0].message - if isinstance(chat_message, AIMessage): - usage_metadata = chat_message.usage_metadata - # Skip tool_calls, already sent as chunks - if "tool_calls" in chat_message.additional_kwargs: - chat_message.additional_kwargs.pop("tool_calls") - else: - usage_metadata = None - message = AIMessageChunk( - content="", - additional_kwargs=chat_message.additional_kwargs, - usage_metadata=usage_metadata, - ) - return ChatGenerationChunk( - message=message, generation_info=chat_result.llm_output - ) - class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] """OpenAI chat model integration.