community[patch]: MLflow callback update (#16687)

Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Serena Ruan 2024-02-05 15:46:46 -08:00 committed by GitHub
parent 3c4b24b69a
commit 9b279ac127
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -82,23 +82,21 @@ def get_text_complexity_metrics() -> List[str]:
def analyze_text( def analyze_text(
text: str, text: str,
nlp: Any = None, nlp: Any = None,
textstat: Any = None,
) -> dict: ) -> dict:
"""Analyze text using textstat and spacy. """Analyze text using textstat and spacy.
Parameters: Parameters:
text (str): The text to analyze. text (str): The text to analyze.
nlp (spacy.lang): The spacy language model to use for visualization. nlp (spacy.lang): The spacy language model to use for visualization.
textstat: The textstat library to use for complexity metrics calculation.
Returns: Returns:
(dict): A dictionary containing the complexity metrics and visualization (dict): A dictionary containing the complexity metrics and visualization
files serialized to HTML string. files serialized to HTML string.
""" """
resp: Dict[str, Any] = {} resp: Dict[str, Any] = {}
try: if textstat is not None:
textstat = import_textstat()
except ImportError:
pass
else:
text_complexity_metrics = { text_complexity_metrics = {
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics() key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
} }
@ -304,8 +302,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.nlp = None self.nlp = None
try: try:
spacy = import_spacy() spacy = import_spacy()
except ImportError: except ImportError as e:
pass logger.warning(e.msg)
else: else:
try: try:
self.nlp = spacy.load("en_core_web_sm") self.nlp = spacy.load("en_core_web_sm")
@ -315,6 +313,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"to download en_core_web_sm model for text visualization." "to download en_core_web_sm model for text visualization."
) )
try:
self.textstat = import_textstat()
except ImportError as e:
logger.warning(e.msg)
self.textstat = None
self.metrics = {key: 0 for key in mlflow_callback_metrics()} self.metrics = {key: 0 for key in mlflow_callback_metrics()}
self.records: Dict[str, Any] = { self.records: Dict[str, Any] = {
@ -403,15 +407,17 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
analyze_text( analyze_text(
generation.text, generation.text,
nlp=self.nlp, nlp=self.nlp,
textstat=self.textstat,
) )
) )
complexity_metrics: Dict[str, float] = generation_resp.pop( if "text_complexity_metrics" in generation_resp:
"text_complexity_metrics" complexity_metrics: Dict[str, float] = generation_resp.pop(
) "text_complexity_metrics"
self.mlflg.metrics( )
complexity_metrics, self.mlflg.metrics(
step=self.metrics["step"], complexity_metrics,
) step=self.metrics["step"],
)
self.records["on_llm_end_records"].append(generation_resp) self.records["on_llm_end_records"].append(generation_resp)
self.records["action_records"].append(generation_resp) self.records["action_records"].append(generation_resp)
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}") self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
@ -681,7 +687,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
.dropna(axis=1) .dropna(axis=1)
.rename({"step": "prompt_step"}, axis=1) .rename({"step": "prompt_step"}, axis=1)
) )
complexity_metrics_columns = get_text_complexity_metrics() complexity_metrics_columns = (
get_text_complexity_metrics() if self.textstat is not None else []
)
visualizations_columns = ( visualizations_columns = (
["dependency_tree", "entities"] if self.nlp is not None else [] ["dependency_tree", "entities"] if self.nlp is not None else []
) )