From 90f5a1c40e1bd8c4bbb40485dc46ac76af89fee1 Mon Sep 17 00:00:00 2001 From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com> Date: Tue, 23 Jan 2024 18:16:51 -0800 Subject: [PATCH] community[minor]: Improve mlflow callback (#15691) - **Description:** Allow passing run_id to MLflowCallbackHandler to resume a run instead of creating a new run. Support recording retriever relevant metrics. Refactor the code to fix some bugs. --------- Signed-off-by: Serena Ruan --- .../callbacks/mlflow_callback.py | 340 +++++++++++------- 1 file changed, 220 insertions(+), 120 deletions(-) diff --git a/libs/community/langchain_community/callbacks/mlflow_callback.py b/libs/community/langchain_community/callbacks/mlflow_callback.py index 6d93125d564..577532a3616 100644 --- a/libs/community/langchain_community/callbacks/mlflow_callback.py +++ b/libs/community/langchain_community/callbacks/mlflow_callback.py @@ -1,3 +1,4 @@ +import logging import os import random import string @@ -5,10 +6,11 @@ import tempfile import traceback from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.documents import Document from langchain_core.outputs import LLMResult from langchain_core.utils import get_from_dict_or_env @@ -21,6 +23,8 @@ from langchain_community.callbacks.utils import ( import_textstat, ) +logger = logging.getLogger(__name__) + def import_mlflow() -> Any: """Import the mlflow python package and raise an error if it is not installed.""" @@ -34,6 +38,47 @@ def import_mlflow() -> Any: return mlflow +def mlflow_callback_metrics() -> List[str]: + return [ + "step", + "starts", + "ends", + "errors", + "text_ctr", + "chain_starts", + "chain_ends", + "llm_starts", + "llm_ends", + "llm_streams", + "tool_starts", + "tool_ends", + "agent_ends", + "retriever_starts", + "retriever_ends", + ] + + +def get_text_complexity_metrics() -> List[str]: + return [ + "flesch_reading_ease", + "flesch_kincaid_grade", + "smog_index", + "coleman_liau_index", + "automated_readability_index", + "dale_chall_readability_score", + "difficult_words", + "linsear_write_formula", + "gunning_fog", + # "text_standard" + "fernandez_huerta", + "szigriszt_pazos", + "gutierrez_polini", + "crawford", + "gulpease_index", + "osman", + ] + + def analyze_text( text: str, nlp: Any = None, @@ -52,22 +97,7 @@ def analyze_text( textstat = import_textstat() spacy = import_spacy() text_complexity_metrics = { - "flesch_reading_ease": textstat.flesch_reading_ease(text), - "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), - "smog_index": textstat.smog_index(text), - "coleman_liau_index": textstat.coleman_liau_index(text), - "automated_readability_index": textstat.automated_readability_index(text), - "dale_chall_readability_score": textstat.dale_chall_readability_score(text), - "difficult_words": textstat.difficult_words(text), - "linsear_write_formula": textstat.linsear_write_formula(text), - "gunning_fog": textstat.gunning_fog(text), - # "text_standard": textstat.text_standard(text), - "fernandez_huerta": textstat.fernandez_huerta(text), - "szigriszt_pazos": textstat.szigriszt_pazos(text), - "gutierrez_polini": textstat.gutierrez_polini(text), - "crawford": textstat.crawford(text), - "gulpease_index": textstat.gulpease_index(text), - "osman": textstat.osman(text), + key: getattr(textstat, key)(text) for key in get_text_complexity_metrics() } resp.update({"text_complexity_metrics": text_complexity_metrics}) resp.update(text_complexity_metrics) @@ -140,58 +170,64 @@ class MlflowLogger: ) self.mlflow.set_tracking_uri(tracking_uri) - # User can set other env variables described here - # > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server - - experiment_name = get_from_dict_or_env( - kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME" - ) - self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name) - if self.mlf_exp is not None: - self.mlf_expid = self.mlf_exp.experiment_id + if run_id := kwargs.get("run_id"): + self.mlf_expid = self.mlflow.get_run(run_id).info.experiment_id else: - self.mlf_expid = self.mlflow.create_experiment(experiment_name) + # User can set other env variables described here + # > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server - self.start_run(kwargs["run_name"], kwargs["run_tags"]) + experiment_name = get_from_dict_or_env( + kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME" + ) + self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name) + if self.mlf_exp is not None: + self.mlf_expid = self.mlf_exp.experiment_id + else: + self.mlf_expid = self.mlflow.create_experiment(experiment_name) - def start_run(self, name: str, tags: Dict[str, str]) -> None: - """To start a new run, auto generates the random suffix for name""" - if name.endswith("-%"): - rname = "".join(random.choices(string.ascii_uppercase + string.digits, k=7)) - name = name.replace("%", rname) - self.run = self.mlflow.MlflowClient().create_run( - self.mlf_expid, run_name=name, tags=tags + self.start_run( + kwargs["run_name"], kwargs["run_tags"], kwargs.get("run_id", None) ) + self.dir = kwargs.get("artifacts_dir", "") + + def start_run( + self, name: str, tags: Dict[str, str], run_id: Optional[str] = None + ) -> None: + """ + If run_id is provided, it will reuse the run with the given run_id. + Otherwise, it starts a new run, auto generates the random suffix for name. + """ + if run_id is None: + if name.endswith("-%"): + rname = "".join( + random.choices(string.ascii_uppercase + string.digits, k=7) + ) + name = name[:-1] + rname + run = self.mlflow.MlflowClient().create_run( + self.mlf_expid, run_name=name, tags=tags + ) + run_id = run.info.run_id + self.run_id = run_id def finish_run(self) -> None: """To finish the run.""" - with self.mlflow.start_run( - run_id=self.run.info.run_id, experiment_id=self.mlf_expid - ): - self.mlflow.end_run() + self.mlflow.end_run() def metric(self, key: str, value: float) -> None: """To log metric to mlflow server.""" - with self.mlflow.start_run( - run_id=self.run.info.run_id, experiment_id=self.mlf_expid - ): - self.mlflow.log_metric(key, value) + self.mlflow.log_metric(key, value, run_id=self.run_id) def metrics( self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0 ) -> None: """To log all metrics in the input dict.""" - with self.mlflow.start_run( - run_id=self.run.info.run_id, experiment_id=self.mlf_expid - ): - self.mlflow.log_metrics(data) + self.mlflow.log_metrics(data, run_id=self.run_id) def jsonf(self, data: Dict[str, Any], filename: str) -> None: """To log the input data as json file artifact.""" - with self.mlflow.start_run( - run_id=self.run.info.run_id, experiment_id=self.mlf_expid - ): - self.mlflow.log_dict(data, f"{filename}.json") + self.mlflow.log_dict( + data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id + ) def table(self, name: str, dataframe) -> None: # type: ignore """To log the input pandas dataframe as a html table""" @@ -199,30 +235,22 @@ class MlflowLogger: def html(self, html: str, filename: str) -> None: """To log the input html string as html file artifact.""" - with self.mlflow.start_run( - run_id=self.run.info.run_id, experiment_id=self.mlf_expid - ): - self.mlflow.log_text(html, f"{filename}.html") + self.mlflow.log_text( + html, os.path.join(self.dir, f"{filename}.html"), run_id=self.run_id + ) def text(self, text: str, filename: str) -> None: """To log the input text as text file artifact.""" - with self.mlflow.start_run( - run_id=self.run.info.run_id, experiment_id=self.mlf_expid - ): - self.mlflow.log_text(text, f"{filename}.txt") + self.mlflow.log_text( + text, os.path.join(self.dir, f"{filename}.txt"), run_id=self.run_id + ) def artifact(self, path: str) -> None: """To upload the file from given path as artifact.""" - with self.mlflow.start_run( - run_id=self.run.info.run_id, experiment_id=self.mlf_expid - ): - self.mlflow.log_artifact(path) + self.mlflow.log_artifact(path, run_id=self.run_id) def langchain_artifact(self, chain: Any) -> None: - with self.mlflow.start_run( - run_id=self.run.info.run_id, experiment_id=self.mlf_expid - ): - self.mlflow.langchain.log_model(chain, "langchain-model") + self.mlflow.langchain.log_model(chain, "langchain-model", run_id=self.run_id) class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): @@ -246,6 +274,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): experiment: Optional[str] = "langchain", tags: Optional[Dict] = None, tracking_uri: Optional[str] = None, + run_id: Optional[str] = None, + artifacts_dir: Optional[str] = None, ) -> None: """Initialize callback handler.""" import_pandas() @@ -258,6 +288,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.experiment = experiment self.tags = tags or {} self.tracking_uri = tracking_uri + self.run_id = run_id + self.artifacts_dir = artifacts_dir self.temp_dir = tempfile.TemporaryDirectory() @@ -266,26 +298,21 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): experiment_name=self.experiment, run_name=self.name, run_tags=self.tags, + run_id=self.run_id, + artifacts_dir=self.artifacts_dir, ) self.action_records: list = [] - self.nlp = spacy.load("en_core_web_sm") + try: + self.nlp = spacy.load("en_core_web_sm") + except OSError: + logger.warning( + "Run `python -m spacy download en_core_web_sm` " + "to download en_core_web_sm model for text visualization." + ) + self.nlp = None - self.metrics = { - "step": 0, - "starts": 0, - "ends": 0, - "errors": 0, - "text_ctr": 0, - "chain_starts": 0, - "chain_ends": 0, - "llm_starts": 0, - "llm_ends": 0, - "llm_streams": 0, - "tool_starts": 0, - "tool_ends": 0, - "agent_ends": 0, - } + self.metrics = {key: 0 for key in mlflow_callback_metrics()} self.records: Dict[str, Any] = { "on_llm_start_records": [], @@ -298,6 +325,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): "on_text_records": [], "on_agent_finish_records": [], "on_agent_action_records": [], + "on_retriever_start_records": [], + "on_retriever_end_records": [], "action_records": [], } @@ -383,10 +412,14 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.records["on_llm_end_records"].append(generation_resp) self.records["action_records"].append(generation_resp) self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}") - dependency_tree = generation_resp["dependency_tree"] - entities = generation_resp["entities"] - self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text)) - self.mlflg.html(entities, "ent-" + hash_string(generation.text)) + if "dependency_tree" in generation_resp: + dependency_tree = generation_resp["dependency_tree"] + self.mlflg.html( + dependency_tree, "dep-" + hash_string(generation.text) + ) + if "entities" in generation_resp: + entities = generation_resp["entities"] + self.mlflg.html(entities, "ent-" + hash_string(generation.text)) def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" @@ -410,14 +443,21 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.mlflg.metrics(self.metrics, step=self.metrics["step"]) - chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) + if isinstance(inputs, dict): + chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) + elif isinstance(inputs, list): + chain_input = ",".join([str(input) for input in inputs]) + else: + chain_input = str(inputs) input_resp = deepcopy(resp) input_resp["inputs"] = chain_input self.records["on_chain_start_records"].append(input_resp) self.records["action_records"].append(input_resp) self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}") - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end( + self, outputs: Union[Dict[str, Any], str, List[str]], **kwargs: Any + ) -> None: """Run when chain ends running.""" self.metrics["step"] += 1 self.metrics["chain_ends"] += 1 @@ -426,7 +466,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): chain_ends = self.metrics["chain_ends"] resp: Dict[str, Any] = {} - chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) + if isinstance(outputs, dict): + chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) + elif isinstance(outputs, list): + chain_output = ",".join(map(str, outputs)) + else: + chain_output = str(outputs) resp.update({"action": "on_chain_end", "outputs": chain_output}) resp.update(self.metrics) @@ -487,7 +532,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): def on_text(self, text: str, **kwargs: Any) -> None: """ - Run when agent is ending. + Run when text is received. """ self.metrics["step"] += 1 self.metrics["text_ctr"] += 1 @@ -549,6 +594,69 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"agent_action_{tool_starts}") + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + **kwargs: Any, + ) -> Any: + """Run when Retriever starts running.""" + self.metrics["step"] += 1 + self.metrics["retriever_starts"] += 1 + self.metrics["starts"] += 1 + + retriever_starts = self.metrics["retriever_starts"] + + resp: Dict[str, Any] = {} + resp.update({"action": "on_retriever_start", "query": query}) + resp.update(flatten_dict(serialized)) + resp.update(self.metrics) + + self.mlflg.metrics(self.metrics, step=self.metrics["step"]) + + self.records["on_retriever_start_records"].append(resp) + self.records["action_records"].append(resp) + self.mlflg.jsonf(resp, f"retriever_start_{retriever_starts}") + + def on_retriever_end( + self, + documents: Sequence[Document], + **kwargs: Any, + ) -> Any: + """Run when Retriever ends running.""" + self.metrics["step"] += 1 + self.metrics["retriever_ends"] += 1 + self.metrics["ends"] += 1 + + retriever_ends = self.metrics["retriever_ends"] + + resp: Dict[str, Any] = {} + retriever_documents = [ + { + "page_content": doc.page_content, + "metadata": { + k: str(v) + if not isinstance(v, list) + else ",".join(str(x) for x in v) + for k, v in doc.metadata.items() + }, + } + for doc in documents + ] + resp.update({"action": "on_retriever_end", "documents": retriever_documents}) + resp.update(self.metrics) + + self.mlflg.metrics(self.metrics, step=self.metrics["step"]) + + self.records["on_retriever_end_records"].append(resp) + self.records["action_records"].append(resp) + self.mlflg.jsonf(resp, f"retriever_end_{retriever_ends}") + + def on_retriever_error(self, error: BaseException, **kwargs: Any) -> Any: + """Run when Retriever errors.""" + self.metrics["step"] += 1 + self.metrics["errors"] += 1 + def _create_session_analysis_df(self) -> Any: """Create a dataframe with all the information from the session.""" pd = import_pandas() @@ -570,39 +678,27 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): .dropna(axis=1) .rename({"step": "prompt_step"}, axis=1) ) - complexity_metrics_columns = [] - visualizations_columns = [] + complexity_metrics_columns = get_text_complexity_metrics() + visualizations_columns = ( + ["dependency_tree", "entities"] if self.nlp is not None else [] + ) - complexity_metrics_columns = [ - "flesch_reading_ease", - "flesch_kincaid_grade", - "smog_index", - "coleman_liau_index", - "automated_readability_index", - "dale_chall_readability_score", - "difficult_words", - "linsear_write_formula", - "gunning_fog", - # "text_standard", - "fernandez_huerta", - "szigriszt_pazos", - "gutierrez_polini", - "crawford", - "gulpease_index", - "osman", + token_usage_columns = [ + "token_usage_total_tokens", + "token_usage_prompt_tokens", + "token_usage_completion_tokens", + ] + token_usage_columns = [ + x for x in token_usage_columns if x in on_llm_end_records_df.columns ] - - visualizations_columns = ["dependency_tree", "entities"] llm_outputs_df = ( on_llm_end_records_df[ [ "step", "text", - "token_usage_total_tokens", - "token_usage_prompt_tokens", - "token_usage_completion_tokens", ] + + token_usage_columns + complexity_metrics_columns + visualizations_columns ] @@ -620,14 +716,18 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): ) return session_analysis_df + def _contain_llm_records(self): + return bool(self.records["on_llm_start_records"]) + def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None: pd = import_pandas() self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"])) - session_analysis_df = self._create_session_analysis_df() - chat_html = session_analysis_df.pop("chat_html") - chat_html = chat_html.replace("\n", "", regex=True) - self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df)) - self.mlflg.html("".join(chat_html.tolist()), "chat_html") + if self._contain_llm_records(): + session_analysis_df = self._create_session_analysis_df() + chat_html = session_analysis_df.pop("chat_html") + chat_html = chat_html.replace("\n", "", regex=True) + self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df)) + self.mlflg.html("".join(chat_html.tolist()), "chat_html") if langchain_asset: # To avoid circular import error