mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
oai assistant multiple actions (#13068)
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
@@ -212,8 +212,12 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
will return OpenAI types
|
||||
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
|
||||
"""
|
||||
input = self._parse_input(input)
|
||||
if "thread_id" not in input:
|
||||
# Being run within AgentExecutor and there are tool outputs to submit.
|
||||
if self.as_agent and input.get("intermediate_steps"):
|
||||
tool_outputs = self._parse_intermediate_steps(input["intermediate_steps"])
|
||||
run = self.client.beta.threads.runs.submit_tool_outputs(**tool_outputs)
|
||||
# Starting a new thread and a new run.
|
||||
elif "thread_id" not in input:
|
||||
thread = {
|
||||
"messages": [
|
||||
{
|
||||
@@ -226,6 +230,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
"metadata": input.get("thread_metadata"),
|
||||
}
|
||||
run = self._create_thread_and_run(input, thread)
|
||||
# Starting a new run in an existing thread.
|
||||
elif "run_id" not in input:
|
||||
_ = self.client.beta.threads.messages.create(
|
||||
input["thread_id"],
|
||||
@@ -235,21 +240,31 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
metadata=input.get("message_metadata"),
|
||||
)
|
||||
run = self._create_run(input)
|
||||
# Submitting tool outputs to an existing run, outside the AgentExecutor
|
||||
# framework.
|
||||
else:
|
||||
run = self.client.beta.threads.runs.submit_tool_outputs(**input)
|
||||
return self._get_response(run.id, run.thread_id)
|
||||
|
||||
def _parse_input(self, input: dict) -> dict:
|
||||
if self.as_agent and input.get("intermediate_steps"):
|
||||
last_action, last_output = input["intermediate_steps"][-1]
|
||||
input = {
|
||||
"tool_outputs": [
|
||||
{"output": last_output, "tool_call_id": last_action.tool_call_id}
|
||||
],
|
||||
"run_id": last_action.run_id,
|
||||
"thread_id": last_action.thread_id,
|
||||
}
|
||||
return input
|
||||
def _parse_intermediate_steps(
|
||||
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]]
|
||||
) -> dict:
|
||||
last_action, last_output = intermediate_steps[-1]
|
||||
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
|
||||
required_tool_call_ids = {
|
||||
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
|
||||
}
|
||||
tool_outputs = [
|
||||
{"output": output, "tool_call_id": action.tool_call_id}
|
||||
for action, output in intermediate_steps
|
||||
if action.tool_call_id in required_tool_call_ids
|
||||
]
|
||||
submit_tool_outputs = {
|
||||
"tool_outputs": tool_outputs,
|
||||
"run_id": last_action.run_id,
|
||||
"thread_id": last_action.thread_id,
|
||||
}
|
||||
return submit_tool_outputs
|
||||
|
||||
def _create_run(self, input: dict) -> Any:
|
||||
params = {
|
||||
@@ -307,6 +322,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
|
||||
function = tool_call.function
|
||||
args = json.loads(function.arguments)
|
||||
if len(args) == 1 and "__arg1" in args:
|
||||
args = args["__arg1"]
|
||||
actions.append(
|
||||
OpenAIAssistantAction(
|
||||
tool=function.name,
|
||||
@@ -321,7 +338,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
else:
|
||||
run_info = json.dumps(run.dict(), indent=2)
|
||||
raise ValueError(
|
||||
f"Unknown run status {run.status}. Full run info:\n\n{run_info})"
|
||||
f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
|
||||
)
|
||||
|
||||
def _wait_for_run(self, run_id: str, thread_id: str) -> Any:
|
||||
|
Reference in New Issue
Block a user