mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(openai): update file index key name (#33350)
This commit is contained in:
@@ -187,15 +187,19 @@ def _convert_annotation_from_v1(annotation: types.Annotation) -> dict[str, Any]:
|
||||
new_ann["title"] = annotation["title"]
|
||||
new_ann["type"] = "url_citation"
|
||||
new_ann["url"] = annotation["url"]
|
||||
|
||||
if extra_fields := annotation.get("extras"):
|
||||
new_ann.update(dict(extra_fields.items()))
|
||||
else:
|
||||
# Document citation
|
||||
new_ann["type"] = "file_citation"
|
||||
|
||||
if extra_fields := annotation.get("extras"):
|
||||
new_ann.update(dict(extra_fields.items()))
|
||||
|
||||
if "title" in annotation:
|
||||
new_ann["filename"] = annotation["title"]
|
||||
|
||||
if extra_fields := annotation.get("extras"):
|
||||
new_ann.update(dict(extra_fields.items()))
|
||||
|
||||
return new_ann
|
||||
|
||||
if annotation["type"] == "non_standard_annotation":
|
||||
|
||||
@@ -3589,6 +3589,24 @@ def _construct_responses_api_payload(
|
||||
return payload
|
||||
|
||||
|
||||
def _format_annotation_to_lc(annotation: dict[str, Any]) -> dict[str, Any]:
|
||||
# langchain-core reserves the `"index"` key for streaming aggregation.
|
||||
# Here we re-name.
|
||||
if annotation.get("type") == "file_citation" and "index" in annotation:
|
||||
new_annotation = annotation.copy()
|
||||
new_annotation["file_index"] = new_annotation.pop("index")
|
||||
return new_annotation
|
||||
return annotation
|
||||
|
||||
|
||||
def _format_annotation_from_lc(annotation: dict[str, Any]) -> dict[str, Any]:
|
||||
if annotation.get("type") == "file_citation" and "file_index" in annotation:
|
||||
new_annotation = annotation.copy()
|
||||
new_annotation["index"] = new_annotation.pop("file_index")
|
||||
return new_annotation
|
||||
return annotation
|
||||
|
||||
|
||||
def _convert_chat_completions_blocks_to_responses(
|
||||
block: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
@@ -3775,7 +3793,10 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
new_block = {
|
||||
"type": "output_text",
|
||||
"text": block["text"],
|
||||
"annotations": block.get("annotations") or [],
|
||||
"annotations": [
|
||||
_format_annotation_from_lc(annotation)
|
||||
for annotation in block.get("annotations") or []
|
||||
],
|
||||
}
|
||||
elif block_type == "refusal":
|
||||
new_block = {
|
||||
@@ -3951,7 +3972,7 @@ def _construct_lc_result_from_responses_api(
|
||||
"type": "text",
|
||||
"text": content.text,
|
||||
"annotations": [
|
||||
annotation.model_dump()
|
||||
_format_annotation_to_lc(annotation.model_dump())
|
||||
for annotation in content.annotations
|
||||
]
|
||||
if isinstance(content.annotations, list)
|
||||
@@ -4142,7 +4163,11 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
annotation = chunk.annotation.model_dump(exclude_none=True, mode="json")
|
||||
|
||||
content.append(
|
||||
{"type": "text", "annotations": [annotation], "index": current_index}
|
||||
{
|
||||
"type": "text",
|
||||
"annotations": [_format_annotation_to_lc(annotation)],
|
||||
"index": current_index,
|
||||
}
|
||||
)
|
||||
elif chunk.type == "response.output_text.done":
|
||||
content.append({"type": "text", "id": chunk.item_id, "index": current_index})
|
||||
|
||||
BIN
libs/partners/openai/tests/cassettes/test_file_search.yaml.gz
Normal file
BIN
libs/partners/openai/tests/cassettes/test_file_search.yaml.gz
Normal file
Binary file not shown.
@@ -34,7 +34,7 @@ def _check_response(response: BaseMessage | None) -> None:
|
||||
if annotation["type"] == "file_citation":
|
||||
assert all(
|
||||
key in annotation
|
||||
for key in ["file_id", "filename", "index", "type"]
|
||||
for key in ["file_id", "filename", "file_index", "type"]
|
||||
)
|
||||
elif annotation["type"] == "web_search":
|
||||
assert all(
|
||||
@@ -374,9 +374,17 @@ def test_computer_calls() -> None:
|
||||
assert response.additional_kwargs["tool_outputs"]
|
||||
|
||||
|
||||
def test_file_search() -> None:
|
||||
pytest.skip() # TODO: set up infra
|
||||
llm = ChatOpenAI(model=MODEL_NAME, use_responses_api=True)
|
||||
@pytest.mark.default_cassette("test_file_search.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
|
||||
def test_file_search(
|
||||
output_version: Literal["responses/v1", "v1"],
|
||||
) -> None:
|
||||
llm = ChatOpenAI(
|
||||
model=MODEL_NAME,
|
||||
use_responses_api=True,
|
||||
output_version=output_version,
|
||||
)
|
||||
tool = {
|
||||
"type": "file_search",
|
||||
"vector_store_ids": [os.environ["OPENAI_VECTOR_STORE_ID"]],
|
||||
@@ -386,16 +394,44 @@ def test_file_search() -> None:
|
||||
response = llm.invoke([input_message], tools=[tool])
|
||||
_check_response(response)
|
||||
|
||||
full: BaseMessageChunk | None = None
|
||||
if output_version == "v1":
|
||||
assert [block["type"] for block in response.content] == [ # type: ignore[index]
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"text",
|
||||
]
|
||||
else:
|
||||
assert [block["type"] for block in response.content] == [ # type: ignore[index]
|
||||
"file_search_call",
|
||||
"text",
|
||||
]
|
||||
|
||||
full: AIMessageChunk | None = None
|
||||
for chunk in llm.stream([input_message], tools=[tool]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
_check_response(full)
|
||||
|
||||
if output_version == "v1":
|
||||
assert [block["type"] for block in full.content] == [ # type: ignore[index]
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"text",
|
||||
]
|
||||
else:
|
||||
assert [block["type"] for block in full.content] == ["file_search_call", "text"] # type: ignore[index]
|
||||
|
||||
next_message = {"role": "user", "content": "Thank you."}
|
||||
_ = llm.invoke([input_message, full, next_message])
|
||||
|
||||
for message in [response, full]:
|
||||
assert [block["type"] for block in message.content_blocks] == [
|
||||
"server_tool_call",
|
||||
"server_tool_result",
|
||||
"text",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.default_cassette("test_stream_reasoning_summary.yaml.gz")
|
||||
@pytest.mark.vcr
|
||||
|
||||
Reference in New Issue
Block a user