community[patch]: MLflow callback update (#16687)

Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Serena Ruan 2024-02-05 15:46:46 -08:00 committed by GitHub
parent 3c4b24b69a
commit 9b279ac127
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -82,23 +82,21 @@ def get_text_complexity_metrics() -> List[str]:
def analyze_text(
text: str,
nlp: Any = None,
textstat: Any = None,
) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
nlp (spacy.lang): The spacy language model to use for visualization.
textstat: The textstat library to use for complexity metrics calculation.
Returns:
(dict): A dictionary containing the complexity metrics and visualization
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
try:
textstat = import_textstat()
except ImportError:
pass
else:
if textstat is not None:
text_complexity_metrics = {
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
}
@ -304,8 +302,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.nlp = None
try:
spacy = import_spacy()
except ImportError:
pass
except ImportError as e:
logger.warning(e.msg)
else:
try:
self.nlp = spacy.load("en_core_web_sm")
@ -315,6 +313,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"to download en_core_web_sm model for text visualization."
)
try:
self.textstat = import_textstat()
except ImportError as e:
logger.warning(e.msg)
self.textstat = None
self.metrics = {key: 0 for key in mlflow_callback_metrics()}
self.records: Dict[str, Any] = {
@ -403,8 +407,10 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
analyze_text(
generation.text,
nlp=self.nlp,
textstat=self.textstat,
)
)
if "text_complexity_metrics" in generation_resp:
complexity_metrics: Dict[str, float] = generation_resp.pop(
"text_complexity_metrics"
)
@ -681,7 +687,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
.dropna(axis=1)
.rename({"step": "prompt_step"}, axis=1)
)
complexity_metrics_columns = get_text_complexity_metrics()
complexity_metrics_columns = (
get_text_complexity_metrics() if self.textstat is not None else []
)
visualizations_columns = (
["dependency_tree", "entities"] if self.nlp is not None else []
)