mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 09:23:57 +00:00
Fix ClearML callback (#11472)
Handle different field names in dicts/dataframes, fixing the ClearML callback. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
7ae8b7f065
commit
4ba2c8ba75
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.utils import (
|
||||
@ -15,6 +17,9 @@ from langchain.callbacks.utils import (
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def import_clearml() -> Any:
|
||||
"""Import the clearml python package and raise an error if it is not installed."""
|
||||
@ -173,7 +178,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
chain_input = inputs["input"]
|
||||
chain_input = inputs.get("input", inputs.get("human_input"))
|
||||
|
||||
if isinstance(chain_input, str):
|
||||
input_resp = deepcopy(resp)
|
||||
@ -200,7 +205,12 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_chain_end",
|
||||
"outputs": outputs.get("output", outputs.get("text")),
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_chain_end_records.append(resp)
|
||||
@ -372,16 +382,31 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
return resp
|
||||
|
||||
@staticmethod
|
||||
def _build_llm_df(
|
||||
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
|
||||
) -> pd.DataFrame:
|
||||
base_df_fields = [field for field in base_df_fields if field in base_df]
|
||||
rename_map = {
|
||||
map_entry_k: map_entry_v
|
||||
for map_entry_k, map_entry_v in rename_map.items()
|
||||
if map_entry_k in base_df_fields
|
||||
}
|
||||
llm_df = base_df[base_df_fields].dropna(axis=1)
|
||||
if rename_map:
|
||||
llm_df = llm_df.rename(rename_map, axis=1)
|
||||
return llm_df
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_input_prompts_df = (
|
||||
on_llm_start_records_df[["step", "prompts", "name"]]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "prompt_step"}, axis=1)
|
||||
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
|
||||
base_df=on_llm_end_records_df,
|
||||
base_df_fields=["step", "prompts"]
|
||||
+ (["name"] if "name" in on_llm_end_records_df else ["id"]),
|
||||
rename_map={"step": "prompt_step"},
|
||||
)
|
||||
complexity_metrics_columns = []
|
||||
visualizations_columns: List = []
|
||||
@ -406,30 +431,20 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"osman",
|
||||
]
|
||||
|
||||
llm_outputs_df = (
|
||||
on_llm_end_records_df[
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns
|
||||
llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
|
||||
on_llm_end_records_df,
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "output_step", "text": "output"}, axis=1)
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns,
|
||||
{"step": "output_step", "text": "output"},
|
||||
)
|
||||
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
||||
# session_analysis_df["chat_html"] = session_analysis_df[
|
||||
# ["prompts", "output"]
|
||||
# ].apply(
|
||||
# lambda row: construct_html_from_prompt_and_generation(
|
||||
# row["prompts"], row["output"]
|
||||
# ),
|
||||
# axis=1,
|
||||
# )
|
||||
return session_analysis_df
|
||||
|
||||
def flush_tracker(
|
||||
|
Loading…
Reference in New Issue
Block a user