fix(openai): update file index key name (#33350)

This commit is contained in:
ccurme
2025-10-09 09:15:27 -04:00
committed by GitHub
parent a3e4f4c2e3
commit c27271f3ae
4 changed files with 76 additions and 11 deletions

View File

@@ -187,15 +187,19 @@ def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]:
new_ann["title"] = annotation["title"]
new_ann["type"] = "url_citation"
new_ann["url"] = annotation["url"]
if extra_fields := annotation.get("extras"):
new_ann.update(dict(extra_fields.items()))
else:
# Document citation
new_ann["type"] = "file_citation"
if extra_fields := annotation.get("extras"):
new_ann.update(dict(extra_fields.items()))
if "title" in annotation:
new_ann["filename"] = annotation["title"]
if extra_fields := annotation.get("extras"):
new_ann.update(dict(extra_fields.items()))
return new_ann
if annotation["type"] == "non_standard_annotation":

View File

@@ -3589,6 +3589,24 @@ def _construct_responses_api_payload(
return payload
def _format_annotation_to_lc(annotation: dict[str, Any]) -> dict[str, Any]:
# langchain-core reserves the `"index"` key for streaming aggregation.
# Here we re-name.
if annotation.get("type") == "file_citation" and "index" in annotation:
new_annotation = annotation.copy()
new_annotation["file_index"] = new_annotation.pop("index")
return new_annotation
return annotation
def _format_annotation_from_lc(annotation: dict[str, Any]) -> dict[str, Any]:
if annotation.get("type") == "file_citation" and "file_index" in annotation:
new_annotation = annotation.copy()
new_annotation["index"] = new_annotation.pop("file_index")
return new_annotation
return annotation
def _convert_chat_completions_blocks_to_responses(
block: dict[str, Any],
) -> dict[str, Any]:
@@ -3775,7 +3793,10 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
new_block = {
"type": "output_text",
"text": block["text"],
"annotations": block.get("annotations") or [],
"annotations": [
_format_annotation_from_lc(annotation)
for annotation in block.get("annotations") or []
],
}
elif block_type == "refusal":
new_block = {
@@ -3951,7 +3972,7 @@ def _construct_lc_result_from_responses_api(
"type": "text",
"text": content.text,
"annotations": [
annotation.model_dump()
_format_annotation_to_lc(annotation.model_dump())
for annotation in content.annotations
]
if isinstance(content.annotations, list)
@@ -4142,7 +4163,11 @@ def _convert_responses_chunk_to_generation_chunk(
annotation = chunk.annotation.model_dump(exclude_none=True, mode="json")
content.append(
{"type": "text", "annotations": [annotation], "index": current_index}
{
"type": "text",
"annotations": [_format_annotation_to_lc(annotation)],
"index": current_index,
}
)
elif chunk.type == "response.output_text.done":
content.append({"type": "text", "id": chunk.item_id, "index": current_index})

View File

@@ -34,7 +34,7 @@ def _check_response(response: BaseMessage | None) -> None:
if annotation["type"] == "file_citation":
assert all(
key in annotation
for key in ["file_id", "filename", "index", "type"]
for key in ["file_id", "filename", "file_index", "type"]
)
elif annotation["type"] == "web_search":
assert all(
@@ -374,9 +374,17 @@ def test_computer_calls() -> None:
assert response.additional_kwargs["tool_outputs"]
def test_file_search() -> None:
pytest.skip() # TODO: set up infra
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
@pytest.mark.default_cassette("test_file_search.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
def test_file_search(
output_version: Literal["responses/v1", "v1"],
) -> None:
llm = ChatOpenAI(
model=MODEL_NAME,
use_responses_api=True,
output_version=output_version,
)
tool = {
"type": "file_search",
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
@@ -386,16 +394,44 @@ def test_file_search() -> None:
response = llm.invoke([input_message], tools=[tool])
_check_response(response)
full: BaseMessageChunk | None = None
if output_version == "v1":
assert [block["type"] for block in response.content] == [ # type: ignore[index]
"server_tool_call",
"server_tool_result",
"text",
]
else:
assert [block["type"] for block in response.content] == [ # type: ignore[index]
"file_search_call",
"text",
]
full: AIMessageChunk | None = None
for chunk in llm.stream([input_message], tools=[tool]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
_check_response(full)
if output_version == "v1":
assert [block["type"] for block in full.content] == [ # type: ignore[index]
"server_tool_call",
"server_tool_result",
"text",
]
else:
assert [block["type"] for block in full.content] == ["file_search_call", "text"] # type: ignore[index]
next_message = {"role": "user", "content": "Thank you."}
_ = llm.invoke([input_message, full, next_message])
for message in [response, full]:
assert [block["type"] for block in message.content_blocks] == [
"server_tool_call",
"server_tool_result",
"text",
]
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
@pytest.mark.vcr