community[patch]: MLflowCallbackHandler -- Move textstat and spacy as optional dependency (#16657)

Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
This commit is contained in:
Serena Ruan 2024-01-27 16:15:07 -08:00 committed by GitHub
parent 508bde7f40
commit f01fb47597
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -94,8 +94,11 @@ def analyze_text(
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
try:
textstat = import_textstat()
spacy = import_spacy()
except ImportError:
pass
else:
text_complexity_metrics = {
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
}
@ -103,6 +106,7 @@ def analyze_text(
resp.update(text_complexity_metrics)
if nlp is not None:
spacy = import_spacy()
doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore
@ -279,9 +283,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None:
"""Initialize callback handler."""
import_pandas()
import_textstat()
import_mlflow()
spacy = import_spacy()
super().__init__()
self.name = name
@ -303,6 +305,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
)
self.action_records: list = []
self.nlp = None
try:
spacy = import_spacy()
except ImportError:
pass
else:
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
@ -310,7 +318,6 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"Run `python -m spacy download en_core_web_sm` "
"to download en_core_web_sm model for text visualization."
)
self.nlp = None
self.metrics = {key: 0 for key in mlflow_callback_metrics()}