mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
Fix ClearML callback (#11472)
Handle different field names in dicts/dataframes, fixing the ClearML callback. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
7ae8b7f065
commit
4ba2c8ba75
@ -1,7 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.utils import (
|
from langchain.callbacks.utils import (
|
||||||
@ -15,6 +17,9 @@ from langchain.callbacks.utils import (
|
|||||||
)
|
)
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
def import_clearml() -> Any:
|
def import_clearml() -> Any:
|
||||||
"""Import the clearml python package and raise an error if it is not installed."""
|
"""Import the clearml python package and raise an error if it is not installed."""
|
||||||
@ -173,7 +178,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
resp.update(flatten_dict(serialized))
|
resp.update(flatten_dict(serialized))
|
||||||
resp.update(self.get_custom_callback_meta())
|
resp.update(self.get_custom_callback_meta())
|
||||||
|
|
||||||
chain_input = inputs["input"]
|
chain_input = inputs.get("input", inputs.get("human_input"))
|
||||||
|
|
||||||
if isinstance(chain_input, str):
|
if isinstance(chain_input, str):
|
||||||
input_resp = deepcopy(resp)
|
input_resp = deepcopy(resp)
|
||||||
@ -200,7 +205,12 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
resp = self._init_resp()
|
resp = self._init_resp()
|
||||||
resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
|
resp.update(
|
||||||
|
{
|
||||||
|
"action": "on_chain_end",
|
||||||
|
"outputs": outputs.get("output", outputs.get("text")),
|
||||||
|
}
|
||||||
|
)
|
||||||
resp.update(self.get_custom_callback_meta())
|
resp.update(self.get_custom_callback_meta())
|
||||||
|
|
||||||
self.on_chain_end_records.append(resp)
|
self.on_chain_end_records.append(resp)
|
||||||
@ -372,16 +382,31 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_llm_df(
|
||||||
|
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
base_df_fields = [field for field in base_df_fields if field in base_df]
|
||||||
|
rename_map = {
|
||||||
|
map_entry_k: map_entry_v
|
||||||
|
for map_entry_k, map_entry_v in rename_map.items()
|
||||||
|
if map_entry_k in base_df_fields
|
||||||
|
}
|
||||||
|
llm_df = base_df[base_df_fields].dropna(axis=1)
|
||||||
|
if rename_map:
|
||||||
|
llm_df = llm_df.rename(rename_map, axis=1)
|
||||||
|
return llm_df
|
||||||
|
|
||||||
def _create_session_analysis_df(self) -> Any:
|
def _create_session_analysis_df(self) -> Any:
|
||||||
"""Create a dataframe with all the information from the session."""
|
"""Create a dataframe with all the information from the session."""
|
||||||
pd = import_pandas()
|
pd = import_pandas()
|
||||||
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
|
|
||||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||||
|
|
||||||
llm_input_prompts_df = (
|
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
|
||||||
on_llm_start_records_df[["step", "prompts", "name"]]
|
base_df=on_llm_end_records_df,
|
||||||
.dropna(axis=1)
|
base_df_fields=["step", "prompts"]
|
||||||
.rename({"step": "prompt_step"}, axis=1)
|
+ (["name"] if "name" in on_llm_end_records_df else ["id"]),
|
||||||
|
rename_map={"step": "prompt_step"},
|
||||||
)
|
)
|
||||||
complexity_metrics_columns = []
|
complexity_metrics_columns = []
|
||||||
visualizations_columns: List = []
|
visualizations_columns: List = []
|
||||||
@ -406,8 +431,8 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
"osman",
|
"osman",
|
||||||
]
|
]
|
||||||
|
|
||||||
llm_outputs_df = (
|
llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
|
||||||
on_llm_end_records_df[
|
on_llm_end_records_df,
|
||||||
[
|
[
|
||||||
"step",
|
"step",
|
||||||
"text",
|
"text",
|
||||||
@ -416,20 +441,10 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
"token_usage_completion_tokens",
|
"token_usage_completion_tokens",
|
||||||
]
|
]
|
||||||
+ complexity_metrics_columns
|
+ complexity_metrics_columns
|
||||||
+ visualizations_columns
|
+ visualizations_columns,
|
||||||
]
|
{"step": "output_step", "text": "output"},
|
||||||
.dropna(axis=1)
|
|
||||||
.rename({"step": "output_step", "text": "output"}, axis=1)
|
|
||||||
)
|
)
|
||||||
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
||||||
# session_analysis_df["chat_html"] = session_analysis_df[
|
|
||||||
# ["prompts", "output"]
|
|
||||||
# ].apply(
|
|
||||||
# lambda row: construct_html_from_prompt_and_generation(
|
|
||||||
# row["prompts"], row["output"]
|
|
||||||
# ),
|
|
||||||
# axis=1,
|
|
||||||
# )
|
|
||||||
return session_analysis_df
|
return session_analysis_df
|
||||||
|
|
||||||
def flush_tracker(
|
def flush_tracker(
|
||||||
|
Loading…
Reference in New Issue
Block a user