oai assistant multiple actions (#13068)

This commit is contained in:
Bagatur
2023-11-08 08:25:37 -08:00
committed by GitHub
parent a9b70baef9
commit 55aeff6777
2 changed files with 78 additions and 90 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json
from time import sleep
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain.pydantic_v1 import Field
from langchain.schema.agent import AgentAction, AgentFinish
@@ -212,8 +212,12 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
will return OpenAI types
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
"""
input = self._parse_input(input)
if "thread_id" not in input:
# Being run within AgentExecutor and there are tool outputs to submit.
if self.as_agent and input.get("intermediate_steps"):
tool_outputs = self._parse_intermediate_steps(input["intermediate_steps"])
run = self.client.beta.threads.runs.submit_tool_outputs(**tool_outputs)
# Starting a new thread and a new run.
elif "thread_id" not in input:
thread = {
"messages": [
{
@@ -226,6 +230,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
"metadata": input.get("thread_metadata"),
}
run = self._create_thread_and_run(input, thread)
# Starting a new run in an existing thread.
elif "run_id" not in input:
_ = self.client.beta.threads.messages.create(
input["thread_id"],
@@ -235,21 +240,31 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
metadata=input.get("message_metadata"),
)
run = self._create_run(input)
# Submitting tool outputs to an existing run, outside the AgentExecutor
# framework.
else:
run = self.client.beta.threads.runs.submit_tool_outputs(**input)
return self._get_response(run.id, run.thread_id)
def _parse_input(self, input: dict) -> dict:
if self.as_agent and input.get("intermediate_steps"):
last_action, last_output = input["intermediate_steps"][-1]
input = {
"tool_outputs": [
{"output": last_output, "tool_call_id": last_action.tool_call_id}
],
"run_id": last_action.run_id,
"thread_id": last_action.thread_id,
}
return input
def _parse_intermediate_steps(
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]]
) -> dict:
last_action, last_output = intermediate_steps[-1]
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
required_tool_call_ids = {
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
}
tool_outputs = [
{"output": output, "tool_call_id": action.tool_call_id}
for action, output in intermediate_steps
if action.tool_call_id in required_tool_call_ids
]
submit_tool_outputs = {
"tool_outputs": tool_outputs,
"run_id": last_action.run_id,
"thread_id": last_action.thread_id,
}
return submit_tool_outputs
def _create_run(self, input: dict) -> Any:
params = {
@@ -307,6 +322,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
function = tool_call.function
args = json.loads(function.arguments)
if len(args) == 1 and "__arg1" in args:
args = args["__arg1"]
actions.append(
OpenAIAssistantAction(
tool=function.name,
@@ -321,7 +338,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
else:
run_info = json.dumps(run.dict(), indent=2)
raise ValueError(
f"Unknown run status {run.status}. Full run info:\n\n{run_info})"
f"Unexpected run status: {run.status}. Full run info:\n\n{run_info})"
)
def _wait_for_run(self, run_id: str, thread_id: str) -> Any: