community[minor]: Improve mlflow callback (#15691)

- **Description:** Allow passing run_id to MLflowCallbackHandler to
resume a run instead of creating a new run. Support recording retriever
relevant metrics. Refactor the code to fix some bugs.
---------

Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
This commit is contained in:
Serena Ruan 2024-01-23 18:16:51 -08:00 committed by GitHub
parent 92e6a641fd
commit 90f5a1c40e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
import logging
import os import os
import random import random
import string import string
@ -5,10 +6,11 @@ import tempfile
import traceback import traceback
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.documents import Document
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import get_from_dict_or_env
@ -21,6 +23,8 @@ from langchain_community.callbacks.utils import (
import_textstat, import_textstat,
) )
logger = logging.getLogger(__name__)
def import_mlflow() -> Any: def import_mlflow() -> Any:
"""Import the mlflow python package and raise an error if it is not installed.""" """Import the mlflow python package and raise an error if it is not installed."""
@ -34,6 +38,47 @@ def import_mlflow() -> Any:
return mlflow return mlflow
def mlflow_callback_metrics() -> List[str]:
return [
"step",
"starts",
"ends",
"errors",
"text_ctr",
"chain_starts",
"chain_ends",
"llm_starts",
"llm_ends",
"llm_streams",
"tool_starts",
"tool_ends",
"agent_ends",
"retriever_starts",
"retriever_ends",
]
def get_text_complexity_metrics() -> List[str]:
return [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
# "text_standard"
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
def analyze_text( def analyze_text(
text: str, text: str,
nlp: Any = None, nlp: Any = None,
@ -52,22 +97,7 @@ def analyze_text(
textstat = import_textstat() textstat = import_textstat()
spacy = import_spacy() spacy = import_spacy()
text_complexity_metrics = { text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text), key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
# "text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
} }
resp.update({"text_complexity_metrics": text_complexity_metrics}) resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics) resp.update(text_complexity_metrics)
@ -140,58 +170,64 @@ class MlflowLogger:
) )
self.mlflow.set_tracking_uri(tracking_uri) self.mlflow.set_tracking_uri(tracking_uri)
# User can set other env variables described here if run_id := kwargs.get("run_id"):
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server self.mlf_expid = self.mlflow.get_run(run_id).info.experiment_id
experiment_name = get_from_dict_or_env(
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
)
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
if self.mlf_exp is not None:
self.mlf_expid = self.mlf_exp.experiment_id
else: else:
self.mlf_expid = self.mlflow.create_experiment(experiment_name) # User can set other env variables described here
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
self.start_run(kwargs["run_name"], kwargs["run_tags"]) experiment_name = get_from_dict_or_env(
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
)
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
if self.mlf_exp is not None:
self.mlf_expid = self.mlf_exp.experiment_id
else:
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
def start_run(self, name: str, tags: Dict[str, str]) -> None: self.start_run(
"""To start a new run, auto generates the random suffix for name""" kwargs["run_name"], kwargs["run_tags"], kwargs.get("run_id", None)
if name.endswith("-%"):
rname = "".join(random.choices(string.ascii_uppercase + string.digits, k=7))
name = name.replace("%", rname)
self.run = self.mlflow.MlflowClient().create_run(
self.mlf_expid, run_name=name, tags=tags
) )
self.dir = kwargs.get("artifacts_dir", "")
def start_run(
self, name: str, tags: Dict[str, str], run_id: Optional[str] = None
) -> None:
"""
If run_id is provided, it will reuse the run with the given run_id.
Otherwise, it starts a new run, auto generates the random suffix for name.
"""
if run_id is None:
if name.endswith("-%"):
rname = "".join(
random.choices(string.ascii_uppercase + string.digits, k=7)
)
name = name[:-1] + rname
run = self.mlflow.MlflowClient().create_run(
self.mlf_expid, run_name=name, tags=tags
)
run_id = run.info.run_id
self.run_id = run_id
def finish_run(self) -> None: def finish_run(self) -> None:
"""To finish the run.""" """To finish the run."""
with self.mlflow.start_run( self.mlflow.end_run()
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.end_run()
def metric(self, key: str, value: float) -> None: def metric(self, key: str, value: float) -> None:
"""To log metric to mlflow server.""" """To log metric to mlflow server."""
with self.mlflow.start_run( self.mlflow.log_metric(key, value, run_id=self.run_id)
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_metric(key, value)
def metrics( def metrics(
self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0 self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0
) -> None: ) -> None:
"""To log all metrics in the input dict.""" """To log all metrics in the input dict."""
with self.mlflow.start_run( self.mlflow.log_metrics(data, run_id=self.run_id)
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_metrics(data)
def jsonf(self, data: Dict[str, Any], filename: str) -> None: def jsonf(self, data: Dict[str, Any], filename: str) -> None:
"""To log the input data as json file artifact.""" """To log the input data as json file artifact."""
with self.mlflow.start_run( self.mlflow.log_dict(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id
): )
self.mlflow.log_dict(data, f"{filename}.json")
def table(self, name: str, dataframe) -> None: # type: ignore def table(self, name: str, dataframe) -> None: # type: ignore
"""To log the input pandas dataframe as a html table""" """To log the input pandas dataframe as a html table"""
@ -199,30 +235,22 @@ class MlflowLogger:
def html(self, html: str, filename: str) -> None: def html(self, html: str, filename: str) -> None:
"""To log the input html string as html file artifact.""" """To log the input html string as html file artifact."""
with self.mlflow.start_run( self.mlflow.log_text(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid html, os.path.join(self.dir, f"{filename}.html"), run_id=self.run_id
): )
self.mlflow.log_text(html, f"{filename}.html")
def text(self, text: str, filename: str) -> None: def text(self, text: str, filename: str) -> None:
"""To log the input text as text file artifact.""" """To log the input text as text file artifact."""
with self.mlflow.start_run( self.mlflow.log_text(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid text, os.path.join(self.dir, f"{filename}.txt"), run_id=self.run_id
): )
self.mlflow.log_text(text, f"{filename}.txt")
def artifact(self, path: str) -> None: def artifact(self, path: str) -> None:
"""To upload the file from given path as artifact.""" """To upload the file from given path as artifact."""
with self.mlflow.start_run( self.mlflow.log_artifact(path, run_id=self.run_id)
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_artifact(path)
def langchain_artifact(self, chain: Any) -> None: def langchain_artifact(self, chain: Any) -> None:
with self.mlflow.start_run( self.mlflow.langchain.log_model(chain, "langchain-model", run_id=self.run_id)
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.langchain.log_model(chain, "langchain-model")
class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
@ -246,6 +274,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
experiment: Optional[str] = "langchain", experiment: Optional[str] = "langchain",
tags: Optional[Dict] = None, tags: Optional[Dict] = None,
tracking_uri: Optional[str] = None, tracking_uri: Optional[str] = None,
run_id: Optional[str] = None,
artifacts_dir: Optional[str] = None,
) -> None: ) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
import_pandas() import_pandas()
@ -258,6 +288,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.experiment = experiment self.experiment = experiment
self.tags = tags or {} self.tags = tags or {}
self.tracking_uri = tracking_uri self.tracking_uri = tracking_uri
self.run_id = run_id
self.artifacts_dir = artifacts_dir
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
@ -266,26 +298,21 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
experiment_name=self.experiment, experiment_name=self.experiment,
run_name=self.name, run_name=self.name,
run_tags=self.tags, run_tags=self.tags,
run_id=self.run_id,
artifacts_dir=self.artifacts_dir,
) )
self.action_records: list = [] self.action_records: list = []
self.nlp = spacy.load("en_core_web_sm") try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.warning(
"Run `python -m spacy download en_core_web_sm` "
"to download en_core_web_sm model for text visualization."
)
self.nlp = None
self.metrics = { self.metrics = {key: 0 for key in mlflow_callback_metrics()}
"step": 0,
"starts": 0,
"ends": 0,
"errors": 0,
"text_ctr": 0,
"chain_starts": 0,
"chain_ends": 0,
"llm_starts": 0,
"llm_ends": 0,
"llm_streams": 0,
"tool_starts": 0,
"tool_ends": 0,
"agent_ends": 0,
}
self.records: Dict[str, Any] = { self.records: Dict[str, Any] = {
"on_llm_start_records": [], "on_llm_start_records": [],
@ -298,6 +325,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"on_text_records": [], "on_text_records": [],
"on_agent_finish_records": [], "on_agent_finish_records": [],
"on_agent_action_records": [], "on_agent_action_records": [],
"on_retriever_start_records": [],
"on_retriever_end_records": [],
"action_records": [], "action_records": [],
} }
@ -383,10 +412,14 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.records["on_llm_end_records"].append(generation_resp) self.records["on_llm_end_records"].append(generation_resp)
self.records["action_records"].append(generation_resp) self.records["action_records"].append(generation_resp)
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}") self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
dependency_tree = generation_resp["dependency_tree"] if "dependency_tree" in generation_resp:
entities = generation_resp["entities"] dependency_tree = generation_resp["dependency_tree"]
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text)) self.mlflg.html(
self.mlflg.html(entities, "ent-" + hash_string(generation.text)) dependency_tree, "dep-" + hash_string(generation.text)
)
if "entities" in generation_resp:
entities = generation_resp["entities"]
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
@ -410,14 +443,21 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.mlflg.metrics(self.metrics, step=self.metrics["step"])
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) if isinstance(inputs, dict):
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
elif isinstance(inputs, list):
chain_input = ",".join([str(input) for input in inputs])
else:
chain_input = str(inputs)
input_resp = deepcopy(resp) input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input input_resp["inputs"] = chain_input
self.records["on_chain_start_records"].append(input_resp) self.records["on_chain_start_records"].append(input_resp)
self.records["action_records"].append(input_resp) self.records["action_records"].append(input_resp)
self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}") self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(
self, outputs: Union[Dict[str, Any], str, List[str]], **kwargs: Any
) -> None:
"""Run when chain ends running.""" """Run when chain ends running."""
self.metrics["step"] += 1 self.metrics["step"] += 1
self.metrics["chain_ends"] += 1 self.metrics["chain_ends"] += 1
@ -426,7 +466,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
chain_ends = self.metrics["chain_ends"] chain_ends = self.metrics["chain_ends"]
resp: Dict[str, Any] = {} resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) if isinstance(outputs, dict):
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
elif isinstance(outputs, list):
chain_output = ",".join(map(str, outputs))
else:
chain_output = str(outputs)
resp.update({"action": "on_chain_end", "outputs": chain_output}) resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.metrics) resp.update(self.metrics)
@ -487,7 +532,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:
""" """
Run when agent is ending. Run when text is received.
""" """
self.metrics["step"] += 1 self.metrics["step"] += 1
self.metrics["text_ctr"] += 1 self.metrics["text_ctr"] += 1
@ -549,6 +594,69 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.records["action_records"].append(resp) self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"agent_action_{tool_starts}") self.mlflg.jsonf(resp, f"agent_action_{tool_starts}")
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
self.metrics["step"] += 1
self.metrics["retriever_starts"] += 1
self.metrics["starts"] += 1
retriever_starts = self.metrics["retriever_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_retriever_start", "query": query})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_retriever_start_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"retriever_start_{retriever_starts}")
def on_retriever_end(
self,
documents: Sequence[Document],
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
self.metrics["step"] += 1
self.metrics["retriever_ends"] += 1
self.metrics["ends"] += 1
retriever_ends = self.metrics["retriever_ends"]
resp: Dict[str, Any] = {}
retriever_documents = [
{
"page_content": doc.page_content,
"metadata": {
k: str(v)
if not isinstance(v, list)
else ",".join(str(x) for x in v)
for k, v in doc.metadata.items()
},
}
for doc in documents
]
resp.update({"action": "on_retriever_end", "documents": retriever_documents})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_retriever_end_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"retriever_end_{retriever_ends}")
def on_retriever_error(self, error: BaseException, **kwargs: Any) -> Any:
"""Run when Retriever errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def _create_session_analysis_df(self) -> Any: def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session.""" """Create a dataframe with all the information from the session."""
pd = import_pandas() pd = import_pandas()
@ -570,39 +678,27 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
.dropna(axis=1) .dropna(axis=1)
.rename({"step": "prompt_step"}, axis=1) .rename({"step": "prompt_step"}, axis=1)
) )
complexity_metrics_columns = [] complexity_metrics_columns = get_text_complexity_metrics()
visualizations_columns = [] visualizations_columns = (
["dependency_tree", "entities"] if self.nlp is not None else []
)
complexity_metrics_columns = [ token_usage_columns = [
"flesch_reading_ease", "token_usage_total_tokens",
"flesch_kincaid_grade", "token_usage_prompt_tokens",
"smog_index", "token_usage_completion_tokens",
"coleman_liau_index", ]
"automated_readability_index", token_usage_columns = [
"dale_chall_readability_score", x for x in token_usage_columns if x in on_llm_end_records_df.columns
"difficult_words",
"linsear_write_formula",
"gunning_fog",
# "text_standard",
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
] ]
visualizations_columns = ["dependency_tree", "entities"]
llm_outputs_df = ( llm_outputs_df = (
on_llm_end_records_df[ on_llm_end_records_df[
[ [
"step", "step",
"text", "text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
] ]
+ token_usage_columns
+ complexity_metrics_columns + complexity_metrics_columns
+ visualizations_columns + visualizations_columns
] ]
@ -620,14 +716,18 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) )
return session_analysis_df return session_analysis_df
def _contain_llm_records(self):
return bool(self.records["on_llm_start_records"])
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None: def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
pd = import_pandas() pd = import_pandas()
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"])) self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
session_analysis_df = self._create_session_analysis_df() if self._contain_llm_records():
chat_html = session_analysis_df.pop("chat_html") session_analysis_df = self._create_session_analysis_df()
chat_html = chat_html.replace("\n", "", regex=True) chat_html = session_analysis_df.pop("chat_html")
self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df)) chat_html = chat_html.replace("\n", "", regex=True)
self.mlflg.html("".join(chat_html.tolist()), "chat_html") self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
self.mlflg.html("".join(chat_html.tolist()), "chat_html")
if langchain_asset: if langchain_asset:
# To avoid circular import error # To avoid circular import error