[experimental]: minor fix to open assistants code (#24682)

This commit is contained in:
Isaac Francisco 2024-08-15 10:50:57 -07:00 committed by GitHub
parent 2b4fbcb4b4
commit 5150ec3a04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -272,7 +272,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
instructions=instructions, instructions=instructions,
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
model=model, model=model,
file_ids=kwargs.get("file_ids"),
) )
return cls(assistant_id=assistant.id, client=client, **kwargs) return cls(assistant_id=assistant.id, client=client, **kwargs)
@ -287,7 +286,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
thread_id: Existing thread to use. thread_id: Existing thread to use.
run_id: Existing run to use. Should only be supplied when providing run_id: Existing run to use. Should only be supplied when providing
the tool output for a required action after an initial invocation. the tool output for a required action after an initial invocation.
file_ids: File ids to include in new run. Used for retrieval.
message_metadata: Metadata to associate with new message. message_metadata: Metadata to associate with new message.
thread_metadata: Metadata to associate with new thread. Only relevant thread_metadata: Metadata to associate with new thread. Only relevant
when new thread being created. when new thread being created.
@ -327,7 +325,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
{ {
"role": "user", "role": "user",
"content": input["content"], "content": input["content"],
"file_ids": input.get("file_ids", []),
"metadata": input.get("message_metadata"), "metadata": input.get("message_metadata"),
} }
], ],
@ -340,7 +337,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
input["thread_id"], input["thread_id"],
content=input["content"], content=input["content"],
role="user", role="user",
file_ids=input.get("file_ids", []),
metadata=input.get("message_metadata"), metadata=input.get("message_metadata"),
) )
run = self._create_run(input) run = self._create_run(input)
@ -394,7 +390,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
instructions=instructions, instructions=instructions,
tools=openai_tools, # type: ignore tools=openai_tools, # type: ignore
model=model, model=model,
file_ids=kwargs.get("file_ids"),
) )
return cls(assistant_id=assistant.id, async_client=async_client, **kwargs) return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
@ -409,7 +404,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
thread_id: Existing thread to use. thread_id: Existing thread to use.
run_id: Existing run to use. Should only be supplied when providing run_id: Existing run to use. Should only be supplied when providing
the tool output for a required action after an initial invocation. the tool output for a required action after an initial invocation.
file_ids: File ids to include in new run. Used for retrieval.
message_metadata: Metadata to associate with a new message. message_metadata: Metadata to associate with a new message.
thread_metadata: Metadata to associate with new thread. Only relevant thread_metadata: Metadata to associate with new thread. Only relevant
when a new thread is created. when a new thread is created.
@ -439,7 +433,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
try: try:
# Being run within AgentExecutor and there are tool outputs to submit. # Being run within AgentExecutor and there are tool outputs to submit.
if self.as_agent and input.get("intermediate_steps"): if self.as_agent and input.get("intermediate_steps"):
tool_outputs = self._parse_intermediate_steps( tool_outputs = await self._aparse_intermediate_steps(
input["intermediate_steps"] input["intermediate_steps"]
) )
run = await self.async_client.beta.threads.runs.submit_tool_outputs( run = await self.async_client.beta.threads.runs.submit_tool_outputs(
@ -452,7 +446,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
{ {
"role": "user", "role": "user",
"content": input["content"], "content": input["content"],
"file_ids": input.get("file_ids", []),
"metadata": input.get("message_metadata"), "metadata": input.get("message_metadata"),
} }
], ],
@ -465,7 +458,6 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
input["thread_id"], input["thread_id"],
content=input["content"], content=input["content"],
role="user", role="user",
file_ids=input.get("file_ids", []),
metadata=input.get("message_metadata"), metadata=input.get("message_metadata"),
) )
run = await self._acreate_run(input) run = await self._acreate_run(input)
@ -493,6 +485,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
) -> dict: ) -> dict:
last_action, last_output = intermediate_steps[-1] last_action, last_output = intermediate_steps[-1]
run = self._wait_for_run(last_action.run_id, last_action.thread_id) run = self._wait_for_run(last_action.run_id, last_action.thread_id)
required_tool_call_ids = set()
if run.required_action:
required_tool_call_ids = { required_tool_call_ids = {
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
} }
@ -621,6 +615,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
) -> dict: ) -> dict:
last_action, last_output = intermediate_steps[-1] last_action, last_output = intermediate_steps[-1]
run = await self._wait_for_run(last_action.run_id, last_action.thread_id) run = await self._wait_for_run(last_action.run_id, last_action.thread_id)
required_tool_call_ids = set()
if run.required_action:
required_tool_call_ids = { required_tool_call_ids = {
tc.id for tc in run.required_action.submit_tool_outputs.tool_calls tc.id for tc in run.required_action.submit_tool_outputs.tool_calls
} }