diff --git a/libs/partners/openai/langchain_openai/chat_models/_compat.py b/libs/partners/openai/langchain_openai/chat_models/_compat.py index 6ff3b932b9b..0c664ac0bad 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_compat.py +++ b/libs/partners/openai/langchain_openai/chat_models/_compat.py @@ -187,15 +187,19 @@ def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]: new_ann["title"] = annotation["title"] new_ann["type"] = "url_citation" new_ann["url"] = annotation["url"] + + if extra_fields := annotation.get("extras"): + new_ann.update(dict(extra_fields.items())) else: # Document citation new_ann["type"] = "file_citation" + + if extra_fields := annotation.get("extras"): + new_ann.update(dict(extra_fields.items())) + if "title" in annotation: new_ann["filename"] = annotation["title"] - if extra_fields := annotation.get("extras"): - new_ann.update(dict(extra_fields.items())) - return new_ann if annotation["type"] == "non_standard_annotation": diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 2d4cd988bbf..8fffd34e6ae 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -3589,6 +3589,24 @@ def _construct_responses_api_payload( return payload +def _format_annotation_to_lc(annotation: dict[str, Any]) -> dict[str, Any]: + # langchain-core reserves the `"index"` key for streaming aggregation. + # Here we re-name. + if annotation.get("type") == "file_citation" and "index" in annotation: + new_annotation = annotation.copy() + new_annotation["file_index"] = new_annotation.pop("index") + return new_annotation + return annotation + + +def _format_annotation_from_lc(annotation: dict[str, Any]) -> dict[str, Any]: + if annotation.get("type") == "file_citation" and "file_index" in annotation: + new_annotation = annotation.copy() + new_annotation["index"] = new_annotation.pop("file_index") + return new_annotation + return annotation + + def _convert_chat_completions_blocks_to_responses( block: dict[str, Any], ) -> dict[str, Any]: @@ -3775,7 +3793,10 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list: new_block = { "type": "output_text", "text": block["text"], - "annotations": block.get("annotations") or [], + "annotations": [ + _format_annotation_from_lc(annotation) + for annotation in block.get("annotations") or [] + ], } elif block_type == "refusal": new_block = { @@ -3951,7 +3972,7 @@ def _construct_lc_result_from_responses_api( "type": "text", "text": content.text, "annotations": [ - annotation.model_dump() + _format_annotation_to_lc(annotation.model_dump()) for annotation in content.annotations ] if isinstance(content.annotations, list) @@ -4142,7 +4163,11 @@ def _convert_responses_chunk_to_generation_chunk( annotation = chunk.annotation.model_dump(exclude_none=True, mode="json") content.append( - {"type": "text", "annotations": [annotation], "index": current_index} + { + "type": "text", + "annotations": [_format_annotation_to_lc(annotation)], + "index": current_index, + } ) elif chunk.type == "response.output_text.done": content.append({"type": "text", "id": chunk.item_id, "index": current_index}) diff --git a/libs/partners/openai/tests/cassettes/test_file_search.yaml.gz b/libs/partners/openai/tests/cassettes/test_file_search.yaml.gz new file mode 100644 index 00000000000..4c896356533 Binary files /dev/null and b/libs/partners/openai/tests/cassettes/test_file_search.yaml.gz differ diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py index a7df4ef7ef7..8310adc6ed8 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_responses_api.py @@ -34,7 +34,7 @@ def _check_response(response: BaseMessage | None) -> None: if annotation["type"] == "file_citation": assert all( key in annotation - for key in ["file_id", "filename", "index", "type"] + for key in ["file_id", "filename", "file_index", "type"] ) elif annotation["type"] == "web_search": assert all( @@ -374,9 +374,17 @@ def test_computer_calls() -> None: assert response.additional_kwargs["tool_outputs"] -def test_file_search() -> None: - pytest.skip() # TODO: set up infra - llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True) +@pytest.mark.default_cassette("test_file_search.yaml.gz") +@pytest.mark.vcr +@pytest.mark.parametrize("output_version", ["responses/v1", "v1"]) +def test_file_search( + output_version: Literal["responses/v1", "v1"], +) -> None: + llm = ChatOpenAI( + model=MODEL_NAME, + use_responses_api=True, + output_version=output_version, + ) tool = { "type": "file_search", "vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]], @@ -386,16 +394,44 @@ def test_file_search() -> None: response = llm.invoke([input_message], tools=[tool]) _check_response(response) - full: BaseMessageChunk | None = None + if output_version == "v1": + assert [block["type"] for block in response.content] == [ # type: ignore[index] + "server_tool_call", + "server_tool_result", + "text", + ] + else: + assert [block["type"] for block in response.content] == [ # type: ignore[index] + "file_search_call", + "text", + ] + + full: AIMessageChunk | None = None for chunk in llm.stream([input_message], tools=[tool]): assert isinstance(chunk, AIMessageChunk) full = chunk if full is None else full + chunk assert isinstance(full, AIMessageChunk) _check_response(full) + if output_version == "v1": + assert [block["type"] for block in full.content] == [ # type: ignore[index] + "server_tool_call", + "server_tool_result", + "text", + ] + else: + assert [block["type"] for block in full.content] == ["file_search_call", "text"] # type: ignore[index] + next_message = {"role": "user", "content": "Thank you."} _ = llm.invoke([input_message, full, next_message]) + for message in [response, full]: + assert [block["type"] for block in message.content_blocks] == [ + "server_tool_call", + "server_tool_result", + "text", + ] + @pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz") @pytest.mark.vcr