community[minor]: Improve mlflow callback (#15691)

- **Description:** Allow passing run_id to MLflowCallbackHandler to
resume a run instead of creating a new run. Support recording retriever
relevant metrics. Refactor the code to fix some bugs.
---------

Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
This commit is contained in:
Serena Ruan 2024-01-23 18:16:51 -08:00 committed by GitHub
parent 92e6a641fd
commit 90f5a1c40e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
import logging
import os
import random
import string
@ -5,10 +6,11 @@ import tempfile
import traceback
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.documents import Document
from langchain_core.outputs import LLMResult
from langchain_core.utils import get_from_dict_or_env
@ -21,6 +23,8 @@ from langchain_community.callbacks.utils import (
import_textstat,
)
logger = logging.getLogger(__name__)
def import_mlflow() -> Any:
"""Import the mlflow python package and raise an error if it is not installed."""
@ -34,6 +38,47 @@ def import_mlflow() -> Any:
return mlflow
def mlflow_callback_metrics() -> List[str]:
return [
"step",
"starts",
"ends",
"errors",
"text_ctr",
"chain_starts",
"chain_ends",
"llm_starts",
"llm_ends",
"llm_streams",
"tool_starts",
"tool_ends",
"agent_ends",
"retriever_starts",
"retriever_ends",
]
def get_text_complexity_metrics() -> List[str]:
return [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
# "text_standard"
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
def analyze_text(
text: str,
nlp: Any = None,
@ -52,22 +97,7 @@ def analyze_text(
textstat = import_textstat()
spacy = import_spacy()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
# "text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
@ -140,58 +170,64 @@ class MlflowLogger:
)
self.mlflow.set_tracking_uri(tracking_uri)
# User can set other env variables described here
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
experiment_name = get_from_dict_or_env(
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
)
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
if self.mlf_exp is not None:
self.mlf_expid = self.mlf_exp.experiment_id
if run_id := kwargs.get("run_id"):
self.mlf_expid = self.mlflow.get_run(run_id).info.experiment_id
else:
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
# User can set other env variables described here
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
self.start_run(kwargs["run_name"], kwargs["run_tags"])
experiment_name = get_from_dict_or_env(
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
)
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
if self.mlf_exp is not None:
self.mlf_expid = self.mlf_exp.experiment_id
else:
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
def start_run(self, name: str, tags: Dict[str, str]) -> None:
"""To start a new run, auto generates the random suffix for name"""
if name.endswith("-%"):
rname = "".join(random.choices(string.ascii_uppercase + string.digits, k=7))
name = name.replace("%", rname)
self.run = self.mlflow.MlflowClient().create_run(
self.mlf_expid, run_name=name, tags=tags
self.start_run(
kwargs["run_name"], kwargs["run_tags"], kwargs.get("run_id", None)
)
self.dir = kwargs.get("artifacts_dir", "")
def start_run(
self, name: str, tags: Dict[str, str], run_id: Optional[str] = None
) -> None:
"""
If run_id is provided, it will reuse the run with the given run_id.
Otherwise, it starts a new run, auto generates the random suffix for name.
"""
if run_id is None:
if name.endswith("-%"):
rname = "".join(
random.choices(string.ascii_uppercase + string.digits, k=7)
)
name = name[:-1] + rname
run = self.mlflow.MlflowClient().create_run(
self.mlf_expid, run_name=name, tags=tags
)
run_id = run.info.run_id
self.run_id = run_id
def finish_run(self) -> None:
"""To finish the run."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.end_run()
self.mlflow.end_run()
def metric(self, key: str, value: float) -> None:
"""To log metric to mlflow server."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_metric(key, value)
self.mlflow.log_metric(key, value, run_id=self.run_id)
def metrics(
self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0
) -> None:
"""To log all metrics in the input dict."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_metrics(data)
self.mlflow.log_metrics(data, run_id=self.run_id)
def jsonf(self, data: Dict[str, Any], filename: str) -> None:
"""To log the input data as json file artifact."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_dict(data, f"{filename}.json")
self.mlflow.log_dict(
data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id
)
def table(self, name: str, dataframe) -> None: # type: ignore
"""To log the input pandas dataframe as a html table"""
@ -199,30 +235,22 @@ class MlflowLogger:
def html(self, html: str, filename: str) -> None:
"""To log the input html string as html file artifact."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_text(html, f"{filename}.html")
self.mlflow.log_text(
html, os.path.join(self.dir, f"{filename}.html"), run_id=self.run_id
)
def text(self, text: str, filename: str) -> None:
"""To log the input text as text file artifact."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_text(text, f"{filename}.txt")
self.mlflow.log_text(
text, os.path.join(self.dir, f"{filename}.txt"), run_id=self.run_id
)
def artifact(self, path: str) -> None:
"""To upload the file from given path as artifact."""
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.log_artifact(path)
self.mlflow.log_artifact(path, run_id=self.run_id)
def langchain_artifact(self, chain: Any) -> None:
with self.mlflow.start_run(
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
):
self.mlflow.langchain.log_model(chain, "langchain-model")
self.mlflow.langchain.log_model(chain, "langchain-model", run_id=self.run_id)
class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
@ -246,6 +274,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
experiment: Optional[str] = "langchain",
tags: Optional[Dict] = None,
tracking_uri: Optional[str] = None,
run_id: Optional[str] = None,
artifacts_dir: Optional[str] = None,
) -> None:
"""Initialize callback handler."""
import_pandas()
@ -258,6 +288,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.experiment = experiment
self.tags = tags or {}
self.tracking_uri = tracking_uri
self.run_id = run_id
self.artifacts_dir = artifacts_dir
self.temp_dir = tempfile.TemporaryDirectory()
@ -266,26 +298,21 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
experiment_name=self.experiment,
run_name=self.name,
run_tags=self.tags,
run_id=self.run_id,
artifacts_dir=self.artifacts_dir,
)
self.action_records: list = []
self.nlp = spacy.load("en_core_web_sm")
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.warning(
"Run `python -m spacy download en_core_web_sm` "
"to download en_core_web_sm model for text visualization."
)
self.nlp = None
self.metrics = {
"step": 0,
"starts": 0,
"ends": 0,
"errors": 0,
"text_ctr": 0,
"chain_starts": 0,
"chain_ends": 0,
"llm_starts": 0,
"llm_ends": 0,
"llm_streams": 0,
"tool_starts": 0,
"tool_ends": 0,
"agent_ends": 0,
}
self.metrics = {key: 0 for key in mlflow_callback_metrics()}
self.records: Dict[str, Any] = {
"on_llm_start_records": [],
@ -298,6 +325,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"on_text_records": [],
"on_agent_finish_records": [],
"on_agent_action_records": [],
"on_retriever_start_records": [],
"on_retriever_end_records": [],
"action_records": [],
}
@ -383,10 +412,14 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.records["on_llm_end_records"].append(generation_resp)
self.records["action_records"].append(generation_resp)
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
dependency_tree = generation_resp["dependency_tree"]
entities = generation_resp["entities"]
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
if "dependency_tree" in generation_resp:
dependency_tree = generation_resp["dependency_tree"]
self.mlflg.html(
dependency_tree, "dep-" + hash_string(generation.text)
)
if "entities" in generation_resp:
entities = generation_resp["entities"]
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
@ -410,14 +443,21 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
if isinstance(inputs, dict):
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
elif isinstance(inputs, list):
chain_input = ",".join([str(input) for input in inputs])
else:
chain_input = str(inputs)
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.records["on_chain_start_records"].append(input_resp)
self.records["action_records"].append(input_resp)
self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(
self, outputs: Union[Dict[str, Any], str, List[str]], **kwargs: Any
) -> None:
"""Run when chain ends running."""
self.metrics["step"] += 1
self.metrics["chain_ends"] += 1
@ -426,7 +466,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
chain_ends = self.metrics["chain_ends"]
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
if isinstance(outputs, dict):
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
elif isinstance(outputs, list):
chain_output = ",".join(map(str, outputs))
else:
chain_output = str(outputs)
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.metrics)
@ -487,7 +532,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
Run when text is received.
"""
self.metrics["step"] += 1
self.metrics["text_ctr"] += 1
@ -549,6 +594,69 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"agent_action_{tool_starts}")
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
self.metrics["step"] += 1
self.metrics["retriever_starts"] += 1
self.metrics["starts"] += 1
retriever_starts = self.metrics["retriever_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_retriever_start", "query": query})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_retriever_start_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"retriever_start_{retriever_starts}")
def on_retriever_end(
self,
documents: Sequence[Document],
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
self.metrics["step"] += 1
self.metrics["retriever_ends"] += 1
self.metrics["ends"] += 1
retriever_ends = self.metrics["retriever_ends"]
resp: Dict[str, Any] = {}
retriever_documents = [
{
"page_content": doc.page_content,
"metadata": {
k: str(v)
if not isinstance(v, list)
else ",".join(str(x) for x in v)
for k, v in doc.metadata.items()
},
}
for doc in documents
]
resp.update({"action": "on_retriever_end", "documents": retriever_documents})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_retriever_end_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"retriever_end_{retriever_ends}")
def on_retriever_error(self, error: BaseException, **kwargs: Any) -> Any:
"""Run when Retriever errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
@ -570,39 +678,27 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
.dropna(axis=1)
.rename({"step": "prompt_step"}, axis=1)
)
complexity_metrics_columns = []
visualizations_columns = []
complexity_metrics_columns = get_text_complexity_metrics()
visualizations_columns = (
["dependency_tree", "entities"] if self.nlp is not None else []
)
complexity_metrics_columns = [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
# "text_standard",
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
token_usage_columns = [
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
token_usage_columns = [
x for x in token_usage_columns if x in on_llm_end_records_df.columns
]
visualizations_columns = ["dependency_tree", "entities"]
llm_outputs_df = (
on_llm_end_records_df[
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ token_usage_columns
+ complexity_metrics_columns
+ visualizations_columns
]
@ -620,14 +716,18 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
)
return session_analysis_df
def _contain_llm_records(self):
return bool(self.records["on_llm_start_records"])
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
pd = import_pandas()
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
session_analysis_df = self._create_session_analysis_df()
chat_html = session_analysis_df.pop("chat_html")
chat_html = chat_html.replace("\n", "", regex=True)
self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
self.mlflg.html("".join(chat_html.tolist()), "chat_html")
if self._contain_llm_records():
session_analysis_df = self._create_session_analysis_df()
chat_html = session_analysis_df.pop("chat_html")
chat_html = chat_html.replace("\n", "", regex=True)
self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
self.mlflg.html("".join(chat_html.tolist()), "chat_html")
if langchain_asset:
# To avoid circular import error