From 9b279ac127d26e90faf6b6a40bc9f4ef0cfb7d8b Mon Sep 17 00:00:00 2001 From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:46:46 -0800 Subject: [PATCH] community[patch]: MLflow callback update (#16687) Signed-off-by: Serena Ruan Co-authored-by: Bagatur --- .../callbacks/mlflow_callback.py | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/libs/community/langchain_community/callbacks/mlflow_callback.py b/libs/community/langchain_community/callbacks/mlflow_callback.py index 6bcfb3aa330..f631a5c51c2 100644 --- a/libs/community/langchain_community/callbacks/mlflow_callback.py +++ b/libs/community/langchain_community/callbacks/mlflow_callback.py @@ -82,23 +82,21 @@ def get_text_complexity_metrics() -> List[str]: def analyze_text( text: str, nlp: Any = None, + textstat: Any = None, ) -> dict: """Analyze text using textstat and spacy. Parameters: text (str): The text to analyze. nlp (spacy.lang): The spacy language model to use for visualization. + textstat: The textstat library to use for complexity metrics calculation. Returns: (dict): A dictionary containing the complexity metrics and visualization files serialized to HTML string. """ resp: Dict[str, Any] = {} - try: - textstat = import_textstat() - except ImportError: - pass - else: + if textstat is not None: text_complexity_metrics = { key: getattr(textstat, key)(text) for key in get_text_complexity_metrics() } @@ -304,8 +302,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.nlp = None try: spacy = import_spacy() - except ImportError: - pass + except ImportError as e: + logger.warning(e.msg) else: try: self.nlp = spacy.load("en_core_web_sm") @@ -315,6 +313,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): "to download en_core_web_sm model for text visualization." ) + try: + self.textstat = import_textstat() + except ImportError as e: + logger.warning(e.msg) + self.textstat = None + self.metrics = {key: 0 for key in mlflow_callback_metrics()} self.records: Dict[str, Any] = { @@ -403,15 +407,17 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): analyze_text( generation.text, nlp=self.nlp, + textstat=self.textstat, ) ) - complexity_metrics: Dict[str, float] = generation_resp.pop( - "text_complexity_metrics" - ) - self.mlflg.metrics( - complexity_metrics, - step=self.metrics["step"], - ) + if "text_complexity_metrics" in generation_resp: + complexity_metrics: Dict[str, float] = generation_resp.pop( + "text_complexity_metrics" + ) + self.mlflg.metrics( + complexity_metrics, + step=self.metrics["step"], + ) self.records["on_llm_end_records"].append(generation_resp) self.records["action_records"].append(generation_resp) self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}") @@ -681,7 +687,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): .dropna(axis=1) .rename({"step": "prompt_step"}, axis=1) ) - complexity_metrics_columns = get_text_complexity_metrics() + complexity_metrics_columns = ( + get_text_complexity_metrics() if self.textstat is not None else [] + ) visualizations_columns = ( ["dependency_tree", "entities"] if self.nlp is not None else [] )