mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-30 23:29:54 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			368 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			368 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """FlyteKit callback handler."""
 | |
| from __future__ import annotations
 | |
| 
 | |
| import logging
 | |
| from copy import deepcopy
 | |
| from typing import TYPE_CHECKING, Any, Dict, List, Tuple
 | |
| 
 | |
| from langchain_core.agents import AgentAction, AgentFinish
 | |
| from langchain_core.callbacks import BaseCallbackHandler
 | |
| from langchain_core.outputs import LLMResult
 | |
| 
 | |
| from langchain_community.callbacks.utils import (
 | |
|     BaseMetadataCallbackHandler,
 | |
|     flatten_dict,
 | |
|     import_pandas,
 | |
|     import_spacy,
 | |
|     import_textstat,
 | |
| )
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     import flytekit
 | |
|     from flytekitplugins.deck import renderer
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| def import_flytekit() -> Tuple[flytekit, renderer]:
 | |
|     """Import flytekit and flytekitplugins-deck-standard."""
 | |
|     try:
 | |
|         import flytekit  # noqa: F401
 | |
|         from flytekitplugins.deck import renderer  # noqa: F401
 | |
|     except ImportError:
 | |
|         raise ImportError(
 | |
|             "To use the flyte callback manager you need"
 | |
|             "to have the `flytekit` and `flytekitplugins-deck-standard`"
 | |
|             "packages installed. Please install them with `pip install flytekit`"
 | |
|             "and `pip install flytekitplugins-deck-standard`."
 | |
|         )
 | |
|     return flytekit, renderer
 | |
| 
 | |
| 
 | |
| def analyze_text(
 | |
|     text: str,
 | |
|     nlp: Any = None,
 | |
|     textstat: Any = None,
 | |
| ) -> dict:
 | |
|     """Analyze text using textstat and spacy.
 | |
| 
 | |
|     Parameters:
 | |
|         text (str): The text to analyze.
 | |
|         nlp (spacy.lang): The spacy language model to use for visualization.
 | |
| 
 | |
|     Returns:
 | |
|         (dict): A dictionary containing the complexity metrics and visualization
 | |
|             files serialized to HTML string.
 | |
|     """
 | |
|     resp: Dict[str, Any] = {}
 | |
|     if textstat is not None:
 | |
|         text_complexity_metrics = {
 | |
|             "flesch_reading_ease": textstat.flesch_reading_ease(text),
 | |
|             "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
 | |
|             "smog_index": textstat.smog_index(text),
 | |
|             "coleman_liau_index": textstat.coleman_liau_index(text),
 | |
|             "automated_readability_index": textstat.automated_readability_index(text),
 | |
|             "dale_chall_readability_score": textstat.dale_chall_readability_score(text),
 | |
|             "difficult_words": textstat.difficult_words(text),
 | |
|             "linsear_write_formula": textstat.linsear_write_formula(text),
 | |
|             "gunning_fog": textstat.gunning_fog(text),
 | |
|             "fernandez_huerta": textstat.fernandez_huerta(text),
 | |
|             "szigriszt_pazos": textstat.szigriszt_pazos(text),
 | |
|             "gutierrez_polini": textstat.gutierrez_polini(text),
 | |
|             "crawford": textstat.crawford(text),
 | |
|             "gulpease_index": textstat.gulpease_index(text),
 | |
|             "osman": textstat.osman(text),
 | |
|         }
 | |
|         resp.update({"text_complexity_metrics": text_complexity_metrics})
 | |
|         resp.update(text_complexity_metrics)
 | |
| 
 | |
|     if nlp is not None:
 | |
|         spacy = import_spacy()
 | |
|         doc = nlp(text)
 | |
|         dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
 | |
|         ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
 | |
|         text_visualizations = {
 | |
|             "dependency_tree": dep_out,
 | |
|             "entities": ent_out,
 | |
|         }
 | |
|         resp.update(text_visualizations)
 | |
| 
 | |
|     return resp
 | |
| 
 | |
| 
 | |
| class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
 | |
|     """This callback handler that is used within a Flyte task."""
 | |
| 
 | |
|     def __init__(self) -> None:
 | |
|         """Initialize callback handler."""
 | |
|         flytekit, renderer = import_flytekit()
 | |
|         self.pandas = import_pandas()
 | |
| 
 | |
|         self.textstat = None
 | |
|         try:
 | |
|             self.textstat = import_textstat()
 | |
|         except ImportError:
 | |
|             logger.warning(
 | |
|                 "Textstat library is not installed. \
 | |
|                 It may result in the inability to log \
 | |
|                 certain metrics that can be captured with Textstat."
 | |
|             )
 | |
| 
 | |
|         spacy = None
 | |
|         try:
 | |
|             spacy = import_spacy()
 | |
|         except ImportError:
 | |
|             logger.warning(
 | |
|                 "Spacy library is not installed. \
 | |
|                 It may result in the inability to log \
 | |
|                 certain metrics that can be captured with Spacy."
 | |
|             )
 | |
| 
 | |
|         super().__init__()
 | |
| 
 | |
|         self.nlp = None
 | |
|         if spacy:
 | |
|             try:
 | |
|                 self.nlp = spacy.load("en_core_web_sm")
 | |
|             except OSError:
 | |
|                 logger.warning(
 | |
|                     "FlyteCallbackHandler uses spacy's en_core_web_sm model"
 | |
|                     " for certain metrics. To download,"
 | |
|                     " run the following command in your terminal:"
 | |
|                     " `python -m spacy download en_core_web_sm`"
 | |
|                 )
 | |
| 
 | |
|         self.table_renderer = renderer.TableRenderer
 | |
|         self.markdown_renderer = renderer.MarkdownRenderer
 | |
| 
 | |
|         self.deck = flytekit.Deck(
 | |
|             "LangChain Metrics",
 | |
|             self.markdown_renderer().to_html("## LangChain Metrics"),
 | |
|         )
 | |
| 
 | |
|     def on_llm_start(
 | |
|         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
 | |
|     ) -> None:
 | |
|         """Run when LLM starts."""
 | |
| 
 | |
|         self.step += 1
 | |
|         self.llm_starts += 1
 | |
|         self.starts += 1
 | |
| 
 | |
|         resp: Dict[str, Any] = {}
 | |
|         resp.update({"action": "on_llm_start"})
 | |
|         resp.update(flatten_dict(serialized))
 | |
|         resp.update(self.get_custom_callback_meta())
 | |
| 
 | |
|         prompt_responses = []
 | |
|         for prompt in prompts:
 | |
|             prompt_responses.append(prompt)
 | |
| 
 | |
|         resp.update({"prompts": prompt_responses})
 | |
| 
 | |
|         self.deck.append(self.markdown_renderer().to_html("### LLM Start"))
 | |
|         self.deck.append(
 | |
|             self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
 | |
|         )
 | |
| 
 | |
|     def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
 | |
|         """Run when LLM generates a new token."""
 | |
| 
 | |
|     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
 | |
|         """Run when LLM ends running."""
 | |
|         self.step += 1
 | |
|         self.llm_ends += 1
 | |
|         self.ends += 1
 | |
| 
 | |
|         resp: Dict[str, Any] = {}
 | |
|         resp.update({"action": "on_llm_end"})
 | |
|         resp.update(flatten_dict(response.llm_output or {}))
 | |
|         resp.update(self.get_custom_callback_meta())
 | |
| 
 | |
|         self.deck.append(self.markdown_renderer().to_html("### LLM End"))
 | |
|         self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp])))
 | |
| 
 | |
|         for generations in response.generations:
 | |
|             for generation in generations:
 | |
|                 generation_resp = deepcopy(resp)
 | |
|                 generation_resp.update(flatten_dict(generation.dict()))
 | |
|                 if self.nlp or self.textstat:
 | |
|                     generation_resp.update(
 | |
|                         analyze_text(
 | |
|                             generation.text, nlp=self.nlp, textstat=self.textstat
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|                     complexity_metrics: Dict[str, float] = generation_resp.pop(
 | |
|                         "text_complexity_metrics"
 | |
|                     )
 | |
|                     self.deck.append(
 | |
|                         self.markdown_renderer().to_html("#### Text Complexity Metrics")
 | |
|                     )
 | |
|                     self.deck.append(
 | |
|                         self.table_renderer().to_html(
 | |
|                             self.pandas.DataFrame([complexity_metrics])
 | |
|                         )
 | |
|                         + "\n"
 | |
|                     )
 | |
| 
 | |
|                     dependency_tree = generation_resp["dependency_tree"]
 | |
|                     self.deck.append(
 | |
|                         self.markdown_renderer().to_html("#### Dependency Tree")
 | |
|                     )
 | |
|                     self.deck.append(dependency_tree)
 | |
| 
 | |
|                     entities = generation_resp["entities"]
 | |
|                     self.deck.append(self.markdown_renderer().to_html("#### Entities"))
 | |
|                     self.deck.append(entities)
 | |
|                 else:
 | |
|                     self.deck.append(
 | |
|                         self.markdown_renderer().to_html("#### Generated Response")
 | |
|                     )
 | |
|                     self.deck.append(self.markdown_renderer().to_html(generation.text))
 | |
| 
 | |
|     def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
 | |
|         """Run when LLM errors."""
 | |
|         self.step += 1
 | |
|         self.errors += 1
 | |
| 
 | |
|     def on_chain_start(
 | |
|         self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
 | |
|     ) -> None:
 | |
|         """Run when chain starts running."""
 | |
|         self.step += 1
 | |
|         self.chain_starts += 1
 | |
|         self.starts += 1
 | |
| 
 | |
|         resp: Dict[str, Any] = {}
 | |
|         resp.update({"action": "on_chain_start"})
 | |
|         resp.update(flatten_dict(serialized))
 | |
|         resp.update(self.get_custom_callback_meta())
 | |
| 
 | |
|         chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
 | |
|         input_resp = deepcopy(resp)
 | |
|         input_resp["inputs"] = chain_input
 | |
| 
 | |
|         self.deck.append(self.markdown_renderer().to_html("### Chain Start"))
 | |
|         self.deck.append(
 | |
|             self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n"
 | |
|         )
 | |
| 
 | |
|     def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
 | |
|         """Run when chain ends running."""
 | |
|         self.step += 1
 | |
|         self.chain_ends += 1
 | |
|         self.ends += 1
 | |
| 
 | |
|         resp: Dict[str, Any] = {}
 | |
|         chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
 | |
|         resp.update({"action": "on_chain_end", "outputs": chain_output})
 | |
|         resp.update(self.get_custom_callback_meta())
 | |
| 
 | |
|         self.deck.append(self.markdown_renderer().to_html("### Chain End"))
 | |
|         self.deck.append(
 | |
|             self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
 | |
|         )
 | |
| 
 | |
|     def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
 | |
|         """Run when chain errors."""
 | |
|         self.step += 1
 | |
|         self.errors += 1
 | |
| 
 | |
|     def on_tool_start(
 | |
|         self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
 | |
|     ) -> None:
 | |
|         """Run when tool starts running."""
 | |
|         self.step += 1
 | |
|         self.tool_starts += 1
 | |
|         self.starts += 1
 | |
| 
 | |
|         resp: Dict[str, Any] = {}
 | |
|         resp.update({"action": "on_tool_start", "input_str": input_str})
 | |
|         resp.update(flatten_dict(serialized))
 | |
|         resp.update(self.get_custom_callback_meta())
 | |
| 
 | |
|         self.deck.append(self.markdown_renderer().to_html("### Tool Start"))
 | |
|         self.deck.append(
 | |
|             self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
 | |
|         )
 | |
| 
 | |
|     def on_tool_end(self, output: str, **kwargs: Any) -> None:
 | |
|         """Run when tool ends running."""
 | |
|         self.step += 1
 | |
|         self.tool_ends += 1
 | |
|         self.ends += 1
 | |
| 
 | |
|         resp: Dict[str, Any] = {}
 | |
|         resp.update({"action": "on_tool_end", "output": output})
 | |
|         resp.update(self.get_custom_callback_meta())
 | |
| 
 | |
|         self.deck.append(self.markdown_renderer().to_html("### Tool End"))
 | |
|         self.deck.append(
 | |
|             self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
 | |
|         )
 | |
| 
 | |
|     def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
 | |
|         """Run when tool errors."""
 | |
|         self.step += 1
 | |
|         self.errors += 1
 | |
| 
 | |
|     def on_text(self, text: str, **kwargs: Any) -> None:
 | |
|         """
 | |
|         Run when agent is ending.
 | |
|         """
 | |
|         self.step += 1
 | |
|         self.text_ctr += 1
 | |
| 
 | |
|         resp: Dict[str, Any] = {}
 | |
|         resp.update({"action": "on_text", "text": text})
 | |
|         resp.update(self.get_custom_callback_meta())
 | |
| 
 | |
|         self.deck.append(self.markdown_renderer().to_html("### On Text"))
 | |
|         self.deck.append(
 | |
|             self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
 | |
|         )
 | |
| 
 | |
|     def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
 | |
|         """Run when agent ends running."""
 | |
|         self.step += 1
 | |
|         self.agent_ends += 1
 | |
|         self.ends += 1
 | |
| 
 | |
|         resp: Dict[str, Any] = {}
 | |
|         resp.update(
 | |
|             {
 | |
|                 "action": "on_agent_finish",
 | |
|                 "output": finish.return_values["output"],
 | |
|                 "log": finish.log,
 | |
|             }
 | |
|         )
 | |
|         resp.update(self.get_custom_callback_meta())
 | |
| 
 | |
|         self.deck.append(self.markdown_renderer().to_html("### Agent Finish"))
 | |
|         self.deck.append(
 | |
|             self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
 | |
|         )
 | |
| 
 | |
|     def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
 | |
|         """Run on agent action."""
 | |
|         self.step += 1
 | |
|         self.tool_starts += 1
 | |
|         self.starts += 1
 | |
| 
 | |
|         resp: Dict[str, Any] = {}
 | |
|         resp.update(
 | |
|             {
 | |
|                 "action": "on_agent_action",
 | |
|                 "tool": action.tool,
 | |
|                 "tool_input": action.tool_input,
 | |
|                 "log": action.log,
 | |
|             }
 | |
|         )
 | |
|         resp.update(self.get_custom_callback_meta())
 | |
| 
 | |
|         self.deck.append(self.markdown_renderer().to_html("### Agent Action"))
 | |
|         self.deck.append(
 | |
|             self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
 | |
|         )
 |