community[patch]: callbacks mypy fixes (#17058)

Related to #17048
This commit is contained in:
Bagatur
2024-02-05 12:37:27 -08:00
committed by GitHub
parent 75b6fa1134
commit af5ae24af2
12 changed files with 49 additions and 59 deletions

View File

@@ -1,6 +1,6 @@
import os
import warnings
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
@@ -269,8 +269,8 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
):
return
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
str(kwargs["run_id"])
prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
List, self.prompts.get(str(kwargs["run_id"]), [])
)
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, list):
@@ -283,10 +283,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
"response": output["text"].strip(),
},
}
for prompt, output in zip(
prompts, # type: ignore
chain_output_val,
)
for prompt, output in zip(prompts, chain_output_val)
]
)
else:
@@ -295,7 +292,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
records=[
{
"fields": {
"prompt": " ".join(prompts), # type: ignore
"prompt": " ".join(prompts),
"response": chain_output_val.strip(),
},
}