diff --git a/libs/partners/openai/langchain_openai/chat_models/_compat.py b/libs/partners/openai/langchain_openai/chat_models/_compat.py index 86aff674295..1f3d5f3344c 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_compat.py +++ b/libs/partners/openai/langchain_openai/chat_models/_compat.py @@ -68,27 +68,17 @@ formats. The functions are used internally by ChatOpenAI. import json from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Union, cast +from typing import Any, Union, cast from langchain_core.messages import ( AIMessage, AIMessageChunk, DocumentCitation, NonStandardAnnotation, - ReasoningContentBlock, UrlCitation, is_data_content_block, ) -if TYPE_CHECKING: - from langchain_core.messages import ( - Base64ContentBlock, - NonStandardContentBlock, - ReasoningContentBlock, - TextContentBlock, - ToolCallContentBlock, - ) - _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__" @@ -284,15 +274,13 @@ def _convert_to_v1_from_chat_completions(message: AIMessage) -> AIMessage: """Mutate a Chat Completions message to v1 format.""" if isinstance(message.content, str): if message.content: - block: TextContentBlock = {"type": "text", "text": message.content} - message.content = [block] + message.content = [{"type": "text", "text": message.content}] else: message.content = [] for tool_call in message.tool_calls: if id_ := tool_call.get("id"): - tool_call_block: ToolCallContentBlock = {"type": "tool_call", "id": id_} - message.content.append(tool_call_block) + message.content.append({"type": "tool_call", "id": id_}) if "tool_calls" in message.additional_kwargs: _ = message.additional_kwargs.pop("tool_calls") @@ -336,31 +324,31 @@ def _convert_annotation_to_v1( annotation_type = annotation.get("type") if annotation_type == "url_citation": - new_annotation: UrlCitation = {"type": "url_citation", "url": annotation["url"]} + url_citation: UrlCitation = {"type": "url_citation", "url": annotation["url"]} for field in ("title", "start_index", "end_index"): if field in annotation: - new_annotation[field] = annotation[field] - return new_annotation + url_citation[field] = annotation[field] + return url_citation elif annotation_type == "file_citation": - new_annotation: DocumentCitation = {"type": "document_citation"} + document_citation: DocumentCitation = {"type": "document_citation"} if "filename" in annotation: - new_annotation["title"] = annotation["filename"] + document_citation["title"] = annotation["filename"] for field in ("file_id", "index"): # OpenAI-specific if field in annotation: - new_annotation[field] = annotation[field] - return new_annotation + document_citation[field] = annotation[field] # type: ignore[literal-required] + return document_citation # TODO: standardise container_file_citation? else: - new_annotation: NonStandardAnnotation = { + non_standard_annotation: NonStandardAnnotation = { "type": "non_standard_annotation", "value": annotation, } - return new_annotation + return non_standard_annotation -def _explode_reasoning(block: dict[str, Any]) -> Iterable[ReasoningContentBlock]: +def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]: if block.get("type") != "reasoning" or "summary" not in block: yield block return @@ -383,7 +371,7 @@ def _explode_reasoning(block: dict[str, Any]) -> Iterable[ReasoningContentBlock] new_block["reasoning"] = part.get("text", "") if idx == 0: new_block.update(first_only) - yield cast(ReasoningContentBlock, new_block) + yield new_block def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: @@ -393,6 +381,8 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: def _iter_blocks() -> Iterable[dict[str, Any]]: for block in message.content: + if not isinstance(block, dict): + continue block_type = block.get("type") if block_type == "text": @@ -408,11 +398,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: elif block_type == "image_generation_call" and ( result := block.get("result") ): - new_block: Base64ContentBlock = { - "type": "image", - "source_type": "base64", - "data": result, - } + new_block = {"type": "image", "source_type": "base64", "data": result} if output_format := block.get("output_format"): new_block["mime_type"] = f"image/{output_format}" for extra_key in ( @@ -430,10 +416,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: yield new_block elif block_type == "function_call": - new_block: ToolCallContentBlock = { - "type": "tool_call", - "id": block.get("call_id", ""), - } + new_block = {"type": "tool_call", "id": block.get("call_id", "")} if "id" in block: new_block["item_id"] = block["id"] for extra_key in ("arguments", "name", "index"): @@ -442,10 +425,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: yield new_block else: - new_block: NonStandardContentBlock = { - "type": "non_standard", - "value": block, - } + new_block = {"type": "non_standard", "value": block} if "index" in new_block["value"]: new_block["index"] = new_block["value"].pop("index") yield new_block diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 9cd4161ec88..39a2a6b2d91 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -3803,7 +3803,7 @@ def _construct_lc_result_from_responses_api( ) if image_generation_call.output_format: mime_type = f"image/{image_generation_call.output_format}" - for block in message.content: + for block in message.beta_content: # type: ignore[assignment] # OK to mutate output message if ( block.get("type") == "image" @@ -4009,7 +4009,7 @@ def _convert_responses_chunk_to_generation_chunk( } ) else: - block = {"type": "reasoning", "reasoning": ""} + block: dict = {"type": "reasoning", "reasoning": ""} if chunk.summary_index > 0: _advance(chunk.output_index, chunk.summary_index) block["id"] = chunk.item_id @@ -4050,7 +4050,7 @@ def _convert_responses_chunk_to_generation_chunk( _convert_to_v03_ai_message(message, has_reasoning=has_reasoning), ) elif output_version == "v1": - message = _convert_to_v1_from_responses(message) + message = cast(AIMessageChunk, _convert_to_v1_from_responses(message)) else: pass return ( diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index 2db68db409f..a20c7a46113 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -472,6 +472,7 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) - "content": "Write and run code to answer the question: what is 3^3?", } response = llm_with_tools.invoke([input_message]) + assert isinstance(response, AIMessage) _check_response(response) if output_version == "v0": tool_outputs = [ @@ -481,12 +482,16 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) - ] elif output_version == "responses/v1": tool_outputs = [ - item for item in response.content if item["type"] == "code_interpreter_call" + item + for item in response.content + if isinstance(item, dict) and item["type"] == "code_interpreter_call" ] else: # v1 tool_outputs = [ - item["value"] for item in response.content if item["type"] == "non_standard" + item["value"] + for item in response.beta_content + if item["type"] == "non_standard" ] assert tool_outputs[0]["type"] == "code_interpreter_call" assert len(tool_outputs) == 1 @@ -511,11 +516,15 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) - ] elif output_version == "responses/v1": tool_outputs = [ - item for item in response.content if item["type"] == "code_interpreter_call" + item + for item in response.content + if isinstance(item, dict) and item["type"] == "code_interpreter_call" ] else: tool_outputs = [ - item["value"] for item in response.content if item["type"] == "non_standard" + item["value"] + for item in response.beta_content + if item["type"] == "non_standard" ] assert tool_outputs[0]["type"] == "code_interpreter_call" assert tool_outputs @@ -675,14 +684,16 @@ def test_image_generation_streaming(output_version: str) -> None: tool_output = next( block for block in complete_ai_message.content - if block["type"] == "image_generation_call" + if isinstance(block, dict) and block["type"] == "image_generation_call" ) assert set(tool_output.keys()).issubset(expected_keys) else: # v1 standard_keys = {"type", "source_type", "data", "id", "status", "index"} tool_output = next( - block for block in complete_ai_message.content if block["type"] == "image" + block + for block in complete_ai_message.beta_content + if block["type"] == "image" ) assert set(standard_keys).issubset(tool_output.keys()) @@ -711,6 +722,7 @@ def test_image_generation_multi_turn(output_version: str) -> None: {"role": "user", "content": "Draw a random short word in green font."} ] ai_message = llm_with_tools.invoke(chat_history) + assert isinstance(ai_message, AIMessage) _check_response(ai_message) expected_keys = { @@ -732,13 +744,13 @@ def test_image_generation_multi_turn(output_version: str) -> None: tool_output = next( block for block in ai_message.content - if block["type"] == "image_generation_call" + if isinstance(block, dict) and block["type"] == "image_generation_call" ) assert set(tool_output.keys()).issubset(expected_keys) else: standard_keys = {"type", "source_type", "data", "id", "status"} tool_output = next( - block for block in ai_message.content if block["type"] == "image" + block for block in ai_message.beta_content if block["type"] == "image" ) assert set(standard_keys).issubset(tool_output.keys()) @@ -774,6 +786,7 @@ def test_image_generation_multi_turn(output_version: str) -> None: ) ai_message2 = llm_with_tools.invoke(chat_history) + assert isinstance(ai_message2, AIMessage) _check_response(ai_message2) if output_version == "v0": @@ -783,12 +796,12 @@ def test_image_generation_multi_turn(output_version: str) -> None: tool_output = next( block for block in ai_message2.content - if block["type"] == "image_generation_call" + if isinstance(block, dict) and block["type"] == "image_generation_call" ) assert set(tool_output.keys()).issubset(expected_keys) else: standard_keys = {"type", "source_type", "data", "id", "status"} tool_output = next( - block for block in ai_message2.content if block["type"] == "image" + block for block in ai_message2.beta_content if block["type"] == "image" ) assert set(standard_keys).issubset(tool_output.keys())