mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 00:29:57 +00:00
community[minor]: Improve mlflow callback (#15691)
- **Description:** Allow passing run_id to MLflowCallbackHandler to resume a run instead of creating a new run. Support recording retriever relevant metrics. Refactor the code to fix some bugs. --------- Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
This commit is contained in:
parent
92e6a641fd
commit
90f5a1c40e
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
@ -5,10 +6,11 @@ import tempfile
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
@ -21,6 +23,8 @@ from langchain_community.callbacks.utils import (
|
||||
import_textstat,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_mlflow() -> Any:
|
||||
"""Import the mlflow python package and raise an error if it is not installed."""
|
||||
@ -34,6 +38,47 @@ def import_mlflow() -> Any:
|
||||
return mlflow
|
||||
|
||||
|
||||
def mlflow_callback_metrics() -> List[str]:
|
||||
return [
|
||||
"step",
|
||||
"starts",
|
||||
"ends",
|
||||
"errors",
|
||||
"text_ctr",
|
||||
"chain_starts",
|
||||
"chain_ends",
|
||||
"llm_starts",
|
||||
"llm_ends",
|
||||
"llm_streams",
|
||||
"tool_starts",
|
||||
"tool_ends",
|
||||
"agent_ends",
|
||||
"retriever_starts",
|
||||
"retriever_ends",
|
||||
]
|
||||
|
||||
|
||||
def get_text_complexity_metrics() -> List[str]:
|
||||
return [
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"smog_index",
|
||||
"coleman_liau_index",
|
||||
"automated_readability_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"linsear_write_formula",
|
||||
"gunning_fog",
|
||||
# "text_standard"
|
||||
"fernandez_huerta",
|
||||
"szigriszt_pazos",
|
||||
"gutierrez_polini",
|
||||
"crawford",
|
||||
"gulpease_index",
|
||||
"osman",
|
||||
]
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: str,
|
||||
nlp: Any = None,
|
||||
@ -52,22 +97,7 @@ def analyze_text(
|
||||
textstat = import_textstat()
|
||||
spacy = import_spacy()
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(text),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
# "text_standard": textstat.text_standard(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
|
||||
}
|
||||
resp.update({"text_complexity_metrics": text_complexity_metrics})
|
||||
resp.update(text_complexity_metrics)
|
||||
@ -140,58 +170,64 @@ class MlflowLogger:
|
||||
)
|
||||
self.mlflow.set_tracking_uri(tracking_uri)
|
||||
|
||||
# User can set other env variables described here
|
||||
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
|
||||
|
||||
experiment_name = get_from_dict_or_env(
|
||||
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
|
||||
)
|
||||
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
|
||||
if self.mlf_exp is not None:
|
||||
self.mlf_expid = self.mlf_exp.experiment_id
|
||||
if run_id := kwargs.get("run_id"):
|
||||
self.mlf_expid = self.mlflow.get_run(run_id).info.experiment_id
|
||||
else:
|
||||
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
|
||||
# User can set other env variables described here
|
||||
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
|
||||
|
||||
self.start_run(kwargs["run_name"], kwargs["run_tags"])
|
||||
experiment_name = get_from_dict_or_env(
|
||||
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
|
||||
)
|
||||
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
|
||||
if self.mlf_exp is not None:
|
||||
self.mlf_expid = self.mlf_exp.experiment_id
|
||||
else:
|
||||
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
|
||||
|
||||
def start_run(self, name: str, tags: Dict[str, str]) -> None:
|
||||
"""To start a new run, auto generates the random suffix for name"""
|
||||
if name.endswith("-%"):
|
||||
rname = "".join(random.choices(string.ascii_uppercase + string.digits, k=7))
|
||||
name = name.replace("%", rname)
|
||||
self.run = self.mlflow.MlflowClient().create_run(
|
||||
self.mlf_expid, run_name=name, tags=tags
|
||||
self.start_run(
|
||||
kwargs["run_name"], kwargs["run_tags"], kwargs.get("run_id", None)
|
||||
)
|
||||
self.dir = kwargs.get("artifacts_dir", "")
|
||||
|
||||
def start_run(
|
||||
self, name: str, tags: Dict[str, str], run_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
If run_id is provided, it will reuse the run with the given run_id.
|
||||
Otherwise, it starts a new run, auto generates the random suffix for name.
|
||||
"""
|
||||
if run_id is None:
|
||||
if name.endswith("-%"):
|
||||
rname = "".join(
|
||||
random.choices(string.ascii_uppercase + string.digits, k=7)
|
||||
)
|
||||
name = name[:-1] + rname
|
||||
run = self.mlflow.MlflowClient().create_run(
|
||||
self.mlf_expid, run_name=name, tags=tags
|
||||
)
|
||||
run_id = run.info.run_id
|
||||
self.run_id = run_id
|
||||
|
||||
def finish_run(self) -> None:
|
||||
"""To finish the run."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.end_run()
|
||||
self.mlflow.end_run()
|
||||
|
||||
def metric(self, key: str, value: float) -> None:
|
||||
"""To log metric to mlflow server."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_metric(key, value)
|
||||
self.mlflow.log_metric(key, value, run_id=self.run_id)
|
||||
|
||||
def metrics(
|
||||
self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0
|
||||
) -> None:
|
||||
"""To log all metrics in the input dict."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_metrics(data)
|
||||
self.mlflow.log_metrics(data, run_id=self.run_id)
|
||||
|
||||
def jsonf(self, data: Dict[str, Any], filename: str) -> None:
|
||||
"""To log the input data as json file artifact."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_dict(data, f"{filename}.json")
|
||||
self.mlflow.log_dict(
|
||||
data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id
|
||||
)
|
||||
|
||||
def table(self, name: str, dataframe) -> None: # type: ignore
|
||||
"""To log the input pandas dataframe as a html table"""
|
||||
@ -199,30 +235,22 @@ class MlflowLogger:
|
||||
|
||||
def html(self, html: str, filename: str) -> None:
|
||||
"""To log the input html string as html file artifact."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_text(html, f"{filename}.html")
|
||||
self.mlflow.log_text(
|
||||
html, os.path.join(self.dir, f"{filename}.html"), run_id=self.run_id
|
||||
)
|
||||
|
||||
def text(self, text: str, filename: str) -> None:
|
||||
"""To log the input text as text file artifact."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_text(text, f"{filename}.txt")
|
||||
self.mlflow.log_text(
|
||||
text, os.path.join(self.dir, f"{filename}.txt"), run_id=self.run_id
|
||||
)
|
||||
|
||||
def artifact(self, path: str) -> None:
|
||||
"""To upload the file from given path as artifact."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_artifact(path)
|
||||
self.mlflow.log_artifact(path, run_id=self.run_id)
|
||||
|
||||
def langchain_artifact(self, chain: Any) -> None:
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.langchain.log_model(chain, "langchain-model")
|
||||
self.mlflow.langchain.log_model(chain, "langchain-model", run_id=self.run_id)
|
||||
|
||||
|
||||
class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
@ -246,6 +274,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
experiment: Optional[str] = "langchain",
|
||||
tags: Optional[Dict] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
artifacts_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
import_pandas()
|
||||
@ -258,6 +288,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.experiment = experiment
|
||||
self.tags = tags or {}
|
||||
self.tracking_uri = tracking_uri
|
||||
self.run_id = run_id
|
||||
self.artifacts_dir = artifacts_dir
|
||||
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
@ -266,26 +298,21 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
experiment_name=self.experiment,
|
||||
run_name=self.name,
|
||||
run_tags=self.tags,
|
||||
run_id=self.run_id,
|
||||
artifacts_dir=self.artifacts_dir,
|
||||
)
|
||||
|
||||
self.action_records: list = []
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
try:
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
except OSError:
|
||||
logger.warning(
|
||||
"Run `python -m spacy download en_core_web_sm` "
|
||||
"to download en_core_web_sm model for text visualization."
|
||||
)
|
||||
self.nlp = None
|
||||
|
||||
self.metrics = {
|
||||
"step": 0,
|
||||
"starts": 0,
|
||||
"ends": 0,
|
||||
"errors": 0,
|
||||
"text_ctr": 0,
|
||||
"chain_starts": 0,
|
||||
"chain_ends": 0,
|
||||
"llm_starts": 0,
|
||||
"llm_ends": 0,
|
||||
"llm_streams": 0,
|
||||
"tool_starts": 0,
|
||||
"tool_ends": 0,
|
||||
"agent_ends": 0,
|
||||
}
|
||||
self.metrics = {key: 0 for key in mlflow_callback_metrics()}
|
||||
|
||||
self.records: Dict[str, Any] = {
|
||||
"on_llm_start_records": [],
|
||||
@ -298,6 +325,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"on_text_records": [],
|
||||
"on_agent_finish_records": [],
|
||||
"on_agent_action_records": [],
|
||||
"on_retriever_start_records": [],
|
||||
"on_retriever_end_records": [],
|
||||
"action_records": [],
|
||||
}
|
||||
|
||||
@ -383,10 +412,14 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.records["on_llm_end_records"].append(generation_resp)
|
||||
self.records["action_records"].append(generation_resp)
|
||||
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
|
||||
dependency_tree = generation_resp["dependency_tree"]
|
||||
entities = generation_resp["entities"]
|
||||
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
|
||||
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
|
||||
if "dependency_tree" in generation_resp:
|
||||
dependency_tree = generation_resp["dependency_tree"]
|
||||
self.mlflg.html(
|
||||
dependency_tree, "dep-" + hash_string(generation.text)
|
||||
)
|
||||
if "entities" in generation_resp:
|
||||
entities = generation_resp["entities"]
|
||||
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
@ -410,14 +443,21 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
||||
if isinstance(inputs, dict):
|
||||
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
||||
elif isinstance(inputs, list):
|
||||
chain_input = ",".join([str(input) for input in inputs])
|
||||
else:
|
||||
chain_input = str(inputs)
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["inputs"] = chain_input
|
||||
self.records["on_chain_start_records"].append(input_resp)
|
||||
self.records["action_records"].append(input_resp)
|
||||
self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], str, List[str]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_ends"] += 1
|
||||
@ -426,7 +466,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
chain_ends = self.metrics["chain_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
||||
if isinstance(outputs, dict):
|
||||
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
||||
elif isinstance(outputs, list):
|
||||
chain_output = ",".join(map(str, outputs))
|
||||
else:
|
||||
chain_output = str(outputs)
|
||||
resp.update({"action": "on_chain_end", "outputs": chain_output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
@ -487,7 +532,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
Run when text is received.
|
||||
"""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["text_ctr"] += 1
|
||||
@ -549,6 +594,69 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"agent_action_{tool_starts}")
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["retriever_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
retriever_starts = self.metrics["retriever_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_retriever_start", "query": query})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_retriever_start_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"retriever_start_{retriever_starts}")
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["retriever_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
retriever_ends = self.metrics["retriever_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
retriever_documents = [
|
||||
{
|
||||
"page_content": doc.page_content,
|
||||
"metadata": {
|
||||
k: str(v)
|
||||
if not isinstance(v, list)
|
||||
else ",".join(str(x) for x in v)
|
||||
for k, v in doc.metadata.items()
|
||||
},
|
||||
}
|
||||
for doc in documents
|
||||
]
|
||||
resp.update({"action": "on_retriever_end", "documents": retriever_documents})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_retriever_end_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"retriever_end_{retriever_ends}")
|
||||
|
||||
def on_retriever_error(self, error: BaseException, **kwargs: Any) -> Any:
|
||||
"""Run when Retriever errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
@ -570,39 +678,27 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "prompt_step"}, axis=1)
|
||||
)
|
||||
complexity_metrics_columns = []
|
||||
visualizations_columns = []
|
||||
complexity_metrics_columns = get_text_complexity_metrics()
|
||||
visualizations_columns = (
|
||||
["dependency_tree", "entities"] if self.nlp is not None else []
|
||||
)
|
||||
|
||||
complexity_metrics_columns = [
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"smog_index",
|
||||
"coleman_liau_index",
|
||||
"automated_readability_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"linsear_write_formula",
|
||||
"gunning_fog",
|
||||
# "text_standard",
|
||||
"fernandez_huerta",
|
||||
"szigriszt_pazos",
|
||||
"gutierrez_polini",
|
||||
"crawford",
|
||||
"gulpease_index",
|
||||
"osman",
|
||||
token_usage_columns = [
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
token_usage_columns = [
|
||||
x for x in token_usage_columns if x in on_llm_end_records_df.columns
|
||||
]
|
||||
|
||||
visualizations_columns = ["dependency_tree", "entities"]
|
||||
|
||||
llm_outputs_df = (
|
||||
on_llm_end_records_df[
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
+ token_usage_columns
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns
|
||||
]
|
||||
@ -620,14 +716,18 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
)
|
||||
return session_analysis_df
|
||||
|
||||
def _contain_llm_records(self):
|
||||
return bool(self.records["on_llm_start_records"])
|
||||
|
||||
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
||||
pd = import_pandas()
|
||||
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
|
||||
session_analysis_df = self._create_session_analysis_df()
|
||||
chat_html = session_analysis_df.pop("chat_html")
|
||||
chat_html = chat_html.replace("\n", "", regex=True)
|
||||
self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
|
||||
self.mlflg.html("".join(chat_html.tolist()), "chat_html")
|
||||
if self._contain_llm_records():
|
||||
session_analysis_df = self._create_session_analysis_df()
|
||||
chat_html = session_analysis_df.pop("chat_html")
|
||||
chat_html = chat_html.replace("\n", "", regex=True)
|
||||
self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
|
||||
self.mlflg.html("".join(chat_html.tolist()), "chat_html")
|
||||
|
||||
if langchain_asset:
|
||||
# To avoid circular import error
|
||||
|
Loading…
Reference in New Issue
Block a user