mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 04:28:58 +00:00
feat(openai): custom tools (#32449)
This commit is contained in:
@@ -3582,6 +3582,20 @@ def _make_computer_call_output_from_message(message: ToolMessage) -> dict:
|
||||
return computer_call_output
|
||||
|
||||
|
||||
def _make_custom_tool_output_from_message(message: ToolMessage) -> Optional[dict]:
|
||||
custom_tool_output = None
|
||||
for block in message.content:
|
||||
if isinstance(block, dict) and block.get("type") == "custom_tool_call_output":
|
||||
custom_tool_output = {
|
||||
"type": "custom_tool_call_output",
|
||||
"call_id": message.tool_call_id,
|
||||
"output": block.get("output") or "",
|
||||
}
|
||||
break
|
||||
|
||||
return custom_tool_output
|
||||
|
||||
|
||||
def _pop_index_and_sub_index(block: dict) -> dict:
|
||||
"""When streaming, langchain-core uses the ``index`` key to aggregate
|
||||
text blocks. OpenAI API does not support this key, so we need to remove it.
|
||||
@@ -3608,7 +3622,10 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
msg.pop("name")
|
||||
if msg["role"] == "tool":
|
||||
tool_output = msg["content"]
|
||||
if lc_msg.additional_kwargs.get("type") == "computer_call_output":
|
||||
custom_tool_output = _make_custom_tool_output_from_message(lc_msg) # type: ignore[arg-type]
|
||||
if custom_tool_output:
|
||||
input_.append(custom_tool_output)
|
||||
elif lc_msg.additional_kwargs.get("type") == "computer_call_output":
|
||||
computer_call_output = _make_computer_call_output_from_message(
|
||||
cast(ToolMessage, lc_msg)
|
||||
)
|
||||
@@ -3663,6 +3680,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
"file_search_call",
|
||||
"function_call",
|
||||
"computer_call",
|
||||
"custom_tool_call",
|
||||
"code_interpreter_call",
|
||||
"mcp_call",
|
||||
"mcp_list_tools",
|
||||
@@ -3690,7 +3708,8 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
content_call_ids = {
|
||||
block["call_id"]
|
||||
for block in input_
|
||||
if block.get("type") == "function_call" and "call_id" in block
|
||||
if block.get("type") in ("function_call", "custom_tool_call")
|
||||
and "call_id" in block
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
if tool_call["id"] not in content_call_ids:
|
||||
@@ -3841,6 +3860,15 @@ def _construct_lc_result_from_responses_api(
|
||||
"error": error,
|
||||
}
|
||||
invalid_tool_calls.append(tool_call)
|
||||
elif output.type == "custom_tool_call":
|
||||
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
|
||||
tool_call = {
|
||||
"type": "tool_call",
|
||||
"name": output.name,
|
||||
"args": {"__arg1": output.input},
|
||||
"id": output.call_id,
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
elif output.type in (
|
||||
"reasoning",
|
||||
"web_search_call",
|
||||
@@ -4044,6 +4072,23 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
|
||||
tool_output["index"] = current_index
|
||||
content.append(tool_output)
|
||||
elif (
|
||||
chunk.type == "response.output_item.done"
|
||||
and chunk.item.type == "custom_tool_call"
|
||||
):
|
||||
_advance(chunk.output_index)
|
||||
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
|
||||
tool_output["index"] = current_index
|
||||
content.append(tool_output)
|
||||
tool_call_chunks.append(
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"name": chunk.item.name,
|
||||
"args": json.dumps({"__arg1": chunk.item.input}),
|
||||
"id": chunk.item.call_id,
|
||||
"index": current_index,
|
||||
}
|
||||
)
|
||||
elif chunk.type == "response.function_call_arguments.delta":
|
||||
_advance(chunk.output_index)
|
||||
tool_call_chunks.append(
|
||||
|
Reference in New Issue
Block a user