openai: lint

This commit is contained in:
Chester Curme 2025-07-11 14:07:47 -04:00
parent 679a9e7c8f
commit ce369125f3
3 changed files with 45 additions and 52 deletions

View File

@ -68,27 +68,17 @@ formats. The functions are used internally by ChatOpenAI.
import json import json
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Union, cast from typing import Any, Union, cast
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
DocumentCitation, DocumentCitation,
NonStandardAnnotation, NonStandardAnnotation,
ReasoningContentBlock,
UrlCitation, UrlCitation,
is_data_content_block, is_data_content_block,
) )
if TYPE_CHECKING:
from langchain_core.messages import (
Base64ContentBlock,
NonStandardContentBlock,
ReasoningContentBlock,
TextContentBlock,
ToolCallContentBlock,
)
_FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__" _FUNCTION_CALL_IDS_MAP_KEY = "__openai_function_call_ids__"
@ -284,15 +274,13 @@ def _convert_to_v1_from_chat_completions(message: AIMessage) -> AIMessage:
"""Mutate a Chat Completions message to v1 format.""" """Mutate a Chat Completions message to v1 format."""
if isinstance(message.content, str): if isinstance(message.content, str):
if message.content: if message.content:
block: TextContentBlock = {"type": "text", "text": message.content} message.content = [{"type": "text", "text": message.content}]
message.content = [block]
else: else:
message.content = [] message.content = []
for tool_call in message.tool_calls: for tool_call in message.tool_calls:
if id_ := tool_call.get("id"): if id_ := tool_call.get("id"):
tool_call_block: ToolCallContentBlock = {"type": "tool_call", "id": id_} message.content.append({"type": "tool_call", "id": id_})
message.content.append(tool_call_block)
if "tool_calls" in message.additional_kwargs: if "tool_calls" in message.additional_kwargs:
_ = message.additional_kwargs.pop("tool_calls") _ = message.additional_kwargs.pop("tool_calls")
@ -336,31 +324,31 @@ def _convert_annotation_to_v1(
annotation_type = annotation.get("type") annotation_type = annotation.get("type")
if annotation_type == "url_citation": if annotation_type == "url_citation":
new_annotation: UrlCitation = {"type": "url_citation", "url": annotation["url"]} url_citation: UrlCitation = {"type": "url_citation", "url": annotation["url"]}
for field in ("title", "start_index", "end_index"): for field in ("title", "start_index", "end_index"):
if field in annotation: if field in annotation:
new_annotation[field] = annotation[field] url_citation[field] = annotation[field]
return new_annotation return url_citation
elif annotation_type == "file_citation": elif annotation_type == "file_citation":
new_annotation: DocumentCitation = {"type": "document_citation"} document_citation: DocumentCitation = {"type": "document_citation"}
if "filename" in annotation: if "filename" in annotation:
new_annotation["title"] = annotation["filename"] document_citation["title"] = annotation["filename"]
for field in ("file_id", "index"): # OpenAI-specific for field in ("file_id", "index"): # OpenAI-specific
if field in annotation: if field in annotation:
new_annotation[field] = annotation[field] document_citation[field] = annotation[field] # type: ignore[literal-required]
return new_annotation return document_citation
# TODO: standardise container_file_citation? # TODO: standardise container_file_citation?
else: else:
new_annotation: NonStandardAnnotation = { non_standard_annotation: NonStandardAnnotation = {
"type": "non_standard_annotation", "type": "non_standard_annotation",
"value": annotation, "value": annotation,
} }
return new_annotation return non_standard_annotation
def _explode_reasoning(block: dict[str, Any]) -> Iterable[ReasoningContentBlock]: def _explode_reasoning(block: dict[str, Any]) -> Iterable[dict[str, Any]]:
if block.get("type") != "reasoning" or "summary" not in block: if block.get("type") != "reasoning" or "summary" not in block:
yield block yield block
return return
@ -383,7 +371,7 @@ def _explode_reasoning(block: dict[str, Any]) -> Iterable[ReasoningContentBlock]
new_block["reasoning"] = part.get("text", "") new_block["reasoning"] = part.get("text", "")
if idx == 0: if idx == 0:
new_block.update(first_only) new_block.update(first_only)
yield cast(ReasoningContentBlock, new_block) yield new_block
def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage: def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
@ -393,6 +381,8 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
def _iter_blocks() -> Iterable[dict[str, Any]]: def _iter_blocks() -> Iterable[dict[str, Any]]:
for block in message.content: for block in message.content:
if not isinstance(block, dict):
continue
block_type = block.get("type") block_type = block.get("type")
if block_type == "text": if block_type == "text":
@ -408,11 +398,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
elif block_type == "image_generation_call" and ( elif block_type == "image_generation_call" and (
result := block.get("result") result := block.get("result")
): ):
new_block: Base64ContentBlock = { new_block = {"type": "image", "source_type": "base64", "data": result}
"type": "image",
"source_type": "base64",
"data": result,
}
if output_format := block.get("output_format"): if output_format := block.get("output_format"):
new_block["mime_type"] = f"image/{output_format}" new_block["mime_type"] = f"image/{output_format}"
for extra_key in ( for extra_key in (
@ -430,10 +416,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
yield new_block yield new_block
elif block_type == "function_call": elif block_type == "function_call":
new_block: ToolCallContentBlock = { new_block = {"type": "tool_call", "id": block.get("call_id", "")}
"type": "tool_call",
"id": block.get("call_id", ""),
}
if "id" in block: if "id" in block:
new_block["item_id"] = block["id"] new_block["item_id"] = block["id"]
for extra_key in ("arguments", "name", "index"): for extra_key in ("arguments", "name", "index"):
@ -442,10 +425,7 @@ def _convert_to_v1_from_responses(message: AIMessage) -> AIMessage:
yield new_block yield new_block
else: else:
new_block: NonStandardContentBlock = { new_block = {"type": "non_standard", "value": block}
"type": "non_standard",
"value": block,
}
if "index" in new_block["value"]: if "index" in new_block["value"]:
new_block["index"] = new_block["value"].pop("index") new_block["index"] = new_block["value"].pop("index")
yield new_block yield new_block

View File

@ -3803,7 +3803,7 @@ def _construct_lc_result_from_responses_api(
) )
if image_generation_call.output_format: if image_generation_call.output_format:
mime_type = f"image/{image_generation_call.output_format}" mime_type = f"image/{image_generation_call.output_format}"
for block in message.content: for block in message.beta_content: # type: ignore[assignment]
# OK to mutate output message # OK to mutate output message
if ( if (
block.get("type") == "image" block.get("type") == "image"
@ -4009,7 +4009,7 @@ def _convert_responses_chunk_to_generation_chunk(
} }
) )
else: else:
block = {"type": "reasoning", "reasoning": ""} block: dict = {"type": "reasoning", "reasoning": ""}
if chunk.summary_index > 0: if chunk.summary_index > 0:
_advance(chunk.output_index, chunk.summary_index) _advance(chunk.output_index, chunk.summary_index)
block["id"] = chunk.item_id block["id"] = chunk.item_id
@ -4050,7 +4050,7 @@ def _convert_responses_chunk_to_generation_chunk(
_convert_to_v03_ai_message(message, has_reasoning=has_reasoning), _convert_to_v03_ai_message(message, has_reasoning=has_reasoning),
) )
elif output_version == "v1": elif output_version == "v1":
message = _convert_to_v1_from_responses(message) message = cast(AIMessageChunk, _convert_to_v1_from_responses(message))
else: else:
pass pass
return ( return (

View File

@ -472,6 +472,7 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
"content": "Write and run code to answer the question: what is 3^3?", "content": "Write and run code to answer the question: what is 3^3?",
} }
response = llm_with_tools.invoke([input_message]) response = llm_with_tools.invoke([input_message])
assert isinstance(response, AIMessage)
_check_response(response) _check_response(response)
if output_version == "v0": if output_version == "v0":
tool_outputs = [ tool_outputs = [
@ -481,12 +482,16 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
] ]
elif output_version == "responses/v1": elif output_version == "responses/v1":
tool_outputs = [ tool_outputs = [
item for item in response.content if item["type"] == "code_interpreter_call" item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
] ]
else: else:
# v1 # v1
tool_outputs = [ tool_outputs = [
item["value"] for item in response.content if item["type"] == "non_standard" item["value"]
for item in response.beta_content
if item["type"] == "non_standard"
] ]
assert tool_outputs[0]["type"] == "code_interpreter_call" assert tool_outputs[0]["type"] == "code_interpreter_call"
assert len(tool_outputs) == 1 assert len(tool_outputs) == 1
@ -511,11 +516,15 @@ def test_code_interpreter(output_version: Literal["v0", "responses/v1", "v1"]) -
] ]
elif output_version == "responses/v1": elif output_version == "responses/v1":
tool_outputs = [ tool_outputs = [
item for item in response.content if item["type"] == "code_interpreter_call" item
for item in response.content
if isinstance(item, dict) and item["type"] == "code_interpreter_call"
] ]
else: else:
tool_outputs = [ tool_outputs = [
item["value"] for item in response.content if item["type"] == "non_standard" item["value"]
for item in response.beta_content
if item["type"] == "non_standard"
] ]
assert tool_outputs[0]["type"] == "code_interpreter_call" assert tool_outputs[0]["type"] == "code_interpreter_call"
assert tool_outputs assert tool_outputs
@ -675,14 +684,16 @@ def test_image_generation_streaming(output_version: str) -> None:
tool_output = next( tool_output = next(
block block
for block in complete_ai_message.content for block in complete_ai_message.content
if block["type"] == "image_generation_call" if isinstance(block, dict) and block["type"] == "image_generation_call"
) )
assert set(tool_output.keys()).issubset(expected_keys) assert set(tool_output.keys()).issubset(expected_keys)
else: else:
# v1 # v1
standard_keys = {"type", "source_type", "data", "id", "status", "index"} standard_keys = {"type", "source_type", "data", "id", "status", "index"}
tool_output = next( tool_output = next(
block for block in complete_ai_message.content if block["type"] == "image" block
for block in complete_ai_message.beta_content
if block["type"] == "image"
) )
assert set(standard_keys).issubset(tool_output.keys()) assert set(standard_keys).issubset(tool_output.keys())
@ -711,6 +722,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
{"role": "user", "content": "Draw a random short word in green font."} {"role": "user", "content": "Draw a random short word in green font."}
] ]
ai_message = llm_with_tools.invoke(chat_history) ai_message = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message, AIMessage)
_check_response(ai_message) _check_response(ai_message)
expected_keys = { expected_keys = {
@ -732,13 +744,13 @@ def test_image_generation_multi_turn(output_version: str) -> None:
tool_output = next( tool_output = next(
block block
for block in ai_message.content for block in ai_message.content
if block["type"] == "image_generation_call" if isinstance(block, dict) and block["type"] == "image_generation_call"
) )
assert set(tool_output.keys()).issubset(expected_keys) assert set(tool_output.keys()).issubset(expected_keys)
else: else:
standard_keys = {"type", "source_type", "data", "id", "status"} standard_keys = {"type", "source_type", "data", "id", "status"}
tool_output = next( tool_output = next(
block for block in ai_message.content if block["type"] == "image" block for block in ai_message.beta_content if block["type"] == "image"
) )
assert set(standard_keys).issubset(tool_output.keys()) assert set(standard_keys).issubset(tool_output.keys())
@ -774,6 +786,7 @@ def test_image_generation_multi_turn(output_version: str) -> None:
) )
ai_message2 = llm_with_tools.invoke(chat_history) ai_message2 = llm_with_tools.invoke(chat_history)
assert isinstance(ai_message2, AIMessage)
_check_response(ai_message2) _check_response(ai_message2)
if output_version == "v0": if output_version == "v0":
@ -783,12 +796,12 @@ def test_image_generation_multi_turn(output_version: str) -> None:
tool_output = next( tool_output = next(
block block
for block in ai_message2.content for block in ai_message2.content
if block["type"] == "image_generation_call" if isinstance(block, dict) and block["type"] == "image_generation_call"
) )
assert set(tool_output.keys()).issubset(expected_keys) assert set(tool_output.keys()).issubset(expected_keys)
else: else:
standard_keys = {"type", "source_type", "data", "id", "status"} standard_keys = {"type", "source_type", "data", "id", "status"}
tool_output = next( tool_output = next(
block for block in ai_message2.content if block["type"] == "image" block for block in ai_message2.beta_content if block["type"] == "image"
) )
assert set(standard_keys).issubset(tool_output.keys()) assert set(standard_keys).issubset(tool_output.keys())