community[patch]: MLflowCallbackHandler -- Move textstat and spacy as optional dependency (#16657)

Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
This commit is contained in:
Serena Ruan 2024-01-27 16:15:07 -08:00 committed by GitHub
parent 508bde7f40
commit f01fb47597
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -94,15 +94,19 @@ def analyze_text(
files serialized to HTML string. files serialized to HTML string.
""" """
resp: Dict[str, Any] = {} resp: Dict[str, Any] = {}
textstat = import_textstat() try:
spacy = import_spacy() textstat = import_textstat()
text_complexity_metrics = { except ImportError:
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics() pass
} else:
resp.update({"text_complexity_metrics": text_complexity_metrics}) text_complexity_metrics = {
resp.update(text_complexity_metrics) key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
if nlp is not None: if nlp is not None:
spacy = import_spacy()
doc = nlp(text) doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore dep_out = spacy.displacy.render( # type: ignore
@ -279,9 +283,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None: ) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
import_pandas() import_pandas()
import_textstat()
import_mlflow() import_mlflow()
spacy = import_spacy()
super().__init__() super().__init__()
self.name = name self.name = name
@ -303,14 +305,19 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) )
self.action_records: list = [] self.action_records: list = []
self.nlp = None
try: try:
self.nlp = spacy.load("en_core_web_sm") spacy = import_spacy()
except OSError: except ImportError:
logger.warning( pass
"Run `python -m spacy download en_core_web_sm` " else:
"to download en_core_web_sm model for text visualization." try:
) self.nlp = spacy.load("en_core_web_sm")
self.nlp = None except OSError:
logger.warning(
"Run `python -m spacy download en_core_web_sm` "
"to download en_core_web_sm model for text visualization."
)
self.metrics = {key: 0 for key in mlflow_callback_metrics()} self.metrics = {key: 0 for key in mlflow_callback_metrics()}