mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
community[patch]: MLflow callback update (#16687)
Signed-off-by: Serena Ruan <serena.rxy@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
3c4b24b69a
commit
9b279ac127
@ -82,23 +82,21 @@ def get_text_complexity_metrics() -> List[str]:
|
|||||||
def analyze_text(
|
def analyze_text(
|
||||||
text: str,
|
text: str,
|
||||||
nlp: Any = None,
|
nlp: Any = None,
|
||||||
|
textstat: Any = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Analyze text using textstat and spacy.
|
"""Analyze text using textstat and spacy.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
text (str): The text to analyze.
|
text (str): The text to analyze.
|
||||||
nlp (spacy.lang): The spacy language model to use for visualization.
|
nlp (spacy.lang): The spacy language model to use for visualization.
|
||||||
|
textstat: The textstat library to use for complexity metrics calculation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(dict): A dictionary containing the complexity metrics and visualization
|
(dict): A dictionary containing the complexity metrics and visualization
|
||||||
files serialized to HTML string.
|
files serialized to HTML string.
|
||||||
"""
|
"""
|
||||||
resp: Dict[str, Any] = {}
|
resp: Dict[str, Any] = {}
|
||||||
try:
|
if textstat is not None:
|
||||||
textstat = import_textstat()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
text_complexity_metrics = {
|
text_complexity_metrics = {
|
||||||
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
|
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
|
||||||
}
|
}
|
||||||
@ -304,8 +302,8 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.nlp = None
|
self.nlp = None
|
||||||
try:
|
try:
|
||||||
spacy = import_spacy()
|
spacy = import_spacy()
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
pass
|
logger.warning(e.msg)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
self.nlp = spacy.load("en_core_web_sm")
|
self.nlp = spacy.load("en_core_web_sm")
|
||||||
@ -315,6 +313,12 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
"to download en_core_web_sm model for text visualization."
|
"to download en_core_web_sm model for text visualization."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.textstat = import_textstat()
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(e.msg)
|
||||||
|
self.textstat = None
|
||||||
|
|
||||||
self.metrics = {key: 0 for key in mlflow_callback_metrics()}
|
self.metrics = {key: 0 for key in mlflow_callback_metrics()}
|
||||||
|
|
||||||
self.records: Dict[str, Any] = {
|
self.records: Dict[str, Any] = {
|
||||||
@ -403,15 +407,17 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
analyze_text(
|
analyze_text(
|
||||||
generation.text,
|
generation.text,
|
||||||
nlp=self.nlp,
|
nlp=self.nlp,
|
||||||
|
textstat=self.textstat,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
if "text_complexity_metrics" in generation_resp:
|
||||||
"text_complexity_metrics"
|
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||||
)
|
"text_complexity_metrics"
|
||||||
self.mlflg.metrics(
|
)
|
||||||
complexity_metrics,
|
self.mlflg.metrics(
|
||||||
step=self.metrics["step"],
|
complexity_metrics,
|
||||||
)
|
step=self.metrics["step"],
|
||||||
|
)
|
||||||
self.records["on_llm_end_records"].append(generation_resp)
|
self.records["on_llm_end_records"].append(generation_resp)
|
||||||
self.records["action_records"].append(generation_resp)
|
self.records["action_records"].append(generation_resp)
|
||||||
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
|
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
|
||||||
@ -681,7 +687,9 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
.dropna(axis=1)
|
.dropna(axis=1)
|
||||||
.rename({"step": "prompt_step"}, axis=1)
|
.rename({"step": "prompt_step"}, axis=1)
|
||||||
)
|
)
|
||||||
complexity_metrics_columns = get_text_complexity_metrics()
|
complexity_metrics_columns = (
|
||||||
|
get_text_complexity_metrics() if self.textstat is not None else []
|
||||||
|
)
|
||||||
visualizations_columns = (
|
visualizations_columns = (
|
||||||
["dependency_tree", "entities"] if self.nlp is not None else []
|
["dependency_tree", "entities"] if self.nlp is not None else []
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user