mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
community[patch]: MLflow callback update (#16687)
Signed-off-by: Serena Ruan <serena.rxy@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
3c4b24b69a
commit
9b279ac127
@ -82,23 +82,21 @@ def get_text_complexity_metrics() -> List[str]:
|
||||
def analyze_text(
|
||||
text: str,
|
||||
nlp: Any = None,
|
||||
textstat: Any = None,
|
||||
) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
nlp (spacy.lang): The spacy language model to use for visualization.
|
||||
textstat: The textstat library to use for complexity metrics calculation.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the complexity metrics and visualization
|
||||
files serialized to HTML string.
|
||||
"""
|
||||
resp: Dict[str, Any] = {}
|
||||
try:
|
||||
textstat = import_textstat()
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if textstat is not None:
|
||||
text_complexity_metrics = {
|
||||
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
|
||||
}
|
||||
@ -304,8 +302,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.nlp = None
|
||||
try:
|
||||
spacy = import_spacy()
|
||||
except ImportError:
|
||||
pass
|
||||
except ImportError as e:
|
||||
logger.warning(e.msg)
|
||||
else:
|
||||
try:
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
@ -315,6 +313,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"to download en_core_web_sm model for text visualization."
|
||||
)
|
||||
|
||||
try:
|
||||
self.textstat = import_textstat()
|
||||
except ImportError as e:
|
||||
logger.warning(e.msg)
|
||||
self.textstat = None
|
||||
|
||||
self.metrics = {key: 0 for key in mlflow_callback_metrics()}
|
||||
|
||||
self.records: Dict[str, Any] = {
|
||||
@ -403,15 +407,17 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
analyze_text(
|
||||
generation.text,
|
||||
nlp=self.nlp,
|
||||
textstat=self.textstat,
|
||||
)
|
||||
)
|
||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||
"text_complexity_metrics"
|
||||
)
|
||||
self.mlflg.metrics(
|
||||
complexity_metrics,
|
||||
step=self.metrics["step"],
|
||||
)
|
||||
if "text_complexity_metrics" in generation_resp:
|
||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||
"text_complexity_metrics"
|
||||
)
|
||||
self.mlflg.metrics(
|
||||
complexity_metrics,
|
||||
step=self.metrics["step"],
|
||||
)
|
||||
self.records["on_llm_end_records"].append(generation_resp)
|
||||
self.records["action_records"].append(generation_resp)
|
||||
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
|
||||
@ -681,7 +687,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "prompt_step"}, axis=1)
|
||||
)
|
||||
complexity_metrics_columns = get_text_complexity_metrics()
|
||||
complexity_metrics_columns = (
|
||||
get_text_complexity_metrics() if self.textstat is not None else []
|
||||
)
|
||||
visualizations_columns = (
|
||||
["dependency_tree", "entities"] if self.nlp is not None else []
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user