mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 03:26:17 +00:00
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
@@ -269,8 +269,8 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
|
||||
):
|
||||
return
|
||||
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
|
||||
str(kwargs["run_id"])
|
||||
prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
|
||||
List, self.prompts.get(str(kwargs["run_id"]), [])
|
||||
)
|
||||
for chain_output_key, chain_output_val in outputs.items():
|
||||
if isinstance(chain_output_val, list):
|
||||
@@ -283,10 +283,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
"response": output["text"].strip(),
|
||||
},
|
||||
}
|
||||
for prompt, output in zip(
|
||||
prompts, # type: ignore
|
||||
chain_output_val,
|
||||
)
|
||||
for prompt, output in zip(prompts, chain_output_val)
|
||||
]
|
||||
)
|
||||
else:
|
||||
@@ -295,7 +292,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
records=[
|
||||
{
|
||||
"fields": {
|
||||
"prompt": " ".join(prompts), # type: ignore
|
||||
"prompt": " ".join(prompts),
|
||||
"response": chain_output_val.strip(),
|
||||
},
|
||||
}
|
||||
|
Reference in New Issue
Block a user