community[patch]: MLflowCallbackHandler -- Move textstat and spacy as optional dependency (#16657)

Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
This commit is contained in:
Serena Ruan 2024-01-27 16:15:07 -08:00 committed by GitHub
parent 508bde7f40
commit f01fb47597
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -94,8 +94,11 @@ def analyze_text(
files serialized to HTML string. files serialized to HTML string.
""" """
resp: Dict[str, Any] = {} resp: Dict[str, Any] = {}
try:
textstat = import_textstat() textstat = import_textstat()
spacy = import_spacy() except ImportError:
pass
else:
text_complexity_metrics = { text_complexity_metrics = {
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics() key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
} }
@ -103,6 +106,7 @@ def analyze_text(
resp.update(text_complexity_metrics) resp.update(text_complexity_metrics)
if nlp is not None: if nlp is not None:
spacy = import_spacy()
doc = nlp(text) doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore dep_out = spacy.displacy.render( # type: ignore
@ -279,9 +283,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) -> None: ) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
import_pandas() import_pandas()
import_textstat()
import_mlflow() import_mlflow()
spacy = import_spacy()
super().__init__() super().__init__()
self.name = name self.name = name
@ -303,6 +305,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) )
self.action_records: list = [] self.action_records: list = []
self.nlp = None
try:
spacy = import_spacy()
except ImportError:
pass
else:
try: try:
self.nlp = spacy.load("en_core_web_sm") self.nlp = spacy.load("en_core_web_sm")
except OSError: except OSError:
@ -310,7 +318,6 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"Run `python -m spacy download en_core_web_sm` " "Run `python -m spacy download en_core_web_sm` "
"to download en_core_web_sm model for text visualization." "to download en_core_web_sm model for text visualization."
) )
self.nlp = None
self.metrics = {key: 0 for key in mlflow_callback_metrics()} self.metrics = {key: 0 for key in mlflow_callback_metrics()}