mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
community[major], core[patch], langchain[patch], experimental[patch]: Create langchain-community (#14463)
Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
This commit is contained in:
66
libs/community/langchain_community/callbacks/__init__.py
Normal file
66
libs/community/langchain_community/callbacks/__init__.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""**Callback handlers** allow listening to events in LangChain.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
|
||||
"""
|
||||
|
||||
from langchain_community.callbacks.aim_callback import AimCallbackHandler
|
||||
from langchain_community.callbacks.argilla_callback import ArgillaCallbackHandler
|
||||
from langchain_community.callbacks.arize_callback import ArizeCallbackHandler
|
||||
from langchain_community.callbacks.arthur_callback import ArthurCallbackHandler
|
||||
from langchain_community.callbacks.clearml_callback import ClearMLCallbackHandler
|
||||
from langchain_community.callbacks.comet_ml_callback import CometCallbackHandler
|
||||
from langchain_community.callbacks.context_callback import ContextCallbackHandler
|
||||
from langchain_community.callbacks.flyte_callback import FlyteCallbackHandler
|
||||
from langchain_community.callbacks.human import HumanApprovalCallbackHandler
|
||||
from langchain_community.callbacks.infino_callback import InfinoCallbackHandler
|
||||
from langchain_community.callbacks.labelstudio_callback import (
|
||||
LabelStudioCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.llmonitor_callback import LLMonitorCallbackHandler
|
||||
from langchain_community.callbacks.manager import (
|
||||
get_openai_callback,
|
||||
wandb_tracing_enabled,
|
||||
)
|
||||
from langchain_community.callbacks.mlflow_callback import MlflowCallbackHandler
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain_community.callbacks.promptlayer_callback import (
|
||||
PromptLayerCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.sagemaker_callback import SageMakerCallbackHandler
|
||||
from langchain_community.callbacks.streamlit import (
|
||||
LLMThoughtLabeler,
|
||||
StreamlitCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.trubrics_callback import TrubricsCallbackHandler
|
||||
from langchain_community.callbacks.wandb_callback import WandbCallbackHandler
|
||||
from langchain_community.callbacks.whylabs_callback import WhyLabsCallbackHandler
|
||||
|
||||
__all__ = [
|
||||
"AimCallbackHandler",
|
||||
"ArgillaCallbackHandler",
|
||||
"ArizeCallbackHandler",
|
||||
"PromptLayerCallbackHandler",
|
||||
"ArthurCallbackHandler",
|
||||
"ClearMLCallbackHandler",
|
||||
"CometCallbackHandler",
|
||||
"ContextCallbackHandler",
|
||||
"HumanApprovalCallbackHandler",
|
||||
"InfinoCallbackHandler",
|
||||
"MlflowCallbackHandler",
|
||||
"LLMonitorCallbackHandler",
|
||||
"OpenAICallbackHandler",
|
||||
"LLMThoughtLabeler",
|
||||
"StreamlitCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
"WhyLabsCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"wandb_tracing_enabled",
|
||||
"FlyteCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
"LabelStudioCallbackHandler",
|
||||
"TrubricsCallbackHandler",
|
||||
]
|
430
libs/community/langchain_community/callbacks/aim_callback.py
Normal file
430
libs/community/langchain_community/callbacks/aim_callback.py
Normal file
@@ -0,0 +1,430 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
def import_aim() -> Any:
|
||||
"""Import the aim python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import aim
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the Aim callback manager you need to have the"
|
||||
" `aim` python package installed."
|
||||
"Please install it with `pip install aim`"
|
||||
)
|
||||
return aim
|
||||
|
||||
|
||||
class BaseMetadataCallbackHandler:
|
||||
"""This class handles the metadata and associated function states for callbacks.
|
||||
|
||||
Attributes:
|
||||
step (int): The current step.
|
||||
starts (int): The number of times the start method has been called.
|
||||
ends (int): The number of times the end method has been called.
|
||||
errors (int): The number of times the error method has been called.
|
||||
text_ctr (int): The number of times the text method has been called.
|
||||
ignore_llm_ (bool): Whether to ignore llm callbacks.
|
||||
ignore_chain_ (bool): Whether to ignore chain callbacks.
|
||||
ignore_agent_ (bool): Whether to ignore agent callbacks.
|
||||
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
|
||||
always_verbose_ (bool): Whether to always be verbose.
|
||||
chain_starts (int): The number of times the chain start method has been called.
|
||||
chain_ends (int): The number of times the chain end method has been called.
|
||||
llm_starts (int): The number of times the llm start method has been called.
|
||||
llm_ends (int): The number of times the llm end method has been called.
|
||||
llm_streams (int): The number of times the text method has been called.
|
||||
tool_starts (int): The number of times the tool start method has been called.
|
||||
tool_ends (int): The number of times the tool end method has been called.
|
||||
agent_ends (int): The number of times the agent end method has been called.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.ignore_retriever_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return self.always_verbose_
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return self.ignore_retriever_
|
||||
|
||||
def get_custom_callback_meta(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step": self.step,
|
||||
"starts": self.starts,
|
||||
"ends": self.ends,
|
||||
"errors": self.errors,
|
||||
"text_ctr": self.text_ctr,
|
||||
"chain_starts": self.chain_starts,
|
||||
"chain_ends": self.chain_ends,
|
||||
"llm_starts": self.llm_starts,
|
||||
"llm_ends": self.llm_ends,
|
||||
"llm_streams": self.llm_streams,
|
||||
"tool_starts": self.tool_starts,
|
||||
"tool_ends": self.tool_ends,
|
||||
"agent_ends": self.agent_ends,
|
||||
}
|
||||
|
||||
def reset_callback_meta(self) -> None:
|
||||
"""Reset the callback metadata."""
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Aim.
|
||||
|
||||
Parameters:
|
||||
repo (:obj:`str`, optional): Aim repository path or Repo object to which
|
||||
Run object is bound. If skipped, default Repo is used.
|
||||
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
|
||||
'default' if not specified. Can be used later to query runs/sequences.
|
||||
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
|
||||
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
|
||||
to disable system metrics tracking.
|
||||
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
|
||||
params such as installed packages, git info, environment variables, etc.
|
||||
|
||||
This handler will utilize the associated callback method called and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run
|
||||
and then logs the response to Aim.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
system_tracking_interval: Optional[int] = 10,
|
||||
log_system_params: bool = True,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
super().__init__()
|
||||
|
||||
aim = import_aim()
|
||||
self.repo = repo
|
||||
self.experiment_name = experiment_name
|
||||
self.system_tracking_interval = system_tracking_interval
|
||||
self.log_system_params = log_system_params
|
||||
self._run = aim.Run(
|
||||
repo=self.repo,
|
||||
experiment=self.experiment_name,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
log_system_params=self.log_system_params,
|
||||
)
|
||||
self._run_hash = self._run.hash
|
||||
self.action_records: list = []
|
||||
|
||||
def setup(self, **kwargs: Any) -> None:
|
||||
aim = import_aim()
|
||||
|
||||
if not self._run:
|
||||
if self._run_hash:
|
||||
self._run = aim.Run(
|
||||
self._run_hash,
|
||||
repo=self.repo,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
)
|
||||
else:
|
||||
self._run = aim.Run(
|
||||
repo=self.repo,
|
||||
experiment=self.experiment_name,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
log_system_params=self.log_system_params,
|
||||
)
|
||||
self._run_hash = self._run.hash
|
||||
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
self._run.set(key, value, strict=False)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
aim = import_aim()
|
||||
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = {"action": "on_llm_start"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
prompts_res = deepcopy(prompts)
|
||||
|
||||
self._run.track(
|
||||
[aim.Text(prompt) for prompt in prompts_res],
|
||||
name="on_llm_start",
|
||||
context=resp,
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = {"action": "on_llm_end"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
response_res = deepcopy(response)
|
||||
|
||||
generated = [
|
||||
aim.Text(generation.text)
|
||||
for generations in response_res.generations
|
||||
for generation in generations
|
||||
]
|
||||
self._run.track(
|
||||
generated,
|
||||
name="on_llm_end",
|
||||
context=resp,
|
||||
)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = {"action": "on_chain_start"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
inputs_res = deepcopy(inputs)
|
||||
|
||||
self._run.track(
|
||||
aim.Text(inputs_res["input"]), name="on_chain_start", context=resp
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = {"action": "on_chain_end"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
outputs_res = deepcopy(outputs)
|
||||
|
||||
self._run.track(
|
||||
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = {"action": "on_tool_start"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = {"action": "on_tool_end"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self._run.track(aim.Text(output), name="on_tool_end", context=resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = {"action": "on_agent_finish"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
finish_res = deepcopy(finish)
|
||||
|
||||
text = "OUTPUT:\n{}\n\nLOG:\n{}".format(
|
||||
finish_res.return_values["output"], finish_res.log
|
||||
)
|
||||
self._run.track(aim.Text(text), name="on_agent_finish", context=resp)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = {
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
action_res = deepcopy(action)
|
||||
|
||||
text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format(
|
||||
action_res.tool_input, action_res.log
|
||||
)
|
||||
self._run.track(aim.Text(text), name="on_agent_action", context=resp)
|
||||
|
||||
def flush_tracker(
|
||||
self,
|
||||
repo: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
system_tracking_interval: Optional[int] = 10,
|
||||
log_system_params: bool = True,
|
||||
langchain_asset: Any = None,
|
||||
reset: bool = True,
|
||||
finish: bool = False,
|
||||
) -> None:
|
||||
"""Flush the tracker and reset the session.
|
||||
|
||||
Args:
|
||||
repo (:obj:`str`, optional): Aim repository path or Repo object to which
|
||||
Run object is bound. If skipped, default Repo is used.
|
||||
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
|
||||
'default' if not specified. Can be used later to query runs/sequences.
|
||||
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
|
||||
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
|
||||
to disable system metrics tracking.
|
||||
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
|
||||
params such as installed packages, git info, environment variables, etc.
|
||||
langchain_asset: The langchain asset to save.
|
||||
reset: Whether to reset the session.
|
||||
finish: Whether to finish the run.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if langchain_asset:
|
||||
try:
|
||||
for key, value in langchain_asset.dict().items():
|
||||
self._run.set(key, value, strict=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if finish or reset:
|
||||
self._run.close()
|
||||
self.reset_callback_meta()
|
||||
if reset:
|
||||
self.__init__( # type: ignore
|
||||
repo=repo if repo else self.repo,
|
||||
experiment_name=experiment_name
|
||||
if experiment_name
|
||||
else self.experiment_name,
|
||||
system_tracking_interval=system_tracking_interval
|
||||
if system_tracking_interval
|
||||
else self.system_tracking_interval,
|
||||
log_system_params=log_system_params
|
||||
if log_system_params
|
||||
else self.log_system_params,
|
||||
)
|
352
libs/community/langchain_community/callbacks/argilla_callback.py
Normal file
352
libs/community/langchain_community/callbacks/argilla_callback.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from packaging.version import parse
|
||||
|
||||
|
||||
class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs into Argilla.
|
||||
|
||||
Args:
|
||||
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
|
||||
exist in advance. If you need help on how to create a `FeedbackDataset` in
|
||||
Argilla, please visit
|
||||
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
|
||||
workspace_name: name of the workspace in Argilla where the specified
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that the
|
||||
default workspace will be used.
|
||||
api_url: URL of the Argilla Server that we want to use, and where the
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that either
|
||||
`ARGILLA_API_URL` environment variable or the default will be used.
|
||||
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
|
||||
means that either `ARGILLA_API_KEY` environment variable or the default
|
||||
will be used.
|
||||
|
||||
Raises:
|
||||
ImportError: if the `argilla` package is not installed.
|
||||
ConnectionError: if the connection to Argilla fails.
|
||||
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
|
||||
|
||||
Examples:
|
||||
>>> from langchain_community.llms import OpenAI
|
||||
>>> from langchain_community.callbacks import ArgillaCallbackHandler
|
||||
>>> argilla_callback = ArgillaCallbackHandler(
|
||||
... dataset_name="my-dataset",
|
||||
... workspace_name="my-workspace",
|
||||
... api_url="http://localhost:6900",
|
||||
... api_key="argilla.apikey",
|
||||
... )
|
||||
>>> llm = OpenAI(
|
||||
... temperature=0,
|
||||
... callbacks=[argilla_callback],
|
||||
... verbose=True,
|
||||
... openai_api_key="API_KEY_HERE",
|
||||
... )
|
||||
>>> llm.generate([
|
||||
... "What is the best NLP-annotation tool out there? (no bias at all)",
|
||||
... ])
|
||||
"Argilla, no doubt about it."
|
||||
"""
|
||||
|
||||
REPO_URL: str = "https://github.com/argilla-io/argilla"
|
||||
ISSUES_URL: str = f"{REPO_URL}/issues"
|
||||
BLOG_URL: str = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501
|
||||
|
||||
DEFAULT_API_URL: str = "http://localhost:6900"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str,
|
||||
workspace_name: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initializes the `ArgillaCallbackHandler`.
|
||||
|
||||
Args:
|
||||
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
|
||||
exist in advance. If you need help on how to create a `FeedbackDataset`
|
||||
in Argilla, please visit
|
||||
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
|
||||
workspace_name: name of the workspace in Argilla where the specified
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that the
|
||||
default workspace will be used.
|
||||
api_url: URL of the Argilla Server that we want to use, and where the
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that either
|
||||
`ARGILLA_API_URL` environment variable or the default will be used.
|
||||
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
|
||||
means that either `ARGILLA_API_KEY` environment variable or the default
|
||||
will be used.
|
||||
|
||||
Raises:
|
||||
ImportError: if the `argilla` package is not installed.
|
||||
ConnectionError: if the connection to Argilla fails.
|
||||
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
|
||||
try:
|
||||
import argilla as rg # noqa: F401
|
||||
|
||||
self.ARGILLA_VERSION = rg.__version__
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the Argilla callback manager you need to have the `argilla` "
|
||||
"Python package installed. Please install it with `pip install argilla`"
|
||||
)
|
||||
|
||||
# Check whether the Argilla version is compatible
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
|
||||
raise ImportError(
|
||||
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
|
||||
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
|
||||
"upgrade `argilla` with `pip install --upgrade argilla`."
|
||||
)
|
||||
|
||||
# Show a warning message if Argilla will assume the default values will be used
|
||||
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
|
||||
warnings.warn(
|
||||
(
|
||||
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
|
||||
f" set, it will default to `{self.DEFAULT_API_URL}`, which is the"
|
||||
" default API URL in Argilla Quickstart."
|
||||
),
|
||||
)
|
||||
api_url = self.DEFAULT_API_URL
|
||||
|
||||
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
|
||||
self.DEFAULT_API_KEY = (
|
||||
"admin.apikey"
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.11.0")
|
||||
else "owner.apikey"
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
|
||||
f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the"
|
||||
" default API key in Argilla Quickstart."
|
||||
),
|
||||
)
|
||||
api_url = self.DEFAULT_API_URL
|
||||
|
||||
# Connect to Argilla with the provided credentials, if applicable
|
||||
try:
|
||||
rg.init(api_key=api_key, api_url=api_url)
|
||||
except Exception as e:
|
||||
raise ConnectionError(
|
||||
f"Could not connect to Argilla with exception: '{e}'.\n"
|
||||
"Please check your `api_key` and `api_url`, and make sure that "
|
||||
"the Argilla server is up and running. If the problem persists "
|
||||
f"please report it to {self.ISSUES_URL} as an `integration` issue."
|
||||
) from e
|
||||
|
||||
# Set the Argilla variables
|
||||
self.dataset_name = dataset_name
|
||||
self.workspace_name = workspace_name or rg.get_workspace()
|
||||
|
||||
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
|
||||
try:
|
||||
extra_args = {}
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||
warnings.warn(
|
||||
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
|
||||
" higher is recommended.",
|
||||
UserWarning,
|
||||
)
|
||||
extra_args = {"with_records": False}
|
||||
self.dataset = rg.FeedbackDataset.from_argilla(
|
||||
name=self.dataset_name,
|
||||
workspace=self.workspace_name,
|
||||
**extra_args,
|
||||
)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
|
||||
f"\nPlease check that the dataset with name={self.dataset_name} in the"
|
||||
f" workspace={self.workspace_name} exists in advance. If you need help"
|
||||
" on how to create a `langchain`-compatible `FeedbackDataset` in"
|
||||
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
|
||||
f" please report it to {self.ISSUES_URL} as an `integration` issue."
|
||||
) from e
|
||||
|
||||
supported_fields = ["prompt", "response"]
|
||||
if supported_fields != [field.name for field in self.dataset.fields]:
|
||||
raise ValueError(
|
||||
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
|
||||
f"{self.workspace_name} had fields that are not supported yet for the"
|
||||
f"`langchain` integration. Supported fields are: {supported_fields},"
|
||||
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
|
||||
" For more information on how to create a `langchain`-compatible"
|
||||
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
|
||||
)
|
||||
|
||||
self.prompts: Dict[str, List[str]] = {}
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
|
||||
" change based on updates to `langchain`. Please report any issues to"
|
||||
f" {self.ISSUES_URL} as an `integration` issue."
|
||||
),
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Save the prompts in memory when an LLM starts."""
|
||||
self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing when a new token is generated."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Log records to Argilla when an LLM ends."""
|
||||
# Do nothing if there's a parent_run_id, since we will log the records when
|
||||
# the chain ends
|
||||
if kwargs["parent_run_id"]:
|
||||
return
|
||||
|
||||
# Creates the records and adds them to the `FeedbackDataset`
|
||||
prompts = self.prompts[str(kwargs["run_id"])]
|
||||
for prompt, generations in zip(prompts, response.generations):
|
||||
self.dataset.add_records(
|
||||
records=[
|
||||
{
|
||||
"fields": {
|
||||
"prompt": prompt,
|
||||
"response": generation.text.strip(),
|
||||
},
|
||||
}
|
||||
for generation in generations
|
||||
]
|
||||
)
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
self.prompts.pop(str(kwargs["run_id"]))
|
||||
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""If the key `input` is in `inputs`, then save it in `self.prompts` using
|
||||
either the `parent_run_id` or the `run_id` as the key. This is done so that
|
||||
we don't log the same input prompt twice, once when the LLM starts and once
|
||||
when the chain starts.
|
||||
"""
|
||||
if "input" in inputs:
|
||||
self.prompts.update(
|
||||
{
|
||||
str(kwargs["parent_run_id"] or kwargs["run_id"]): (
|
||||
inputs["input"]
|
||||
if isinstance(inputs["input"], list)
|
||||
else [inputs["input"]]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
|
||||
log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
|
||||
differs if the output is a list or not.
|
||||
"""
|
||||
if not any(
|
||||
key in self.prompts
|
||||
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
|
||||
):
|
||||
return
|
||||
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
|
||||
str(kwargs["run_id"])
|
||||
)
|
||||
for chain_output_key, chain_output_val in outputs.items():
|
||||
if isinstance(chain_output_val, list):
|
||||
# Creates the records and adds them to the `FeedbackDataset`
|
||||
self.dataset.add_records(
|
||||
records=[
|
||||
{
|
||||
"fields": {
|
||||
"prompt": prompt,
|
||||
"response": output["text"].strip(),
|
||||
},
|
||||
}
|
||||
for prompt, output in zip(
|
||||
prompts, # type: ignore
|
||||
chain_output_val,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Creates the records and adds them to the `FeedbackDataset`
|
||||
self.dataset.add_records(
|
||||
records=[
|
||||
{
|
||||
"fields": {
|
||||
"prompt": " ".join(prompts), # type: ignore
|
||||
"response": chain_output_val.strip(),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
if str(kwargs["parent_run_id"]) in self.prompts:
|
||||
self.prompts.pop(str(kwargs["parent_run_id"]))
|
||||
if str(kwargs["run_id"]) in self.prompts:
|
||||
self.prompts.pop(str(kwargs["run_id"]))
|
||||
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
213
libs/community/langchain_community/callbacks/arize_callback.py
Normal file
213
libs/community/langchain_community/callbacks/arize_callback.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.callbacks.utils import import_pandas
|
||||
|
||||
|
||||
class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Arize."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
SPACE_KEY: Optional[str] = None,
|
||||
API_KEY: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
super().__init__()
|
||||
self.model_id = model_id
|
||||
self.model_version = model_version
|
||||
self.space_key = SPACE_KEY
|
||||
self.api_key = API_KEY
|
||||
self.prompt_records: List[str] = []
|
||||
self.response_records: List[str] = []
|
||||
self.prediction_ids: List[str] = []
|
||||
self.pred_timestamps: List[int] = []
|
||||
self.response_embeddings: List[float] = []
|
||||
self.prompt_embeddings: List[float] = []
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.step = 0
|
||||
|
||||
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
|
||||
from arize.pandas.logger import Client
|
||||
|
||||
self.generator = EmbeddingGenerator.from_use_case(
|
||||
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
|
||||
model_name="distilbert-base-uncased",
|
||||
tokenizer_max_length=512,
|
||||
batch_size=256,
|
||||
)
|
||||
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
|
||||
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
|
||||
raise ValueError("❌ CHANGE SPACE AND API KEYS")
|
||||
else:
|
||||
print("✅ Arize client setup done! Now you can start using Arize!")
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
for prompt in prompts:
|
||||
self.prompt_records.append(prompt.replace("\n", ""))
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pd = import_pandas()
|
||||
from arize.utils.types import (
|
||||
EmbeddingColumnNames,
|
||||
Environments,
|
||||
ModelTypes,
|
||||
Schema,
|
||||
)
|
||||
|
||||
# Safe check if 'llm_output' and 'token_usage' exist
|
||||
if response.llm_output and "token_usage" in response.llm_output:
|
||||
self.prompt_tokens = response.llm_output["token_usage"].get(
|
||||
"prompt_tokens", 0
|
||||
)
|
||||
self.total_tokens = response.llm_output["token_usage"].get(
|
||||
"total_tokens", 0
|
||||
)
|
||||
self.completion_tokens = response.llm_output["token_usage"].get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
else:
|
||||
self.prompt_tokens = (
|
||||
self.total_tokens
|
||||
) = self.completion_tokens = 0 # assign default value
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
prompt = self.prompt_records[self.step]
|
||||
self.step = self.step + 1
|
||||
prompt_embedding = pd.Series(
|
||||
self.generator.generate_embeddings(
|
||||
text_col=pd.Series(prompt.replace("\n", " "))
|
||||
).reset_index(drop=True)
|
||||
)
|
||||
|
||||
# Assigning text to response_text instead of response
|
||||
response_text = generation.text.replace("\n", " ")
|
||||
response_embedding = pd.Series(
|
||||
self.generator.generate_embeddings(
|
||||
text_col=pd.Series(generation.text.replace("\n", " "))
|
||||
).reset_index(drop=True)
|
||||
)
|
||||
pred_timestamp = datetime.now().timestamp()
|
||||
|
||||
# Define the columns and data
|
||||
columns = [
|
||||
"prediction_ts",
|
||||
"response",
|
||||
"prompt",
|
||||
"response_vector",
|
||||
"prompt_vector",
|
||||
"prompt_token",
|
||||
"completion_token",
|
||||
"total_token",
|
||||
]
|
||||
data = [
|
||||
[
|
||||
pred_timestamp,
|
||||
response_text,
|
||||
prompt,
|
||||
response_embedding[0],
|
||||
prompt_embedding[0],
|
||||
self.prompt_tokens,
|
||||
self.total_tokens,
|
||||
self.completion_tokens,
|
||||
]
|
||||
]
|
||||
|
||||
# Create the DataFrame
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
|
||||
# Declare prompt and response columns
|
||||
prompt_columns = EmbeddingColumnNames(
|
||||
vector_column_name="prompt_vector", data_column_name="prompt"
|
||||
)
|
||||
|
||||
response_columns = EmbeddingColumnNames(
|
||||
vector_column_name="response_vector", data_column_name="response"
|
||||
)
|
||||
|
||||
schema = Schema(
|
||||
timestamp_column_name="prediction_ts",
|
||||
tag_column_names=[
|
||||
"prompt_token",
|
||||
"completion_token",
|
||||
"total_token",
|
||||
],
|
||||
prompt_column_names=prompt_columns,
|
||||
response_column_names=response_columns,
|
||||
)
|
||||
|
||||
response_from_arize = self.arize_client.log(
|
||||
dataframe=df,
|
||||
schema=schema,
|
||||
model_id=self.model_id,
|
||||
model_version=self.model_version,
|
||||
model_type=ModelTypes.GENERATIVE_LLM,
|
||||
environment=Environments.PRODUCTION,
|
||||
)
|
||||
if response_from_arize.status_code == 200:
|
||||
print("✅ Successfully logged data to Arize!")
|
||||
else:
|
||||
print(f'❌ Logging failed "{response_from_arize.text}"')
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
pass
|
296
libs/community/langchain_community/callbacks/arthur_callback.py
Normal file
296
libs/community/langchain_community/callbacks/arthur_callback.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""ArthurAI's Callback Handler."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import arthurai
|
||||
from arthurai.core.models import ArthurModel
|
||||
|
||||
PROMPT_TOKENS = "prompt_tokens"
|
||||
COMPLETION_TOKENS = "completion_tokens"
|
||||
TOKEN_USAGE = "token_usage"
|
||||
FINISH_REASON = "finish_reason"
|
||||
DURATION = "duration"
|
||||
|
||||
|
||||
def _lazy_load_arthur() -> arthurai:
|
||||
"""Lazy load Arthur."""
|
||||
try:
|
||||
import arthurai
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To use the ArthurCallbackHandler you need the"
|
||||
" `arthurai` package. Please install it with"
|
||||
" `pip install arthurai`.",
|
||||
e,
|
||||
)
|
||||
|
||||
return arthurai
|
||||
|
||||
|
||||
class ArthurCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Arthur platform.
|
||||
|
||||
Arthur helps enterprise teams optimize model operations
|
||||
and performance at scale. The Arthur API tracks model
|
||||
performance, explainability, and fairness across tabular,
|
||||
NLP, and CV models. Our API is model- and platform-agnostic,
|
||||
and continuously scales with complex and dynamic enterprise needs.
|
||||
To learn more about Arthur, visit our website at
|
||||
https://www.arthur.ai/ or read the Arthur docs at
|
||||
https://docs.arthur.ai/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arthur_model: ArthurModel,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
super().__init__()
|
||||
arthurai = _lazy_load_arthur()
|
||||
Stage = arthurai.common.constants.Stage
|
||||
ValueType = arthurai.common.constants.ValueType
|
||||
self.arthur_model = arthur_model
|
||||
# save the attributes of this model to be used when preparing
|
||||
# inferences to log to Arthur in on_llm_end()
|
||||
self.attr_names = set([a.name for a in self.arthur_model.get_attributes()])
|
||||
self.input_attr = [
|
||||
x
|
||||
for x in self.arthur_model.get_attributes()
|
||||
if x.stage == Stage.ModelPipelineInput
|
||||
and x.value_type == ValueType.Unstructured_Text
|
||||
][0].name
|
||||
self.output_attr = [
|
||||
x
|
||||
for x in self.arthur_model.get_attributes()
|
||||
if x.stage == Stage.PredictedValue
|
||||
and x.value_type == ValueType.Unstructured_Text
|
||||
][0].name
|
||||
self.token_likelihood_attr = None
|
||||
if (
|
||||
len(
|
||||
[
|
||||
x
|
||||
for x in self.arthur_model.get_attributes()
|
||||
if x.value_type == ValueType.TokenLikelihoods
|
||||
]
|
||||
)
|
||||
> 0
|
||||
):
|
||||
self.token_likelihood_attr = [
|
||||
x
|
||||
for x in self.arthur_model.get_attributes()
|
||||
if x.value_type == ValueType.TokenLikelihoods
|
||||
][0].name
|
||||
|
||||
self.run_map: DefaultDict[str, Any] = defaultdict(dict)
|
||||
|
||||
@classmethod
|
||||
def from_credentials(
|
||||
cls,
|
||||
model_id: str,
|
||||
arthur_url: Optional[str] = "https://app.arthur.ai",
|
||||
arthur_login: Optional[str] = None,
|
||||
arthur_password: Optional[str] = None,
|
||||
) -> ArthurCallbackHandler:
|
||||
"""Initialize callback handler from Arthur credentials.
|
||||
|
||||
Args:
|
||||
model_id (str): The ID of the arthur model to log to.
|
||||
arthur_url (str, optional): The URL of the Arthur instance to log to.
|
||||
Defaults to "https://app.arthur.ai".
|
||||
arthur_login (str, optional): The login to use to connect to Arthur.
|
||||
Defaults to None.
|
||||
arthur_password (str, optional): The password to use to connect to
|
||||
Arthur. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ArthurCallbackHandler: The initialized callback handler.
|
||||
"""
|
||||
arthurai = _lazy_load_arthur()
|
||||
ArthurAI = arthurai.ArthurAI
|
||||
ResponseClientError = arthurai.common.exceptions.ResponseClientError
|
||||
|
||||
# connect to Arthur
|
||||
if arthur_login is None:
|
||||
try:
|
||||
arthur_api_key = os.environ["ARTHUR_API_KEY"]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"No Arthur authentication provided. Either give"
|
||||
" a login to the ArthurCallbackHandler"
|
||||
" or set an ARTHUR_API_KEY as an environment variable."
|
||||
)
|
||||
arthur = ArthurAI(url=arthur_url, access_key=arthur_api_key)
|
||||
else:
|
||||
if arthur_password is None:
|
||||
arthur = ArthurAI(url=arthur_url, login=arthur_login)
|
||||
else:
|
||||
arthur = ArthurAI(
|
||||
url=arthur_url, login=arthur_login, password=arthur_password
|
||||
)
|
||||
# get model from Arthur by the provided model ID
|
||||
try:
|
||||
arthur_model = arthur.get_model(model_id)
|
||||
except ResponseClientError:
|
||||
raise ValueError(
|
||||
f"Was unable to retrieve model with id {model_id} from Arthur."
|
||||
" Make sure the ID corresponds to a model that is currently"
|
||||
" registered with your Arthur account."
|
||||
)
|
||||
return cls(arthur_model)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""On LLM start, save the input prompts"""
|
||||
run_id = kwargs["run_id"]
|
||||
self.run_map[run_id]["input_texts"] = prompts
|
||||
self.run_map[run_id]["start_time"] = time()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""On LLM end, send data to Arthur."""
|
||||
try:
|
||||
import pytz # type: ignore[import]
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import pytz. Please install it with 'pip install pytz'."
|
||||
) from e
|
||||
|
||||
run_id = kwargs["run_id"]
|
||||
|
||||
# get the run params from this run ID,
|
||||
# or raise an error if this run ID has no corresponding metadata in self.run_map
|
||||
try:
|
||||
run_map_data = self.run_map[run_id]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
"This function has been called with a run_id"
|
||||
" that was never registered in on_llm_start()."
|
||||
" Restart and try running the LLM again"
|
||||
) from e
|
||||
|
||||
# mark the duration time between on_llm_start() and on_llm_end()
|
||||
time_from_start_to_end = time() - run_map_data["start_time"]
|
||||
|
||||
# create inferences to log to Arthur
|
||||
inferences = []
|
||||
for i, generations in enumerate(response.generations):
|
||||
for generation in generations:
|
||||
inference = {
|
||||
"partner_inference_id": str(uuid.uuid4()),
|
||||
"inference_timestamp": datetime.now(tz=pytz.UTC),
|
||||
self.input_attr: run_map_data["input_texts"][i],
|
||||
self.output_attr: generation.text,
|
||||
}
|
||||
|
||||
if generation.generation_info is not None:
|
||||
# add finish reason to the inference
|
||||
# if generation info contains a finish reason and
|
||||
# if the ArthurModel was registered to monitor finish_reason
|
||||
if (
|
||||
FINISH_REASON in generation.generation_info
|
||||
and FINISH_REASON in self.attr_names
|
||||
):
|
||||
inference[FINISH_REASON] = generation.generation_info[
|
||||
FINISH_REASON
|
||||
]
|
||||
|
||||
# add token likelihoods data to the inference if the ArthurModel
|
||||
# was registered to monitor token likelihoods
|
||||
logprobs_data = generation.generation_info["logprobs"]
|
||||
if (
|
||||
logprobs_data is not None
|
||||
and self.token_likelihood_attr is not None
|
||||
):
|
||||
logprobs = logprobs_data["top_logprobs"]
|
||||
likelihoods = [
|
||||
{k: np.exp(v) for k, v in logprobs[i].items()}
|
||||
for i in range(len(logprobs))
|
||||
]
|
||||
inference[self.token_likelihood_attr] = likelihoods
|
||||
|
||||
# add token usage counts to the inference if the
|
||||
# ArthurModel was registered to monitor token usage
|
||||
if (
|
||||
isinstance(response.llm_output, dict)
|
||||
and TOKEN_USAGE in response.llm_output
|
||||
):
|
||||
token_usage = response.llm_output[TOKEN_USAGE]
|
||||
if (
|
||||
PROMPT_TOKENS in token_usage
|
||||
and PROMPT_TOKENS in self.attr_names
|
||||
):
|
||||
inference[PROMPT_TOKENS] = token_usage[PROMPT_TOKENS]
|
||||
if (
|
||||
COMPLETION_TOKENS in token_usage
|
||||
and COMPLETION_TOKENS in self.attr_names
|
||||
):
|
||||
inference[COMPLETION_TOKENS] = token_usage[COMPLETION_TOKENS]
|
||||
|
||||
# add inference duration to the inference if the ArthurModel
|
||||
# was registered to monitor inference duration
|
||||
if DURATION in self.attr_names:
|
||||
inference[DURATION] = time_from_start_to_end
|
||||
|
||||
inferences.append(inference)
|
||||
|
||||
# send inferences to arthur
|
||||
self.arthur_model.send_inferences(inferences)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""On chain start, do nothing."""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""On chain end, do nothing."""
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""On new token, pass."""
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
527
libs/community/langchain_community/callbacks/clearml_callback.py
Normal file
527
libs/community/langchain_community/callbacks/clearml_callback.py
Normal file
@@ -0,0 +1,527 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
load_json,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def import_clearml() -> Any:
|
||||
"""Import the clearml python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import clearml # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the clearml callback manager you need to have the `clearml` python "
|
||||
"package installed. Please install it with `pip install clearml`"
|
||||
)
|
||||
return clearml
|
||||
|
||||
|
||||
class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs to ClearML.
|
||||
|
||||
Parameters:
|
||||
job_type (str): The type of clearml task such as "inference", "testing" or "qc"
|
||||
project_name (str): The clearml project name
|
||||
tags (list): Tags to add to the task
|
||||
task_name (str): Name of the clearml task
|
||||
visualize (bool): Whether to visualize the run.
|
||||
complexity_metrics (bool): Whether to log complexity metrics
|
||||
stream_logs (bool): Whether to stream callback actions to ClearML
|
||||
|
||||
This handler will utilize the associated callback method and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run,
|
||||
and adds the response to the list of records for both the {method}_records and
|
||||
action. It then logs the response to the ClearML console.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: Optional[str] = "inference",
|
||||
project_name: Optional[str] = "langchain_callback_demo",
|
||||
tags: Optional[Sequence] = None,
|
||||
task_name: Optional[str] = None,
|
||||
visualize: bool = False,
|
||||
complexity_metrics: bool = False,
|
||||
stream_logs: bool = False,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
clearml = import_clearml()
|
||||
spacy = import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.task_type = task_type
|
||||
self.project_name = project_name
|
||||
self.tags = tags
|
||||
self.task_name = task_name
|
||||
self.visualize = visualize
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.stream_logs = stream_logs
|
||||
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
# Check if ClearML task already exists (e.g. in pipeline)
|
||||
if clearml.Task.current_task():
|
||||
self.task = clearml.Task.current_task()
|
||||
else:
|
||||
self.task = clearml.Task.init( # type: ignore
|
||||
task_type=self.task_type,
|
||||
project_name=self.project_name,
|
||||
tags=self.tags,
|
||||
task_name=self.task_name,
|
||||
output_uri=True,
|
||||
)
|
||||
self.logger = self.task.get_logger()
|
||||
warning = (
|
||||
"The clearml callback is currently in beta and is subject to change "
|
||||
"based on updates to `langchain`. Please report any issues to "
|
||||
"https://github.com/allegroai/clearml/issues with the tag `langchain`."
|
||||
)
|
||||
self.logger.report_text(warning, level=30, print_console=True)
|
||||
self.callback_columns: list = []
|
||||
self.action_records: list = []
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.visualize = visualize
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
def _init_resp(self) -> Dict:
|
||||
return {k: None for k in self.callback_columns}
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompts"] = prompt
|
||||
self.on_llm_start_records.append(prompt_resp)
|
||||
self.action_records.append(prompt_resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(prompt_resp)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_llm_token_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
generation_resp.update(self.analyze_text(generation.text))
|
||||
self.on_llm_end_records.append(generation_resp)
|
||||
self.action_records.append(generation_resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(generation_resp)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
chain_input = inputs.get("input", inputs.get("human_input"))
|
||||
|
||||
if isinstance(chain_input, str):
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["input"] = chain_input
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(input_resp)
|
||||
elif isinstance(chain_input, list):
|
||||
for inp in chain_input:
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp.update(inp)
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(input_resp)
|
||||
else:
|
||||
raise ValueError("Unexpected data format provided!")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_chain_end",
|
||||
"outputs": outputs.get("output", outputs.get("text")),
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_chain_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_start_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_text_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_agent_finish_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
self.on_agent_action_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def analyze_text(self, text: str) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the complexity metrics.
|
||||
"""
|
||||
resp = {}
|
||||
textstat = import_textstat()
|
||||
spacy = import_spacy()
|
||||
if self.complexity_metrics:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(
|
||||
text
|
||||
),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(
|
||||
text
|
||||
),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
"text_standard": textstat.text_standard(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if self.visualize and self.nlp and self.temp_dir.name is not None:
|
||||
doc = self.nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
)
|
||||
dep_output_path = Path(
|
||||
self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
|
||||
)
|
||||
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
||||
|
||||
ent_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="ent", jupyter=False, page=True
|
||||
)
|
||||
ent_output_path = Path(
|
||||
self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
|
||||
)
|
||||
ent_output_path.open("w", encoding="utf-8").write(ent_out)
|
||||
|
||||
self.logger.report_media(
|
||||
"Dependencies Plot", text, local_path=dep_output_path
|
||||
)
|
||||
self.logger.report_media("Entities Plot", text, local_path=ent_output_path)
|
||||
|
||||
return resp
|
||||
|
||||
@staticmethod
|
||||
def _build_llm_df(
|
||||
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
|
||||
) -> pd.DataFrame:
|
||||
base_df_fields = [field for field in base_df_fields if field in base_df]
|
||||
rename_map = {
|
||||
map_entry_k: map_entry_v
|
||||
for map_entry_k, map_entry_v in rename_map.items()
|
||||
if map_entry_k in base_df_fields
|
||||
}
|
||||
llm_df = base_df[base_df_fields].dropna(axis=1)
|
||||
if rename_map:
|
||||
llm_df = llm_df.rename(rename_map, axis=1)
|
||||
return llm_df
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
|
||||
base_df=on_llm_end_records_df,
|
||||
base_df_fields=["step", "prompts"]
|
||||
+ (["name"] if "name" in on_llm_end_records_df else ["id"]),
|
||||
rename_map={"step": "prompt_step"},
|
||||
)
|
||||
complexity_metrics_columns = []
|
||||
visualizations_columns: List = []
|
||||
|
||||
if self.complexity_metrics:
|
||||
complexity_metrics_columns = [
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"smog_index",
|
||||
"coleman_liau_index",
|
||||
"automated_readability_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"linsear_write_formula",
|
||||
"gunning_fog",
|
||||
"text_standard",
|
||||
"fernandez_huerta",
|
||||
"szigriszt_pazos",
|
||||
"gutierrez_polini",
|
||||
"crawford",
|
||||
"gulpease_index",
|
||||
"osman",
|
||||
]
|
||||
|
||||
llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
|
||||
on_llm_end_records_df,
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns,
|
||||
{"step": "output_step", "text": "output"},
|
||||
)
|
||||
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
||||
return session_analysis_df
|
||||
|
||||
def flush_tracker(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
langchain_asset: Any = None,
|
||||
finish: bool = False,
|
||||
) -> None:
|
||||
"""Flush the tracker and setup the session.
|
||||
|
||||
Everything after this will be a new table.
|
||||
|
||||
Args:
|
||||
name: Name of the performed session so far so it is identifiable
|
||||
langchain_asset: The langchain asset to save.
|
||||
finish: Whether to finish the run.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pd = import_pandas()
|
||||
clearml = import_clearml()
|
||||
|
||||
# Log the action records
|
||||
self.logger.report_table(
|
||||
"Action Records", name, table_plot=pd.DataFrame(self.action_records)
|
||||
)
|
||||
|
||||
# Session analysis
|
||||
session_analysis_df = self._create_session_analysis_df()
|
||||
self.logger.report_table(
|
||||
"Session Analysis", name, table_plot=session_analysis_df
|
||||
)
|
||||
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(
|
||||
{
|
||||
"action_records": pd.DataFrame(self.action_records),
|
||||
"session_analysis": session_analysis_df,
|
||||
}
|
||||
)
|
||||
|
||||
if langchain_asset:
|
||||
langchain_asset_path = Path(self.temp_dir.name, "model.json")
|
||||
try:
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
# Create output model and connect it to the task
|
||||
output_model = clearml.OutputModel(
|
||||
task=self.task, config_text=load_json(langchain_asset_path)
|
||||
)
|
||||
output_model.update_weights(
|
||||
weights_filename=str(langchain_asset_path),
|
||||
auto_delete_file=False,
|
||||
target_filename=name,
|
||||
)
|
||||
except ValueError:
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
output_model = clearml.OutputModel(
|
||||
task=self.task, config_text=load_json(langchain_asset_path)
|
||||
)
|
||||
output_model.update_weights(
|
||||
weights_filename=str(langchain_asset_path),
|
||||
auto_delete_file=False,
|
||||
target_filename=name,
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
print("Could not save model.")
|
||||
print(repr(e))
|
||||
pass
|
||||
|
||||
# Cleanup after adding everything to ClearML
|
||||
self.task.flush(wait_for_uploads=True)
|
||||
self.temp_dir.cleanup()
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.reset_callback_meta()
|
||||
|
||||
if finish:
|
||||
self.task.close()
|
@@ -0,0 +1,645 @@
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
import langchain_community
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
|
||||
LANGCHAIN_MODEL_NAME = "langchain-model"
|
||||
|
||||
|
||||
def import_comet_ml() -> Any:
|
||||
"""Import comet_ml and raise an error if it is not installed."""
|
||||
try:
|
||||
import comet_ml # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the comet_ml callback manager you need to have the "
|
||||
"`comet_ml` python package installed. Please install it with"
|
||||
" `pip install comet_ml`"
|
||||
)
|
||||
return comet_ml
|
||||
|
||||
|
||||
def _get_experiment(
|
||||
workspace: Optional[str] = None, project_name: Optional[str] = None
|
||||
) -> Any:
|
||||
comet_ml = import_comet_ml()
|
||||
|
||||
experiment = comet_ml.Experiment( # type: ignore
|
||||
workspace=workspace,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
||||
return experiment
|
||||
|
||||
|
||||
def _fetch_text_complexity_metrics(text: str) -> dict:
|
||||
textstat = import_textstat()
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(text),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
"text_standard": textstat.text_standard(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
return text_complexity_metrics
|
||||
|
||||
|
||||
def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
|
||||
pd = import_pandas()
|
||||
metrics_df = pd.DataFrame(metrics)
|
||||
metrics_summary = metrics_df.describe()
|
||||
|
||||
return metrics_summary.to_dict()
|
||||
|
||||
|
||||
class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Comet.
|
||||
|
||||
Parameters:
|
||||
job_type (str): The type of comet_ml task such as "inference",
|
||||
"testing" or "qc"
|
||||
project_name (str): The comet_ml project name
|
||||
tags (list): Tags to add to the task
|
||||
task_name (str): Name of the comet_ml task
|
||||
visualize (bool): Whether to visualize the run.
|
||||
complexity_metrics (bool): Whether to log complexity metrics
|
||||
stream_logs (bool): Whether to stream callback actions to Comet
|
||||
|
||||
This handler will utilize the associated callback method and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run,
|
||||
and adds the response to the list of records for both the {method}_records and
|
||||
action. It then logs the response to Comet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: Optional[str] = "inference",
|
||||
workspace: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
name: Optional[str] = None,
|
||||
visualizations: Optional[List[str]] = None,
|
||||
complexity_metrics: bool = False,
|
||||
custom_metrics: Optional[Callable] = None,
|
||||
stream_logs: bool = True,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
self.comet_ml = import_comet_ml()
|
||||
super().__init__()
|
||||
|
||||
self.task_type = task_type
|
||||
self.workspace = workspace
|
||||
self.project_name = project_name
|
||||
self.tags = tags
|
||||
self.visualizations = visualizations
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.custom_metrics = custom_metrics
|
||||
self.stream_logs = stream_logs
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
self.experiment = _get_experiment(workspace, project_name)
|
||||
self.experiment.log_other("Created from", "langchain")
|
||||
if tags:
|
||||
self.experiment.add_tags(tags)
|
||||
self.name = name
|
||||
if self.name:
|
||||
self.experiment.set_name(self.name)
|
||||
|
||||
warning = (
|
||||
"The comet_ml callback is currently in beta and is subject to change "
|
||||
"based on updates to `langchain`. Please report any issues to "
|
||||
"https://github.com/comet-ml/issue-tracking/issues with the tag "
|
||||
"`langchain`."
|
||||
)
|
||||
self.comet_ml.LOGGER.warning(warning)
|
||||
|
||||
self.callback_columns: list = []
|
||||
self.action_records: list = []
|
||||
self.complexity_metrics = complexity_metrics
|
||||
if self.visualizations:
|
||||
spacy = import_spacy()
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
else:
|
||||
self.nlp = None
|
||||
|
||||
def _init_resp(self) -> Dict:
|
||||
return {k: None for k in self.callback_columns}
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
metadata = self._init_resp()
|
||||
metadata.update({"action": "on_llm_start"})
|
||||
metadata.update(flatten_dict(serialized))
|
||||
metadata.update(self.get_custom_callback_meta())
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_resp = deepcopy(metadata)
|
||||
prompt_resp["prompts"] = prompt
|
||||
self.on_llm_start_records.append(prompt_resp)
|
||||
self.action_records.append(prompt_resp)
|
||||
|
||||
if self.stream_logs:
|
||||
self._log_stream(prompt, metadata, self.step)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
metadata = self._init_resp()
|
||||
metadata.update({"action": "on_llm_end"})
|
||||
metadata.update(flatten_dict(response.llm_output or {}))
|
||||
metadata.update(self.get_custom_callback_meta())
|
||||
|
||||
output_complexity_metrics = []
|
||||
output_custom_metrics = []
|
||||
|
||||
for prompt_idx, generations in enumerate(response.generations):
|
||||
for gen_idx, generation in enumerate(generations):
|
||||
text = generation.text
|
||||
|
||||
generation_resp = deepcopy(metadata)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
|
||||
complexity_metrics = self._get_complexity_metrics(text)
|
||||
if complexity_metrics:
|
||||
output_complexity_metrics.append(complexity_metrics)
|
||||
generation_resp.update(complexity_metrics)
|
||||
|
||||
custom_metrics = self._get_custom_metrics(
|
||||
generation, prompt_idx, gen_idx
|
||||
)
|
||||
if custom_metrics:
|
||||
output_custom_metrics.append(custom_metrics)
|
||||
generation_resp.update(custom_metrics)
|
||||
|
||||
if self.stream_logs:
|
||||
self._log_stream(text, metadata, self.step)
|
||||
|
||||
self.action_records.append(generation_resp)
|
||||
self.on_llm_end_records.append(generation_resp)
|
||||
|
||||
self._log_text_metrics(output_complexity_metrics, step=self.step)
|
||||
self._log_text_metrics(output_custom_metrics, step=self.step)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for chain_input_key, chain_input_val in inputs.items():
|
||||
if isinstance(chain_input_val, str):
|
||||
input_resp = deepcopy(resp)
|
||||
if self.stream_logs:
|
||||
self._log_stream(chain_input_val, resp, self.step)
|
||||
input_resp.update({chain_input_key: chain_input_val})
|
||||
self.action_records.append(input_resp)
|
||||
|
||||
else:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
f"Unexpected data format provided! "
|
||||
f"Input Value for {chain_input_key} will not be logged"
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_end"})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for chain_output_key, chain_output_val in outputs.items():
|
||||
if isinstance(chain_output_val, str):
|
||||
output_resp = deepcopy(resp)
|
||||
if self.stream_logs:
|
||||
self._log_stream(chain_output_val, resp, self.step)
|
||||
output_resp.update({chain_output_key: chain_output_val})
|
||||
self.action_records.append(output_resp)
|
||||
else:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
f"Unexpected data format provided! "
|
||||
f"Output Value for {chain_output_key} will not be logged"
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(input_str, resp, self.step)
|
||||
|
||||
resp.update({"input_str": input_str})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_end"})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(output, resp, self.step)
|
||||
|
||||
resp.update({"output": output})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_text"})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(text, resp, self.step)
|
||||
|
||||
resp.update({"text": text})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
output = finish.return_values["output"]
|
||||
log = finish.log
|
||||
|
||||
resp.update({"action": "on_agent_finish", "log": log})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(output, resp, self.step)
|
||||
|
||||
resp.update({"output": output})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
tool = action.tool
|
||||
tool_input = str(action.tool_input)
|
||||
log = action.log
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_agent_action", "log": log, "tool": tool})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(tool_input, resp, self.step)
|
||||
|
||||
resp.update({"tool_input": tool_input})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def _get_complexity_metrics(self, text: str) -> dict:
|
||||
"""Compute text complexity metrics using textstat.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the complexity metrics.
|
||||
"""
|
||||
resp = {}
|
||||
if self.complexity_metrics:
|
||||
text_complexity_metrics = _fetch_text_complexity_metrics(text)
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
return resp
|
||||
|
||||
def _get_custom_metrics(
|
||||
self, generation: Generation, prompt_idx: int, gen_idx: int
|
||||
) -> dict:
|
||||
"""Compute Custom Metrics for an LLM Generated Output
|
||||
|
||||
Args:
|
||||
generation (LLMResult): Output generation from an LLM
|
||||
prompt_idx (int): List index of the input prompt
|
||||
gen_idx (int): List index of the generated output
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the custom metrics.
|
||||
"""
|
||||
|
||||
resp = {}
|
||||
if self.custom_metrics:
|
||||
custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx)
|
||||
resp.update(custom_metrics)
|
||||
|
||||
return resp
|
||||
|
||||
def flush_tracker(
|
||||
self,
|
||||
langchain_asset: Any = None,
|
||||
task_type: Optional[str] = "inference",
|
||||
workspace: Optional[str] = None,
|
||||
project_name: Optional[str] = "comet-langchain-demo",
|
||||
tags: Optional[Sequence] = None,
|
||||
name: Optional[str] = None,
|
||||
visualizations: Optional[List[str]] = None,
|
||||
complexity_metrics: bool = False,
|
||||
custom_metrics: Optional[Callable] = None,
|
||||
finish: bool = False,
|
||||
reset: bool = False,
|
||||
) -> None:
|
||||
"""Flush the tracker and setup the session.
|
||||
|
||||
Everything after this will be a new table.
|
||||
|
||||
Args:
|
||||
name: Name of the performed session so far so it is identifiable
|
||||
langchain_asset: The langchain asset to save.
|
||||
finish: Whether to finish the run.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._log_session(langchain_asset)
|
||||
|
||||
if langchain_asset:
|
||||
try:
|
||||
self._log_model(langchain_asset)
|
||||
except Exception:
|
||||
self.comet_ml.LOGGER.error(
|
||||
"Failed to export agent or LLM to Comet",
|
||||
exc_info=True,
|
||||
extra={"show_traceback": True},
|
||||
)
|
||||
|
||||
if finish:
|
||||
self.experiment.end()
|
||||
|
||||
if reset:
|
||||
self._reset(
|
||||
task_type,
|
||||
workspace,
|
||||
project_name,
|
||||
tags,
|
||||
name,
|
||||
visualizations,
|
||||
complexity_metrics,
|
||||
custom_metrics,
|
||||
)
|
||||
|
||||
def _log_stream(self, prompt: str, metadata: dict, step: int) -> None:
|
||||
self.experiment.log_text(prompt, metadata=metadata, step=step)
|
||||
|
||||
def _log_model(self, langchain_asset: Any) -> None:
|
||||
model_parameters = self._get_llm_parameters(langchain_asset)
|
||||
self.experiment.log_parameters(model_parameters, prefix="model")
|
||||
|
||||
langchain_asset_path = Path(self.temp_dir.name, "model.json")
|
||||
model_name = self.name if self.name else LANGCHAIN_MODEL_NAME
|
||||
|
||||
try:
|
||||
if hasattr(langchain_asset, "save"):
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
self.experiment.log_model(model_name, str(langchain_asset_path))
|
||||
except (ValueError, AttributeError, NotImplementedError) as e:
|
||||
if hasattr(langchain_asset, "save_agent"):
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
self.experiment.log_model(model_name, str(langchain_asset_path))
|
||||
else:
|
||||
self.comet_ml.LOGGER.error(
|
||||
f"{e}"
|
||||
" Could not save Langchain Asset "
|
||||
f"for {langchain_asset.__class__.__name__}"
|
||||
)
|
||||
|
||||
def _log_session(self, langchain_asset: Optional[Any] = None) -> None:
|
||||
try:
|
||||
llm_session_df = self._create_session_analysis_dataframe(langchain_asset)
|
||||
# Log the cleaned dataframe as a table
|
||||
self.experiment.log_table("langchain-llm-session.csv", llm_session_df)
|
||||
except Exception:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
"Failed to log session data to Comet",
|
||||
exc_info=True,
|
||||
extra={"show_traceback": True},
|
||||
)
|
||||
|
||||
try:
|
||||
metadata = {"langchain_version": str(langchain_community.__version__)}
|
||||
# Log the langchain low-level records as a JSON file directly
|
||||
self.experiment.log_asset_data(
|
||||
self.action_records, "langchain-action_records.json", metadata=metadata
|
||||
)
|
||||
except Exception:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
"Failed to log session data to Comet",
|
||||
exc_info=True,
|
||||
extra={"show_traceback": True},
|
||||
)
|
||||
|
||||
try:
|
||||
self._log_visualizations(llm_session_df)
|
||||
except Exception:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
"Failed to log visualizations to Comet",
|
||||
exc_info=True,
|
||||
extra={"show_traceback": True},
|
||||
)
|
||||
|
||||
def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None:
|
||||
if not metrics:
|
||||
return
|
||||
|
||||
metrics_summary = _summarize_metrics_for_generated_outputs(metrics)
|
||||
for key, value in metrics_summary.items():
|
||||
self.experiment.log_metrics(value, prefix=key, step=step)
|
||||
|
||||
def _log_visualizations(self, session_df: Any) -> None:
|
||||
if not (self.visualizations and self.nlp):
|
||||
return
|
||||
|
||||
spacy = import_spacy()
|
||||
|
||||
prompts = session_df["prompts"].tolist()
|
||||
outputs = session_df["text"].tolist()
|
||||
|
||||
for idx, (prompt, output) in enumerate(zip(prompts, outputs)):
|
||||
doc = self.nlp(output)
|
||||
sentence_spans = list(doc.sents)
|
||||
|
||||
for visualization in self.visualizations:
|
||||
try:
|
||||
html = spacy.displacy.render(
|
||||
sentence_spans,
|
||||
style=visualization,
|
||||
options={"compact": True},
|
||||
jupyter=False,
|
||||
page=True,
|
||||
)
|
||||
self.experiment.log_asset_data(
|
||||
html,
|
||||
name=f"langchain-viz-{visualization}-{idx}.html",
|
||||
metadata={"prompt": prompt},
|
||||
step=idx,
|
||||
)
|
||||
except Exception as e:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
e, exc_info=True, extra={"show_traceback": True}
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def _reset(
|
||||
self,
|
||||
task_type: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
name: Optional[str] = None,
|
||||
visualizations: Optional[List[str]] = None,
|
||||
complexity_metrics: bool = False,
|
||||
custom_metrics: Optional[Callable] = None,
|
||||
) -> None:
|
||||
_task_type = task_type if task_type else self.task_type
|
||||
_workspace = workspace if workspace else self.workspace
|
||||
_project_name = project_name if project_name else self.project_name
|
||||
_tags = tags if tags else self.tags
|
||||
_name = name if name else self.name
|
||||
_visualizations = visualizations if visualizations else self.visualizations
|
||||
_complexity_metrics = (
|
||||
complexity_metrics if complexity_metrics else self.complexity_metrics
|
||||
)
|
||||
_custom_metrics = custom_metrics if custom_metrics else self.custom_metrics
|
||||
|
||||
self.__init__( # type: ignore
|
||||
task_type=_task_type,
|
||||
workspace=_workspace,
|
||||
project_name=_project_name,
|
||||
tags=_tags,
|
||||
name=_name,
|
||||
visualizations=_visualizations,
|
||||
complexity_metrics=_complexity_metrics,
|
||||
custom_metrics=_custom_metrics,
|
||||
)
|
||||
|
||||
self.reset_callback_meta()
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
|
||||
pd = import_pandas()
|
||||
|
||||
llm_parameters = self._get_llm_parameters(langchain_asset)
|
||||
num_generations_per_prompt = llm_parameters.get("n", 1)
|
||||
|
||||
llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
|
||||
# Repeat each input row based on the number of outputs generated per prompt
|
||||
llm_start_records_df = llm_start_records_df.loc[
|
||||
llm_start_records_df.index.repeat(num_generations_per_prompt)
|
||||
].reset_index(drop=True)
|
||||
llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_session_df = pd.merge(
|
||||
llm_start_records_df,
|
||||
llm_end_records_df,
|
||||
left_index=True,
|
||||
right_index=True,
|
||||
suffixes=["_llm_start", "_llm_end"],
|
||||
)
|
||||
|
||||
return llm_session_df
|
||||
|
||||
def _get_llm_parameters(self, langchain_asset: Any = None) -> dict:
|
||||
if not langchain_asset:
|
||||
return {}
|
||||
try:
|
||||
if hasattr(langchain_asset, "agent"):
|
||||
llm_parameters = langchain_asset.agent.llm_chain.llm.dict()
|
||||
elif hasattr(langchain_asset, "llm_chain"):
|
||||
llm_parameters = langchain_asset.llm_chain.llm.dict()
|
||||
elif hasattr(langchain_asset, "llm"):
|
||||
llm_parameters = langchain_asset.llm.dict()
|
||||
else:
|
||||
llm_parameters = langchain_asset.dict()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
return llm_parameters
|
@@ -0,0 +1,183 @@
|
||||
# flake8: noqa
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class DeepEvalCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs into deepeval.
|
||||
|
||||
Args:
|
||||
implementation_name: name of the `implementation` in deepeval
|
||||
metrics: A list of metrics
|
||||
|
||||
Raises:
|
||||
ImportError: if the `deepeval` package is not installed.
|
||||
|
||||
Examples:
|
||||
>>> from langchain_community.llms import OpenAI
|
||||
>>> from langchain_community.callbacks import DeepEvalCallbackHandler
|
||||
>>> from deepeval.metrics import AnswerRelevancy
|
||||
>>> metric = AnswerRelevancy(minimum_score=0.3)
|
||||
>>> deepeval_callback = DeepEvalCallbackHandler(
|
||||
... implementation_name="exampleImplementation",
|
||||
... metrics=[metric],
|
||||
... )
|
||||
>>> llm = OpenAI(
|
||||
... temperature=0,
|
||||
... callbacks=[deepeval_callback],
|
||||
... verbose=True,
|
||||
... openai_api_key="API_KEY_HERE",
|
||||
... )
|
||||
>>> llm.generate([
|
||||
... "What is the best evaluation tool out there? (no bias at all)",
|
||||
... ])
|
||||
"Deepeval, no doubt about it."
|
||||
"""
|
||||
|
||||
REPO_URL: str = "https://github.com/confident-ai/deepeval"
|
||||
ISSUES_URL: str = f"{REPO_URL}/issues"
|
||||
BLOG_URL: str = "https://docs.confident-ai.com" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metrics: List[Any],
|
||||
implementation_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initializes the `deepevalCallbackHandler`.
|
||||
|
||||
Args:
|
||||
implementation_name: Name of the implementation you want.
|
||||
metrics: What metrics do you want to track?
|
||||
|
||||
Raises:
|
||||
ImportError: if the `deepeval` package is not installed.
|
||||
ConnectionError: if the connection to deepeval fails.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Import deepeval (not via `import_deepeval` to keep hints in IDEs)
|
||||
try:
|
||||
import deepeval # ignore: F401,I001
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"""To use the deepeval callback manager you need to have the
|
||||
`deepeval` Python package installed. Please install it with
|
||||
`pip install deepeval`"""
|
||||
)
|
||||
|
||||
if os.path.exists(".deepeval"):
|
||||
warnings.warn(
|
||||
"""You are currently not logging anything to the dashboard, we
|
||||
recommend using `deepeval login`."""
|
||||
)
|
||||
|
||||
# Set the deepeval variables
|
||||
self.implementation_name = implementation_name
|
||||
self.metrics = metrics
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"The `DeepEvalCallbackHandler` is currently in beta and is subject to"
|
||||
" change based on updates to `langchain`. Please report any issues to"
|
||||
f" {self.ISSUES_URL} as an `integration` issue."
|
||||
),
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Store the prompts"""
|
||||
self.prompts = prompts
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing when a new token is generated."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Log records to deepeval when an LLM ends."""
|
||||
from deepeval.metrics.answer_relevancy import AnswerRelevancy
|
||||
from deepeval.metrics.bias_classifier import UnBiasedMetric
|
||||
from deepeval.metrics.metric import Metric
|
||||
from deepeval.metrics.toxic_classifier import NonToxicMetric
|
||||
|
||||
for metric in self.metrics:
|
||||
for i, generation in enumerate(response.generations):
|
||||
# Here, we only measure the first generation's output
|
||||
output = generation[0].text
|
||||
query = self.prompts[i]
|
||||
if isinstance(metric, AnswerRelevancy):
|
||||
result = metric.measure(
|
||||
output=output,
|
||||
query=query,
|
||||
)
|
||||
print(f"Answer Relevancy: {result}")
|
||||
elif isinstance(metric, UnBiasedMetric):
|
||||
score = metric.measure(output)
|
||||
print(f"Bias Score: {score}")
|
||||
elif isinstance(metric, NonToxicMetric):
|
||||
score = metric.measure(output)
|
||||
print(f"Toxic Score: {score}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""Metric {metric.__name__} is not supported by deepeval
|
||||
callbacks."""
|
||||
)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when chain starts"""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Do nothing when chain ends."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
194
libs/community/langchain_community/callbacks/context_callback.py
Normal file
194
libs/community/langchain_community/callbacks/context_callback.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Callback handler for Context AI"""
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
def import_context() -> Any:
|
||||
"""Import the `getcontext` package."""
|
||||
try:
|
||||
import getcontext # noqa: F401
|
||||
from getcontext.generated.models import (
|
||||
Conversation,
|
||||
Message,
|
||||
MessageRole,
|
||||
Rating,
|
||||
)
|
||||
from getcontext.token import Credential # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the context callback manager you need to have the "
|
||||
"`getcontext` python package installed (version >=0.3.0). "
|
||||
"Please install it with `pip install --upgrade python-context`"
|
||||
)
|
||||
return getcontext, Credential, Conversation, Message, MessageRole, Rating
|
||||
|
||||
|
||||
class ContextCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that records transcripts to the Context service.
|
||||
|
||||
(https://context.ai).
|
||||
|
||||
Keyword Args:
|
||||
token (optional): The token with which to authenticate requests to Context.
|
||||
Visit https://with.context.ai/settings to generate a token.
|
||||
If not provided, the value of the `CONTEXT_TOKEN` environment
|
||||
variable will be used.
|
||||
|
||||
Raises:
|
||||
ImportError: if the `context-python` package is not installed.
|
||||
|
||||
Chat Example:
|
||||
>>> from langchain_community.llms import ChatOpenAI
|
||||
>>> from langchain_community.callbacks import ContextCallbackHandler
|
||||
>>> context_callback = ContextCallbackHandler(
|
||||
... token="<CONTEXT_TOKEN_HERE>",
|
||||
... )
|
||||
>>> chat = ChatOpenAI(
|
||||
... temperature=0,
|
||||
... headers={"user_id": "123"},
|
||||
... callbacks=[context_callback],
|
||||
... openai_api_key="API_KEY_HERE",
|
||||
... )
|
||||
>>> messages = [
|
||||
... SystemMessage(content="You translate English to French."),
|
||||
... HumanMessage(content="I love programming with LangChain."),
|
||||
... ]
|
||||
>>> chat(messages)
|
||||
|
||||
Chain Example:
|
||||
>>> from langchain.chains import LLMChain
|
||||
>>> from langchain_community.chat_models import ChatOpenAI
|
||||
>>> from langchain_community.callbacks import ContextCallbackHandler
|
||||
>>> context_callback = ContextCallbackHandler(
|
||||
... token="<CONTEXT_TOKEN_HERE>",
|
||||
... )
|
||||
>>> human_message_prompt = HumanMessagePromptTemplate(
|
||||
... prompt=PromptTemplate(
|
||||
... template="What is a good name for a company that makes {product}?",
|
||||
... input_variables=["product"],
|
||||
... ),
|
||||
... )
|
||||
>>> chat_prompt_template = ChatPromptTemplate.from_messages(
|
||||
... [human_message_prompt]
|
||||
... )
|
||||
>>> callback = ContextCallbackHandler(token)
|
||||
>>> # Note: the same callback object must be shared between the
|
||||
... LLM and the chain.
|
||||
>>> chat = ChatOpenAI(temperature=0.9, callbacks=[callback])
|
||||
>>> chain = LLMChain(
|
||||
... llm=chat,
|
||||
... prompt=chat_prompt_template,
|
||||
... callbacks=[callback]
|
||||
... )
|
||||
>>> chain.run("colorful socks")
|
||||
"""
|
||||
|
||||
def __init__(self, token: str = "", verbose: bool = False, **kwargs: Any) -> None:
|
||||
(
|
||||
self.context,
|
||||
self.credential,
|
||||
self.conversation_model,
|
||||
self.message_model,
|
||||
self.message_role_model,
|
||||
self.rating_model,
|
||||
) = import_context()
|
||||
|
||||
token = token or os.environ.get("CONTEXT_TOKEN") or ""
|
||||
|
||||
self.client = self.context.ContextAPI(credential=self.credential(token))
|
||||
|
||||
self.chain_run_id = None
|
||||
|
||||
self.llm_model = None
|
||||
|
||||
self.messages: List[Any] = []
|
||||
self.metadata: Dict[str, str] = {}
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when the chat model is started."""
|
||||
llm_model = kwargs.get("invocation_params", {}).get("model", None)
|
||||
if llm_model is not None:
|
||||
self.metadata["model"] = llm_model
|
||||
|
||||
if len(messages) == 0:
|
||||
return
|
||||
|
||||
for message in messages[0]:
|
||||
role = self.message_role_model.SYSTEM
|
||||
if message.type == "human":
|
||||
role = self.message_role_model.USER
|
||||
elif message.type == "system":
|
||||
role = self.message_role_model.SYSTEM
|
||||
elif message.type == "ai":
|
||||
role = self.message_role_model.ASSISTANT
|
||||
|
||||
self.messages.append(
|
||||
self.message_model(
|
||||
message=message.content,
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends."""
|
||||
if len(response.generations) == 0 or len(response.generations[0]) == 0:
|
||||
return
|
||||
|
||||
if not self.chain_run_id:
|
||||
generation = response.generations[0][0]
|
||||
self.messages.append(
|
||||
self.message_model(
|
||||
message=generation.text,
|
||||
role=self.message_role_model.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
self._log_conversation()
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts."""
|
||||
self.chain_run_id = kwargs.get("run_id", None)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends."""
|
||||
self.messages.append(
|
||||
self.message_model(
|
||||
message=outputs["text"],
|
||||
role=self.message_role_model.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
self._log_conversation()
|
||||
|
||||
self.chain_run_id = None
|
||||
|
||||
def _log_conversation(self) -> None:
|
||||
"""Log the conversation to the context API."""
|
||||
if len(self.messages) == 0:
|
||||
return
|
||||
|
||||
self.client.log.conversation_upsert(
|
||||
body={
|
||||
"conversation": self.conversation_model(
|
||||
messages=self.messages,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
self.messages = []
|
||||
self.metadata = {}
|
371
libs/community/langchain_community/callbacks/flyte_callback.py
Normal file
371
libs/community/langchain_community/callbacks/flyte_callback.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""FlyteKit callback handler."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import flytekit
|
||||
from flytekitplugins.deck import renderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_flytekit() -> Tuple[flytekit, renderer]:
|
||||
"""Import flytekit and flytekitplugins-deck-standard."""
|
||||
try:
|
||||
import flytekit # noqa: F401
|
||||
from flytekitplugins.deck import renderer # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the flyte callback manager you need"
|
||||
"to have the `flytekit` and `flytekitplugins-deck-standard`"
|
||||
"packages installed. Please install them with `pip install flytekit`"
|
||||
"and `pip install flytekitplugins-deck-standard`."
|
||||
)
|
||||
return flytekit, renderer
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: str,
|
||||
nlp: Any = None,
|
||||
textstat: Any = None,
|
||||
) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
nlp (spacy.lang): The spacy language model to use for visualization.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the complexity metrics and visualization
|
||||
files serialized to HTML string.
|
||||
"""
|
||||
resp: Dict[str, Any] = {}
|
||||
if textstat is not None:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(text),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
resp.update({"text_complexity_metrics": text_complexity_metrics})
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if nlp is not None:
|
||||
spacy = import_spacy()
|
||||
doc = nlp(text)
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
)
|
||||
ent_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="ent", jupyter=False, page=True
|
||||
)
|
||||
text_visualizations = {
|
||||
"dependency_tree": dep_out,
|
||||
"entities": ent_out,
|
||||
}
|
||||
resp.update(text_visualizations)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""This callback handler that is used within a Flyte task."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize callback handler."""
|
||||
flytekit, renderer = import_flytekit()
|
||||
self.pandas = import_pandas()
|
||||
|
||||
self.textstat = None
|
||||
try:
|
||||
self.textstat = import_textstat()
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Textstat library is not installed. \
|
||||
It may result in the inability to log \
|
||||
certain metrics that can be captured with Textstat."
|
||||
)
|
||||
|
||||
spacy = None
|
||||
try:
|
||||
spacy = import_spacy()
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Spacy library is not installed. \
|
||||
It may result in the inability to log \
|
||||
certain metrics that can be captured with Spacy."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.nlp = None
|
||||
if spacy:
|
||||
try:
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
except OSError:
|
||||
logger.warning(
|
||||
"FlyteCallbackHandler uses spacy's en_core_web_sm model"
|
||||
" for certain metrics. To download,"
|
||||
" run the following command in your terminal:"
|
||||
" `python -m spacy download en_core_web_sm`"
|
||||
)
|
||||
|
||||
self.table_renderer = renderer.TableRenderer
|
||||
self.markdown_renderer = renderer.MarkdownRenderer
|
||||
|
||||
self.deck = flytekit.Deck(
|
||||
"LangChain Metrics",
|
||||
self.markdown_renderer().to_html("## LangChain Metrics"),
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
prompt_responses = []
|
||||
for prompt in prompts:
|
||||
prompt_responses.append(prompt)
|
||||
|
||||
resp.update({"prompts": prompt_responses})
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### LLM Start"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### LLM End"))
|
||||
self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp])))
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
if self.nlp or self.textstat:
|
||||
generation_resp.update(
|
||||
analyze_text(
|
||||
generation.text, nlp=self.nlp, textstat=self.textstat
|
||||
)
|
||||
)
|
||||
|
||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||
"text_complexity_metrics"
|
||||
) # type: ignore # noqa: E501
|
||||
self.deck.append(
|
||||
self.markdown_renderer().to_html("#### Text Complexity Metrics")
|
||||
)
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(
|
||||
self.pandas.DataFrame([complexity_metrics])
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
dependency_tree = generation_resp["dependency_tree"]
|
||||
self.deck.append(
|
||||
self.markdown_renderer().to_html("#### Dependency Tree")
|
||||
)
|
||||
self.deck.append(dependency_tree)
|
||||
|
||||
entities = generation_resp["entities"]
|
||||
self.deck.append(self.markdown_renderer().to_html("#### Entities"))
|
||||
self.deck.append(entities)
|
||||
else:
|
||||
self.deck.append(
|
||||
self.markdown_renderer().to_html("#### Generated Response")
|
||||
)
|
||||
self.deck.append(self.markdown_renderer().to_html(generation.text))
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["inputs"] = chain_input
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Chain Start"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
||||
resp.update({"action": "on_chain_end", "outputs": chain_output})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Chain End"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Tool Start"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Tool End"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### On Text"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Agent Finish"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Agent Action"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
88
libs/community/langchain_community/callbacks/human.py
Normal file
88
libs/community/langchain_community/callbacks/human.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler
|
||||
|
||||
|
||||
def _default_approve(_input: str) -> bool:
|
||||
msg = (
|
||||
"Do you approve of the following input? "
|
||||
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
|
||||
)
|
||||
msg += "\n\n" + _input + "\n"
|
||||
resp = input(msg)
|
||||
return resp.lower() in ("yes", "y")
|
||||
|
||||
|
||||
async def _adefault_approve(_input: str) -> bool:
|
||||
msg = (
|
||||
"Do you approve of the following input? "
|
||||
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
|
||||
)
|
||||
msg += "\n\n" + _input + "\n"
|
||||
resp = input(msg)
|
||||
return resp.lower() in ("yes", "y")
|
||||
|
||||
|
||||
def _default_true(_: Dict[str, Any]) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class HumanRejectedException(Exception):
|
||||
"""Exception to raise when a person manually review and rejects a value."""
|
||||
|
||||
|
||||
class HumanApprovalCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback for manually validating values."""
|
||||
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approve: Callable[[Any], bool] = _default_approve,
|
||||
should_check: Callable[[Dict[str, Any]], bool] = _default_true,
|
||||
):
|
||||
self._approve = approve
|
||||
self._should_check = should_check
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self._should_check(serialized) and not self._approve(input_str):
|
||||
raise HumanRejectedException(
|
||||
f"Inputs {input_str} to tool {serialized} were rejected."
|
||||
)
|
||||
|
||||
|
||||
class AsyncHumanApprovalCallbackHandler(AsyncCallbackHandler):
|
||||
"""Asynchronous callback for manually validating values."""
|
||||
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approve: Callable[[Any], Awaitable[bool]] = _adefault_approve,
|
||||
should_check: Callable[[Dict[str, Any]], bool] = _default_true,
|
||||
):
|
||||
self._approve = approve
|
||||
self._should_check = should_check
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self._should_check(serialized) and not await self._approve(input_str):
|
||||
raise HumanRejectedException(
|
||||
f"Inputs {input_str} to tool {serialized} were rejected."
|
||||
)
|
266
libs/community/langchain_community/callbacks/infino_callback.py
Normal file
266
libs/community/langchain_community/callbacks/infino_callback.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
def import_infino() -> Any:
|
||||
"""Import the infino client."""
|
||||
try:
|
||||
from infinopy import InfinoClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the Infino callbacks manager you need to have the"
|
||||
" `infinopy` python package installed."
|
||||
"Please install it with `pip install infinopy`"
|
||||
)
|
||||
return InfinoClient()
|
||||
|
||||
|
||||
def import_tiktoken() -> Any:
|
||||
"""Import tiktoken for counting tokens for OpenAI models."""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the ChatOpenAI model with Infino callback manager, you need to "
|
||||
"have the `tiktoken` python package installed."
|
||||
"Please install it with `pip install tiktoken`"
|
||||
)
|
||||
return tiktoken
|
||||
|
||||
|
||||
def get_num_tokens(string: str, openai_model_name: str) -> int:
|
||||
"""Calculate num tokens for OpenAI with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/main
|
||||
/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
"""
|
||||
tiktoken = import_tiktoken()
|
||||
|
||||
encoding = tiktoken.encoding_for_model(openai_model_name)
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
|
||||
|
||||
class InfinoCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Infino."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
# Set Infino client
|
||||
self.client = import_infino()
|
||||
self.model_id = model_id
|
||||
self.model_version = model_version
|
||||
self.verbose = verbose
|
||||
self.is_chat_openai_model = False
|
||||
self.chat_openai_model_name = "gpt-3.5-turbo"
|
||||
|
||||
def _send_to_infino(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
is_ts: bool = True,
|
||||
) -> None:
|
||||
"""Send the key-value to Infino.
|
||||
|
||||
Parameters:
|
||||
key (str): the key to send to Infino.
|
||||
value (Any): the value to send to Infino.
|
||||
is_ts (bool): if True, the value is part of a time series, else it
|
||||
is sent as a log message.
|
||||
"""
|
||||
payload = {
|
||||
"date": int(time.time()),
|
||||
key: value,
|
||||
"labels": {
|
||||
"model_id": self.model_id,
|
||||
"model_version": self.model_version,
|
||||
},
|
||||
}
|
||||
if self.verbose:
|
||||
print(f"Tracking {key} with Infino: {payload}")
|
||||
|
||||
# Append to Infino time series only if is_ts is True, otherwise
|
||||
# append to Infino log.
|
||||
if is_ts:
|
||||
self.client.append_ts(payload)
|
||||
else:
|
||||
self.client.append_log(payload)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Log the prompts to Infino, and set start time and error flag."""
|
||||
for prompt in prompts:
|
||||
self._send_to_infino("prompt", prompt, is_ts=False)
|
||||
|
||||
# Set the error flag to indicate no error (this will get overridden
|
||||
# in on_llm_error if an error occurs).
|
||||
self.error = 0
|
||||
|
||||
# Set the start time (so that we can calculate the request
|
||||
# duration in on_llm_end).
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing when a new token is generated."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Log the latency, error, token usage, and response to Infino."""
|
||||
# Calculate and track the request latency.
|
||||
self.end_time = time.time()
|
||||
duration = self.end_time - self.start_time
|
||||
self._send_to_infino("latency", duration)
|
||||
|
||||
# Track success or error flag.
|
||||
self._send_to_infino("error", self.error)
|
||||
|
||||
# Track prompt response.
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
self._send_to_infino("prompt_response", generation.text, is_ts=False)
|
||||
|
||||
# Track token usage (for non-chat models).
|
||||
if (response.llm_output is not None) and isinstance(response.llm_output, Dict):
|
||||
token_usage = response.llm_output["token_usage"]
|
||||
if token_usage is not None:
|
||||
prompt_tokens = token_usage["prompt_tokens"]
|
||||
total_tokens = token_usage["total_tokens"]
|
||||
completion_tokens = token_usage["completion_tokens"]
|
||||
self._send_to_infino("prompt_tokens", prompt_tokens)
|
||||
self._send_to_infino("total_tokens", total_tokens)
|
||||
self._send_to_infino("completion_tokens", completion_tokens)
|
||||
|
||||
# Track completion token usage (for openai chat models).
|
||||
if self.is_chat_openai_model:
|
||||
messages = " ".join(
|
||||
generation.message.content # type: ignore[attr-defined]
|
||||
for generation in generations
|
||||
)
|
||||
completion_tokens = get_num_tokens(
|
||||
messages, openai_model_name=self.chat_openai_model_name
|
||||
)
|
||||
self._send_to_infino("completion_tokens", completion_tokens)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Set the error flag."""
|
||||
self.error = 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM chain starts."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain ends."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Need to log the error."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
# Currently, for chat models, we only support input prompts for ChatOpenAI.
|
||||
# Check if this model is a ChatOpenAI model.
|
||||
values = serialized.get("id")
|
||||
if values:
|
||||
for value in values:
|
||||
if value == "ChatOpenAI":
|
||||
self.is_chat_openai_model = True
|
||||
break
|
||||
|
||||
# Track prompt tokens for ChatOpenAI model.
|
||||
if self.is_chat_openai_model:
|
||||
invocation_params = kwargs.get("invocation_params")
|
||||
if invocation_params:
|
||||
model_name = invocation_params.get("model_name")
|
||||
if model_name:
|
||||
self.chat_openai_model_name = model_name
|
||||
prompt_tokens = 0
|
||||
for message_list in messages:
|
||||
message_string = " ".join(
|
||||
cast(str, msg.content) for msg in message_list
|
||||
)
|
||||
num_tokens = get_num_tokens(
|
||||
message_string,
|
||||
openai_model_name=self.chat_openai_model_name,
|
||||
)
|
||||
prompt_tokens += num_tokens
|
||||
|
||||
self._send_to_infino("prompt_tokens", prompt_tokens)
|
||||
|
||||
if self.verbose:
|
||||
print(
|
||||
f"on_chat_model_start: is_chat_openai_model= \
|
||||
{self.is_chat_openai_model}, \
|
||||
chat_openai_model_name={self.chat_openai_model_name}"
|
||||
)
|
||||
|
||||
# Send the prompt to infino
|
||||
prompt = " ".join(
|
||||
cast(str, msg.content) for sublist in messages for msg in sublist
|
||||
)
|
||||
self._send_to_infino("prompt", prompt, is_ts=False)
|
||||
|
||||
# Set the error flag to indicate no error (this will get overridden
|
||||
# in on_llm_error if an error occurs).
|
||||
self.error = 0
|
||||
|
||||
# Set the start time (so that we can calculate the request
|
||||
# duration in on_llm_end).
|
||||
self.start_time = time.time()
|
@@ -0,0 +1,390 @@
|
||||
import os
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage, ChatMessage
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
|
||||
class LabelStudioMode(Enum):
|
||||
"""Label Studio mode enumerator."""
|
||||
|
||||
PROMPT = "prompt"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
def get_default_label_configs(
|
||||
mode: Union[str, LabelStudioMode],
|
||||
) -> Tuple[str, LabelStudioMode]:
|
||||
"""Get default Label Studio configs for the given mode.
|
||||
|
||||
Parameters:
|
||||
mode: Label Studio mode ("prompt" or "chat")
|
||||
|
||||
Returns: Tuple of Label Studio config and mode
|
||||
"""
|
||||
_default_label_configs = {
|
||||
LabelStudioMode.PROMPT.value: """
|
||||
<View>
|
||||
<Style>
|
||||
.prompt-box {
|
||||
background-color: white;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
|
||||
padding: 20px;
|
||||
}
|
||||
</Style>
|
||||
<View className="root">
|
||||
<View className="prompt-box">
|
||||
<Text name="prompt" value="$prompt"/>
|
||||
</View>
|
||||
<TextArea name="response" toName="prompt"
|
||||
maxSubmissions="1" editable="true"
|
||||
required="true"/>
|
||||
</View>
|
||||
<Header value="Rate the response:"/>
|
||||
<Rating name="rating" toName="prompt"/>
|
||||
</View>""",
|
||||
LabelStudioMode.CHAT.value: """
|
||||
<View>
|
||||
<View className="root">
|
||||
<Paragraphs name="dialogue"
|
||||
value="$prompt"
|
||||
layout="dialogue"
|
||||
textKey="content"
|
||||
nameKey="role"
|
||||
granularity="sentence"/>
|
||||
<Header value="Final response:"/>
|
||||
<TextArea name="response" toName="dialogue"
|
||||
maxSubmissions="1" editable="true"
|
||||
required="true"/>
|
||||
</View>
|
||||
<Header value="Rate the response:"/>
|
||||
<Rating name="rating" toName="dialogue"/>
|
||||
</View>""",
|
||||
}
|
||||
|
||||
if isinstance(mode, str):
|
||||
mode = LabelStudioMode(mode)
|
||||
|
||||
return _default_label_configs[mode.value], mode
|
||||
|
||||
|
||||
class LabelStudioCallbackHandler(BaseCallbackHandler):
|
||||
"""Label Studio callback handler.
|
||||
Provides the ability to send predictions to Label Studio
|
||||
for human evaluation, feedback and annotation.
|
||||
|
||||
Parameters:
|
||||
api_key: Label Studio API key
|
||||
url: Label Studio URL
|
||||
project_id: Label Studio project ID
|
||||
project_name: Label Studio project name
|
||||
project_config: Label Studio project config (XML)
|
||||
mode: Label Studio mode ("prompt" or "chat")
|
||||
|
||||
Examples:
|
||||
>>> from langchain_community.llms import OpenAI
|
||||
>>> from langchain_community.callbacks import LabelStudioCallbackHandler
|
||||
>>> handler = LabelStudioCallbackHandler(
|
||||
... api_key='<your_key_here>',
|
||||
... url='http://localhost:8080',
|
||||
... project_name='LangChain-%Y-%m-%d',
|
||||
... mode='prompt'
|
||||
... )
|
||||
>>> llm = OpenAI(callbacks=[handler])
|
||||
>>> llm.predict('Tell me a story about a dog.')
|
||||
"""
|
||||
|
||||
DEFAULT_PROJECT_NAME: str = "LangChain-%Y-%m-%d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
project_id: Optional[int] = None,
|
||||
project_name: str = DEFAULT_PROJECT_NAME,
|
||||
project_config: Optional[str] = None,
|
||||
mode: Union[str, LabelStudioMode] = LabelStudioMode.PROMPT,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Import LabelStudio SDK
|
||||
try:
|
||||
import label_studio_sdk as ls
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f"You're using {self.__class__.__name__} in your code,"
|
||||
f" but you don't have the LabelStudio SDK "
|
||||
f"Python package installed or upgraded to the latest version. "
|
||||
f"Please run `pip install -U label-studio-sdk`"
|
||||
f" before using this callback."
|
||||
)
|
||||
|
||||
# Check if Label Studio API key is provided
|
||||
if not api_key:
|
||||
if os.getenv("LABEL_STUDIO_API_KEY"):
|
||||
api_key = str(os.getenv("LABEL_STUDIO_API_KEY"))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"You're using {self.__class__.__name__} in your code,"
|
||||
f" Label Studio API key is not provided. "
|
||||
f"Please provide Label Studio API key: "
|
||||
f"go to the Label Studio instance, navigate to "
|
||||
f"Account & Settings -> Access Token and copy the key. "
|
||||
f"Use the key as a parameter for the callback: "
|
||||
f"{self.__class__.__name__}"
|
||||
f"(label_studio_api_key='<your_key_here>', ...) or "
|
||||
f"set the environment variable LABEL_STUDIO_API_KEY=<your_key_here>"
|
||||
)
|
||||
self.api_key = api_key
|
||||
|
||||
if not url:
|
||||
if os.getenv("LABEL_STUDIO_URL"):
|
||||
url = os.getenv("LABEL_STUDIO_URL")
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Label Studio URL is not provided, "
|
||||
f"using default URL: {ls.LABEL_STUDIO_DEFAULT_URL}"
|
||||
f"If you want to provide your own URL, use the parameter: "
|
||||
f"{self.__class__.__name__}"
|
||||
f"(label_studio_url='<your_url_here>', ...) "
|
||||
f"or set the environment variable LABEL_STUDIO_URL=<your_url_here>"
|
||||
)
|
||||
url = ls.LABEL_STUDIO_DEFAULT_URL
|
||||
self.url = url
|
||||
|
||||
# Maps run_id to prompts
|
||||
self.payload: Dict[str, Dict] = {}
|
||||
|
||||
self.ls_client = ls.Client(url=self.url, api_key=self.api_key)
|
||||
self.project_name = project_name
|
||||
if project_config:
|
||||
self.project_config = project_config
|
||||
self.mode = None
|
||||
else:
|
||||
self.project_config, self.mode = get_default_label_configs(mode)
|
||||
|
||||
self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
|
||||
if self.project_id is not None:
|
||||
self.ls_project = self.ls_client.get_project(int(self.project_id))
|
||||
else:
|
||||
project_title = datetime.today().strftime(self.project_name)
|
||||
existing_projects = self.ls_client.get_projects(title=project_title)
|
||||
if existing_projects:
|
||||
self.ls_project = existing_projects[0]
|
||||
self.project_id = self.ls_project.id
|
||||
else:
|
||||
self.ls_project = self.ls_client.create_project(
|
||||
title=project_title, label_config=self.project_config
|
||||
)
|
||||
self.project_id = self.ls_project.id
|
||||
self.parsed_label_config = self.ls_project.parsed_label_config
|
||||
|
||||
# Find the first TextArea tag
|
||||
# "from_name", "to_name", "value" will be used to create predictions
|
||||
self.from_name, self.to_name, self.value, self.input_type = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
for tag_name, tag_info in self.parsed_label_config.items():
|
||||
if tag_info["type"] == "TextArea":
|
||||
self.from_name = tag_name
|
||||
self.to_name = tag_info["to_name"][0]
|
||||
self.value = tag_info["inputs"][0]["value"]
|
||||
self.input_type = tag_info["inputs"][0]["type"]
|
||||
break
|
||||
if not self.from_name:
|
||||
error_message = (
|
||||
f'Label Studio project "{self.project_name}" '
|
||||
f"does not have a TextArea tag. "
|
||||
f"Please add a TextArea tag to the project."
|
||||
)
|
||||
if self.mode == LabelStudioMode.PROMPT:
|
||||
error_message += (
|
||||
"\nHINT: go to project Settings -> "
|
||||
"Labeling Interface -> Browse Templates"
|
||||
' and select "Generative AI -> '
|
||||
'Supervised Language Model Fine-tuning" template.'
|
||||
)
|
||||
else:
|
||||
error_message += (
|
||||
"\nHINT: go to project Settings -> "
|
||||
"Labeling Interface -> Browse Templates"
|
||||
" and check available templates under "
|
||||
'"Generative AI" section.'
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
|
||||
def add_prompts_generations(
|
||||
self, run_id: str, generations: List[List[Generation]]
|
||||
) -> None:
|
||||
# Create tasks in Label Studio
|
||||
tasks = []
|
||||
prompts = self.payload[run_id]["prompts"]
|
||||
model_version = (
|
||||
self.payload[run_id]["kwargs"]
|
||||
.get("invocation_params", {})
|
||||
.get("model_name")
|
||||
)
|
||||
for prompt, generation in zip(prompts, generations):
|
||||
tasks.append(
|
||||
{
|
||||
"data": {
|
||||
self.value: prompt,
|
||||
"run_id": run_id,
|
||||
},
|
||||
"predictions": [
|
||||
{
|
||||
"result": [
|
||||
{
|
||||
"from_name": self.from_name,
|
||||
"to_name": self.to_name,
|
||||
"type": "textarea",
|
||||
"value": {"text": [g.text for g in generation]},
|
||||
}
|
||||
],
|
||||
"model_version": model_version,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self.ls_project.import_tasks(tasks)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Save the prompts in memory when an LLM starts."""
|
||||
if self.input_type != "Text":
|
||||
raise ValueError(
|
||||
f'\nLabel Studio project "{self.project_name}" '
|
||||
f"has an input type <{self.input_type}>. "
|
||||
f'To make it work with the mode="chat", '
|
||||
f"the input type should be <Text>.\n"
|
||||
f"Read more here https://labelstud.io/tags/text"
|
||||
)
|
||||
run_id = str(kwargs["run_id"])
|
||||
self.payload[run_id] = {"prompts": prompts, "kwargs": kwargs}
|
||||
|
||||
def _get_message_role(self, message: BaseMessage) -> str:
|
||||
"""Get the role of the message."""
|
||||
if isinstance(message, ChatMessage):
|
||||
return message.role
|
||||
else:
|
||||
return message.__class__.__name__
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Save the prompts in memory when an LLM starts."""
|
||||
if self.input_type != "Paragraphs":
|
||||
raise ValueError(
|
||||
f'\nLabel Studio project "{self.project_name}" '
|
||||
f"has an input type <{self.input_type}>. "
|
||||
f'To make it work with the mode="chat", '
|
||||
f"the input type should be <Paragraphs>.\n"
|
||||
f"Read more here https://labelstud.io/tags/paragraphs"
|
||||
)
|
||||
|
||||
prompts = []
|
||||
for message_list in messages:
|
||||
dialog = []
|
||||
for message in message_list:
|
||||
dialog.append(
|
||||
{
|
||||
"role": self._get_message_role(message),
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
prompts.append(dialog)
|
||||
self.payload[str(run_id)] = {
|
||||
"prompts": prompts,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"run_id": run_id,
|
||||
"parent_run_id": parent_run_id,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing when a new token is generated."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Create a new Label Studio task for each prompt and generation."""
|
||||
run_id = str(kwargs["run_id"])
|
||||
|
||||
# Submit results to Label Studio
|
||||
self.add_prompts_generations(run_id, response.generations)
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
self.payload.pop(run_id)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
@@ -0,0 +1,680 @@
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, List, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from packaging.version import parse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_API_URL = "https://app.llmonitor.com"
|
||||
|
||||
user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None)
|
||||
user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None)
|
||||
|
||||
PARAMS_TO_CAPTURE = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"stop",
|
||||
"presence_penalty",
|
||||
"frequence_penalty",
|
||||
"seed",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
"max_tokens",
|
||||
"logit_bias",
|
||||
]
|
||||
|
||||
|
||||
class UserContextManager:
|
||||
"""Context manager for LLMonitor user context."""
|
||||
|
||||
def __init__(self, user_id: str, user_props: Any = None) -> None:
|
||||
user_ctx.set(user_id)
|
||||
user_props_ctx.set(user_props)
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> Any:
|
||||
user_ctx.set(None)
|
||||
user_props_ctx.set(None)
|
||||
|
||||
|
||||
def identify(user_id: str, user_props: Any = None) -> UserContextManager:
|
||||
"""Builds an LLMonitor UserContextManager
|
||||
|
||||
Parameters:
|
||||
- `user_id`: The user id.
|
||||
- `user_props`: The user properties.
|
||||
|
||||
Returns:
|
||||
A context manager that sets the user context.
|
||||
"""
|
||||
return UserContextManager(user_id, user_props)
|
||||
|
||||
|
||||
def _serialize(obj: Any) -> Union[Dict[str, Any], List[Any], Any]:
|
||||
if hasattr(obj, "to_json"):
|
||||
return obj.to_json()
|
||||
|
||||
if isinstance(obj, dict):
|
||||
return {key: _serialize(value) for key, value in obj.items()}
|
||||
|
||||
if isinstance(obj, list):
|
||||
return [_serialize(element) for element in obj]
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def _parse_input(raw_input: Any) -> Any:
|
||||
if not raw_input:
|
||||
return None
|
||||
|
||||
# if it's an array of 1, just parse the first element
|
||||
if isinstance(raw_input, list) and len(raw_input) == 1:
|
||||
return _parse_input(raw_input[0])
|
||||
|
||||
if not isinstance(raw_input, dict):
|
||||
return _serialize(raw_input)
|
||||
|
||||
input_value = raw_input.get("input")
|
||||
inputs_value = raw_input.get("inputs")
|
||||
question_value = raw_input.get("question")
|
||||
query_value = raw_input.get("query")
|
||||
|
||||
if input_value:
|
||||
return input_value
|
||||
if inputs_value:
|
||||
return inputs_value
|
||||
if question_value:
|
||||
return question_value
|
||||
if query_value:
|
||||
return query_value
|
||||
|
||||
return _serialize(raw_input)
|
||||
|
||||
|
||||
def _parse_output(raw_output: dict) -> Any:
|
||||
if not raw_output:
|
||||
return None
|
||||
|
||||
if not isinstance(raw_output, dict):
|
||||
return _serialize(raw_output)
|
||||
|
||||
text_value = raw_output.get("text")
|
||||
output_value = raw_output.get("output")
|
||||
output_text_value = raw_output.get("output_text")
|
||||
answer_value = raw_output.get("answer")
|
||||
result_value = raw_output.get("result")
|
||||
|
||||
if text_value:
|
||||
return text_value
|
||||
if answer_value:
|
||||
return answer_value
|
||||
if output_value:
|
||||
return output_value
|
||||
if output_text_value:
|
||||
return output_text_value
|
||||
if result_value:
|
||||
return result_value
|
||||
|
||||
return _serialize(raw_output)
|
||||
|
||||
|
||||
def _parse_lc_role(
|
||||
role: str,
|
||||
) -> str:
|
||||
if role == "human":
|
||||
return "user"
|
||||
else:
|
||||
return role
|
||||
|
||||
|
||||
def _get_user_id(metadata: Any) -> Any:
|
||||
if user_ctx.get() is not None:
|
||||
return user_ctx.get()
|
||||
|
||||
metadata = metadata or {}
|
||||
user_id = metadata.get("user_id")
|
||||
if user_id is None:
|
||||
user_id = metadata.get("userId") # legacy, to delete in the future
|
||||
return user_id
|
||||
|
||||
|
||||
def _get_user_props(metadata: Any) -> Any:
|
||||
if user_props_ctx.get() is not None:
|
||||
return user_props_ctx.get()
|
||||
|
||||
metadata = metadata or {}
|
||||
return metadata.get("user_props", None)
|
||||
|
||||
|
||||
def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
|
||||
keys = ["function_call", "tool_calls", "tool_call_id", "name"]
|
||||
parsed = {"text": message.content, "role": _parse_lc_role(message.type)}
|
||||
parsed.update(
|
||||
{
|
||||
key: cast(Any, message.additional_kwargs.get(key))
|
||||
for key in keys
|
||||
if message.additional_kwargs.get(key) is not None
|
||||
}
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def _parse_lc_messages(messages: Union[List[BaseMessage], Any]) -> List[Dict[str, Any]]:
|
||||
return [_parse_lc_message(message) for message in messages]
|
||||
|
||||
|
||||
class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler for LLMonitor`.
|
||||
|
||||
#### Parameters:
|
||||
- `app_id`: The app id of the app you want to report to. Defaults to
|
||||
`None`, which means that `LLMONITOR_APP_ID` will be used.
|
||||
- `api_url`: The url of the LLMonitor API. Defaults to `None`,
|
||||
which means that either `LLMONITOR_API_URL` environment variable
|
||||
or `https://app.llmonitor.com` will be used.
|
||||
|
||||
#### Raises:
|
||||
- `ValueError`: if `app_id` is not provided either as an
|
||||
argument or as an environment variable.
|
||||
- `ConnectionError`: if the connection to the API fails.
|
||||
|
||||
|
||||
#### Example:
|
||||
```python
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_community.callbacks import LLMonitorCallbackHandler
|
||||
|
||||
llmonitor_callback = LLMonitorCallbackHandler()
|
||||
llm = OpenAI(callbacks=[llmonitor_callback],
|
||||
metadata={"userId": "user-123"})
|
||||
llm.predict("Hello, how are you?")
|
||||
```
|
||||
"""
|
||||
|
||||
__api_url: str
|
||||
__app_id: str
|
||||
__verbose: bool
|
||||
__llmonitor_version: str
|
||||
__has_valid_config: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: Union[str, None] = None,
|
||||
api_url: Union[str, None] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.__has_valid_config = True
|
||||
|
||||
try:
|
||||
import llmonitor
|
||||
|
||||
self.__llmonitor_version = importlib.metadata.version("llmonitor")
|
||||
self.__track_event = llmonitor.track_event
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"""[LLMonitor] To use the LLMonitor callback handler you need to
|
||||
have the `llmonitor` Python package installed. Please install it
|
||||
with `pip install llmonitor`"""
|
||||
)
|
||||
self.__has_valid_config = False
|
||||
return
|
||||
|
||||
if parse(self.__llmonitor_version) < parse("0.0.32"):
|
||||
logger.warning(
|
||||
f"""[LLMonitor] The installed `llmonitor` version is
|
||||
{self.__llmonitor_version}
|
||||
but `LLMonitorCallbackHandler` requires at least version 0.0.32
|
||||
upgrade `llmonitor` with `pip install --upgrade llmonitor`"""
|
||||
)
|
||||
self.__has_valid_config = False
|
||||
|
||||
self.__has_valid_config = True
|
||||
|
||||
self.__api_url = api_url or os.getenv("LLMONITOR_API_URL") or DEFAULT_API_URL
|
||||
self.__verbose = verbose or bool(os.getenv("LLMONITOR_VERBOSE"))
|
||||
|
||||
_app_id = app_id or os.getenv("LLMONITOR_APP_ID")
|
||||
if _app_id is None:
|
||||
logger.warning(
|
||||
"""[LLMonitor] app_id must be provided either as an argument or
|
||||
as an environment variable"""
|
||||
)
|
||||
self.__has_valid_config = False
|
||||
else:
|
||||
self.__app_id = _app_id
|
||||
|
||||
if self.__has_valid_config is False:
|
||||
return None
|
||||
|
||||
try:
|
||||
res = requests.get(f"{self.__api_url}/api/app/{self.__app_id}")
|
||||
if not res.ok:
|
||||
raise ConnectionError()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"""[LLMonitor] Could not connect to the LLMonitor API at
|
||||
{self.__api_url}"""
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
|
||||
params = kwargs.get("invocation_params", {})
|
||||
params.update(
|
||||
serialized.get("kwargs", {})
|
||||
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
|
||||
|
||||
name = (
|
||||
params.get("model")
|
||||
or params.get("model_name")
|
||||
or params.get("model_id")
|
||||
)
|
||||
|
||||
if not name and "anthropic" in params.get("_type"):
|
||||
name = "claude-2"
|
||||
|
||||
extra = {
|
||||
param: params.get(param)
|
||||
for param in PARAMS_TO_CAPTURE
|
||||
if params.get(param) is not None
|
||||
}
|
||||
|
||||
input = _parse_input(prompts)
|
||||
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
tags=tags,
|
||||
extra=extra,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}")
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
|
||||
params = kwargs.get("invocation_params", {})
|
||||
params.update(
|
||||
serialized.get("kwargs", {})
|
||||
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
|
||||
|
||||
name = (
|
||||
params.get("model")
|
||||
or params.get("model_name")
|
||||
or params.get("model_id")
|
||||
)
|
||||
|
||||
if not name and "anthropic" in params.get("_type"):
|
||||
name = "claude-2"
|
||||
|
||||
extra = {
|
||||
param: params.get(param)
|
||||
for param in PARAMS_TO_CAPTURE
|
||||
if params.get(param) is not None
|
||||
}
|
||||
|
||||
input = _parse_lc_messages(messages[0])
|
||||
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
tags=tags,
|
||||
extra=extra,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_chat_model_start: {e}")
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
|
||||
try:
|
||||
token_usage = (response.llm_output or {}).get("token_usage", {})
|
||||
|
||||
parsed_output: Any = [
|
||||
_parse_lc_message(generation.message)
|
||||
if hasattr(generation, "message")
|
||||
else generation.text
|
||||
for generation in response.generations[0]
|
||||
]
|
||||
|
||||
# if it's an array of 1, just parse the first element
|
||||
if len(parsed_output) == 1:
|
||||
parsed_output = parsed_output[0]
|
||||
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=parsed_output,
|
||||
token_usage={
|
||||
"prompt": token_usage.get("prompt_tokens"),
|
||||
"completion": token_usage.get("completion_tokens"),
|
||||
},
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_llm_end: {e}")
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
name = serialized.get("name")
|
||||
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input_str,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_tool_start: {e}")
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=output,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_tool_end: {e}")
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
name = serialized.get("id", [None, None, None, None])[3]
|
||||
type = "chain"
|
||||
metadata = metadata or {}
|
||||
|
||||
agentName = metadata.get("agent_name")
|
||||
if agentName is None:
|
||||
agentName = metadata.get("agentName")
|
||||
|
||||
if name == "AgentExecutor" or name == "PlanAndExecute":
|
||||
type = "agent"
|
||||
if agentName is not None:
|
||||
type = "agent"
|
||||
name = agentName
|
||||
if parent_run_id is not None:
|
||||
type = "chain"
|
||||
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
input = _parse_input(inputs)
|
||||
|
||||
self.__track_event(
|
||||
type,
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_chain_start: {e}")
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
output = _parse_output(outputs)
|
||||
|
||||
self.__track_event(
|
||||
"chain",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=output,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_chain_end: {e}")
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
name = action.tool
|
||||
input = _parse_input(action.tool_input)
|
||||
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"start",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_agent_action: {e}")
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
output = _parse_output(finish.return_values)
|
||||
|
||||
self.__track_event(
|
||||
"agent",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=output,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
self.__track_event(
|
||||
"chain",
|
||||
"error",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
error={"message": str(error), "stack": traceback.format_exc()},
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_chain_error: {e}")
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"error",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
error={"message": str(error), "stack": traceback.format_exc()},
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_tool_error: {e}")
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"error",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
error={"message": str(error), "stack": traceback.format_exc()},
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_llm_error: {e}")
|
||||
|
||||
|
||||
__all__ = ["LLMonitorCallbackHandler", "identify"]
|
69
libs/community/langchain_community/callbacks/manager.py
Normal file
69
libs/community/langchain_community/callbacks/manager.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
Generator,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain_core.tracers.context import register_configure_hook
|
||||
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain_community.callbacks.tracers.wandb import WandbTracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||
"openai_callback", default=None
|
||||
)
|
||||
wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( # noqa: E501
|
||||
"tracing_wandb_callback", default=None
|
||||
)
|
||||
|
||||
register_configure_hook(openai_callback_var, True)
|
||||
register_configure_hook(
|
||||
wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
"""Get the OpenAI callback handler in a context manager.
|
||||
which conveniently exposes token and cost information.
|
||||
|
||||
Returns:
|
||||
OpenAICallbackHandler: The OpenAI callback handler.
|
||||
|
||||
Example:
|
||||
>>> with get_openai_callback() as cb:
|
||||
... # Use the OpenAI callback handler
|
||||
"""
|
||||
cb = OpenAICallbackHandler()
|
||||
openai_callback_var.set(cb)
|
||||
yield cb
|
||||
openai_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def wandb_tracing_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[None, None, None]:
|
||||
"""Get the WandbTracer in a context manager.
|
||||
|
||||
Args:
|
||||
session_name (str, optional): The name of the session.
|
||||
Defaults to "default".
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Example:
|
||||
>>> with wandb_tracing_enabled() as session:
|
||||
... # Use the WandbTracer session
|
||||
"""
|
||||
cb = WandbTracer()
|
||||
wandb_tracing_callback_var.set(cb)
|
||||
yield None
|
||||
wandb_tracing_callback_var.set(None)
|
660
libs/community/langchain_community/callbacks/mlflow_callback.py
Normal file
660
libs/community/langchain_community/callbacks/mlflow_callback.py
Normal file
@@ -0,0 +1,660 @@
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
|
||||
|
||||
def import_mlflow() -> Any:
|
||||
"""Import the mlflow python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import mlflow
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the mlflow callback manager you need to have the `mlflow` python "
|
||||
"package installed. Please install it with `pip install mlflow>=2.3.0`"
|
||||
)
|
||||
return mlflow
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: str,
|
||||
nlp: Any = None,
|
||||
) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
nlp (spacy.lang): The spacy language model to use for visualization.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the complexity metrics and visualization
|
||||
files serialized to HTML string.
|
||||
"""
|
||||
resp: Dict[str, Any] = {}
|
||||
textstat = import_textstat()
|
||||
spacy = import_spacy()
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(text),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
# "text_standard": textstat.text_standard(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
resp.update({"text_complexity_metrics": text_complexity_metrics})
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if nlp is not None:
|
||||
doc = nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
)
|
||||
|
||||
ent_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="ent", jupyter=False, page=True
|
||||
)
|
||||
|
||||
text_visualizations = {
|
||||
"dependency_tree": dep_out,
|
||||
"entities": ent_out,
|
||||
}
|
||||
|
||||
resp.update(text_visualizations)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
|
||||
"""Construct an html element from a prompt and a generation.
|
||||
|
||||
Parameters:
|
||||
prompt (str): The prompt.
|
||||
generation (str): The generation.
|
||||
|
||||
Returns:
|
||||
(str): The html string."""
|
||||
formatted_prompt = prompt.replace("\n", "<br>")
|
||||
formatted_generation = generation.replace("\n", "<br>")
|
||||
|
||||
return f"""
|
||||
<p style="color:black;">{formatted_prompt}:</p>
|
||||
<blockquote>
|
||||
<p style="color:green;">
|
||||
{formatted_generation}
|
||||
</p>
|
||||
</blockquote>
|
||||
"""
|
||||
|
||||
|
||||
class MlflowLogger:
|
||||
"""Callback Handler that logs metrics and artifacts to mlflow server.
|
||||
|
||||
Parameters:
|
||||
name (str): Name of the run.
|
||||
experiment (str): Name of the experiment.
|
||||
tags (dict): Tags to be attached for the run.
|
||||
tracking_uri (str): MLflow tracking server uri.
|
||||
|
||||
This handler implements the helper functions to initialize,
|
||||
log metrics and artifacts to the mlflow server.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
self.mlflow = import_mlflow()
|
||||
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
|
||||
self.mlflow.set_tracking_uri("databricks")
|
||||
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
|
||||
self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid)
|
||||
else:
|
||||
tracking_uri = get_from_dict_or_env(
|
||||
kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", ""
|
||||
)
|
||||
self.mlflow.set_tracking_uri(tracking_uri)
|
||||
|
||||
# User can set other env variables described here
|
||||
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
|
||||
|
||||
experiment_name = get_from_dict_or_env(
|
||||
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
|
||||
)
|
||||
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
|
||||
if self.mlf_exp is not None:
|
||||
self.mlf_expid = self.mlf_exp.experiment_id
|
||||
else:
|
||||
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
|
||||
|
||||
self.start_run(kwargs["run_name"], kwargs["run_tags"])
|
||||
|
||||
def start_run(self, name: str, tags: Dict[str, str]) -> None:
|
||||
"""To start a new run, auto generates the random suffix for name"""
|
||||
if name.endswith("-%"):
|
||||
rname = "".join(random.choices(string.ascii_uppercase + string.digits, k=7))
|
||||
name = name.replace("%", rname)
|
||||
self.run = self.mlflow.MlflowClient().create_run(
|
||||
self.mlf_expid, run_name=name, tags=tags
|
||||
)
|
||||
|
||||
def finish_run(self) -> None:
|
||||
"""To finish the run."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.end_run()
|
||||
|
||||
def metric(self, key: str, value: float) -> None:
|
||||
"""To log metric to mlflow server."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_metric(key, value)
|
||||
|
||||
def metrics(
|
||||
self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0
|
||||
) -> None:
|
||||
"""To log all metrics in the input dict."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_metrics(data)
|
||||
|
||||
def jsonf(self, data: Dict[str, Any], filename: str) -> None:
|
||||
"""To log the input data as json file artifact."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_dict(data, f"{filename}.json")
|
||||
|
||||
def table(self, name: str, dataframe) -> None: # type: ignore
|
||||
"""To log the input pandas dataframe as a html table"""
|
||||
self.html(dataframe.to_html(), f"table_{name}")
|
||||
|
||||
def html(self, html: str, filename: str) -> None:
|
||||
"""To log the input html string as html file artifact."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_text(html, f"{filename}.html")
|
||||
|
||||
def text(self, text: str, filename: str) -> None:
|
||||
"""To log the input text as text file artifact."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_text(text, f"{filename}.txt")
|
||||
|
||||
def artifact(self, path: str) -> None:
|
||||
"""To upload the file from given path as artifact."""
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.log_artifact(path)
|
||||
|
||||
def langchain_artifact(self, chain: Any) -> None:
|
||||
with self.mlflow.start_run(
|
||||
run_id=self.run.info.run_id, experiment_id=self.mlf_expid
|
||||
):
|
||||
self.mlflow.langchain.log_model(chain, "langchain-model")
|
||||
|
||||
|
||||
class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs metrics and artifacts to mlflow server.
|
||||
|
||||
Parameters:
|
||||
name (str): Name of the run.
|
||||
experiment (str): Name of the experiment.
|
||||
tags (dict): Tags to be attached for the run.
|
||||
tracking_uri (str): MLflow tracking server uri.
|
||||
|
||||
This handler will utilize the associated callback method called and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run,
|
||||
and adds the response to the list of records for both the {method}_records and
|
||||
action. It then logs the response to mlflow server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = "langchainrun-%",
|
||||
experiment: Optional[str] = "langchain",
|
||||
tags: Optional[Dict] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
import_pandas()
|
||||
import_textstat()
|
||||
import_mlflow()
|
||||
spacy = import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
self.experiment = experiment
|
||||
self.tags = tags or {}
|
||||
self.tracking_uri = tracking_uri
|
||||
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
self.mlflg = MlflowLogger(
|
||||
tracking_uri=self.tracking_uri,
|
||||
experiment_name=self.experiment,
|
||||
run_name=self.name,
|
||||
run_tags=self.tags,
|
||||
)
|
||||
|
||||
self.action_records: list = []
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
self.metrics = {
|
||||
"step": 0,
|
||||
"starts": 0,
|
||||
"ends": 0,
|
||||
"errors": 0,
|
||||
"text_ctr": 0,
|
||||
"chain_starts": 0,
|
||||
"chain_ends": 0,
|
||||
"llm_starts": 0,
|
||||
"llm_ends": 0,
|
||||
"llm_streams": 0,
|
||||
"tool_starts": 0,
|
||||
"tool_ends": 0,
|
||||
"agent_ends": 0,
|
||||
}
|
||||
|
||||
self.records: Dict[str, Any] = {
|
||||
"on_llm_start_records": [],
|
||||
"on_llm_token_records": [],
|
||||
"on_llm_end_records": [],
|
||||
"on_chain_start_records": [],
|
||||
"on_chain_end_records": [],
|
||||
"on_tool_start_records": [],
|
||||
"on_tool_end_records": [],
|
||||
"on_text_records": [],
|
||||
"on_agent_finish_records": [],
|
||||
"on_agent_action_records": [],
|
||||
"action_records": [],
|
||||
}
|
||||
|
||||
def _reset(self) -> None:
|
||||
for k, v in self.metrics.items():
|
||||
self.metrics[k] = 0
|
||||
for k, v in self.records.items():
|
||||
self.records[k] = []
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
llm_starts = self.metrics["llm_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
for idx, prompt in enumerate(prompts):
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompt"] = prompt
|
||||
self.records["on_llm_start_records"].append(prompt_resp)
|
||||
self.records["action_records"].append(prompt_resp)
|
||||
self.mlflg.jsonf(prompt_resp, f"llm_start_{llm_starts}_prompt_{idx}")
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_streams"] += 1
|
||||
|
||||
llm_streams = self.metrics["llm_streams"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_llm_token_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"llm_new_tokens_{llm_streams}")
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
llm_ends = self.metrics["llm_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
for generations in response.generations:
|
||||
for idx, generation in enumerate(generations):
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
generation_resp.update(
|
||||
analyze_text(
|
||||
generation.text,
|
||||
nlp=self.nlp,
|
||||
)
|
||||
)
|
||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||
"text_complexity_metrics"
|
||||
) # type: ignore # noqa: E501
|
||||
self.mlflg.metrics(
|
||||
complexity_metrics,
|
||||
step=self.metrics["step"],
|
||||
)
|
||||
self.records["on_llm_end_records"].append(generation_resp)
|
||||
self.records["action_records"].append(generation_resp)
|
||||
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
|
||||
dependency_tree = generation_resp["dependency_tree"]
|
||||
entities = generation_resp["entities"]
|
||||
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
|
||||
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
chain_starts = self.metrics["chain_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["inputs"] = chain_input
|
||||
self.records["on_chain_start_records"].append(input_resp)
|
||||
self.records["action_records"].append(input_resp)
|
||||
self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
chain_ends = self.metrics["chain_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
||||
resp.update({"action": "on_chain_end", "outputs": chain_output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_chain_end_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_tool_start_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
tool_ends = self.metrics["tool_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_tool_end_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["text_ctr"] += 1
|
||||
|
||||
text_ctr = self.metrics["text_ctr"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_text_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"on_text_{text_ctr}")
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["agent_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
agent_ends = self.metrics["agent_ends"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_agent_finish_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"agent_finish_{agent_ends}")
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
self.records["on_agent_action_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"agent_action_{tool_starts}")
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
|
||||
on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
|
||||
|
||||
llm_input_columns = ["step", "prompt"]
|
||||
if "name" in on_llm_start_records_df.columns:
|
||||
llm_input_columns.append("name")
|
||||
elif "id" in on_llm_start_records_df.columns:
|
||||
# id is llm class's full import path. For example:
|
||||
# ["langchain", "llms", "openai", "AzureOpenAI"]
|
||||
on_llm_start_records_df["name"] = on_llm_start_records_df["id"].apply(
|
||||
lambda id_: id_[-1]
|
||||
)
|
||||
llm_input_columns.append("name")
|
||||
llm_input_prompts_df = (
|
||||
on_llm_start_records_df[llm_input_columns]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "prompt_step"}, axis=1)
|
||||
)
|
||||
complexity_metrics_columns = []
|
||||
visualizations_columns = []
|
||||
|
||||
complexity_metrics_columns = [
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"smog_index",
|
||||
"coleman_liau_index",
|
||||
"automated_readability_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"linsear_write_formula",
|
||||
"gunning_fog",
|
||||
# "text_standard",
|
||||
"fernandez_huerta",
|
||||
"szigriszt_pazos",
|
||||
"gutierrez_polini",
|
||||
"crawford",
|
||||
"gulpease_index",
|
||||
"osman",
|
||||
]
|
||||
|
||||
visualizations_columns = ["dependency_tree", "entities"]
|
||||
|
||||
llm_outputs_df = (
|
||||
on_llm_end_records_df[
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns
|
||||
]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "output_step", "text": "output"}, axis=1)
|
||||
)
|
||||
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
||||
session_analysis_df["chat_html"] = session_analysis_df[
|
||||
["prompt", "output"]
|
||||
].apply(
|
||||
lambda row: construct_html_from_prompt_and_generation(
|
||||
row["prompt"], row["output"]
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
return session_analysis_df
|
||||
|
||||
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
||||
pd = import_pandas()
|
||||
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
|
||||
session_analysis_df = self._create_session_analysis_df()
|
||||
chat_html = session_analysis_df.pop("chat_html")
|
||||
chat_html = chat_html.replace("\n", "", regex=True)
|
||||
self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
|
||||
self.mlflg.html("".join(chat_html.tolist()), "chat_html")
|
||||
|
||||
if langchain_asset:
|
||||
# To avoid circular import error
|
||||
# mlflow only supports LLMChain asset
|
||||
if "langchain.chains.llm.LLMChain" in str(type(langchain_asset)):
|
||||
self.mlflg.langchain_artifact(langchain_asset)
|
||||
else:
|
||||
langchain_asset_path = str(Path(self.temp_dir.name, "model.json"))
|
||||
try:
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
self.mlflg.artifact(langchain_asset_path)
|
||||
except ValueError:
|
||||
try:
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
self.mlflg.artifact(langchain_asset_path)
|
||||
except AttributeError:
|
||||
print("Could not save model.")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
except NotImplementedError:
|
||||
print("Could not save model.")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
except NotImplementedError:
|
||||
print("Could not save model.")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
if finish:
|
||||
self.mlflg.finish_run()
|
||||
self._reset()
|
208
libs/community/langchain_community/callbacks/openai_info.py
Normal file
208
libs/community/langchain_community/callbacks/openai_info.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS = {
|
||||
# GPT-4 input
|
||||
"gpt-4": 0.03,
|
||||
"gpt-4-0314": 0.03,
|
||||
"gpt-4-0613": 0.03,
|
||||
"gpt-4-32k": 0.06,
|
||||
"gpt-4-32k-0314": 0.06,
|
||||
"gpt-4-32k-0613": 0.06,
|
||||
"gpt-4-vision-preview": 0.01,
|
||||
"gpt-4-1106-preview": 0.01,
|
||||
# GPT-4 output
|
||||
"gpt-4-completion": 0.06,
|
||||
"gpt-4-0314-completion": 0.06,
|
||||
"gpt-4-0613-completion": 0.06,
|
||||
"gpt-4-32k-completion": 0.12,
|
||||
"gpt-4-32k-0314-completion": 0.12,
|
||||
"gpt-4-32k-0613-completion": 0.12,
|
||||
"gpt-4-vision-preview-completion": 0.03,
|
||||
"gpt-4-1106-preview-completion": 0.03,
|
||||
# GPT-3.5 input
|
||||
"gpt-3.5-turbo": 0.0015,
|
||||
"gpt-3.5-turbo-0301": 0.0015,
|
||||
"gpt-3.5-turbo-0613": 0.0015,
|
||||
"gpt-3.5-turbo-1106": 0.001,
|
||||
"gpt-3.5-turbo-instruct": 0.0015,
|
||||
"gpt-3.5-turbo-16k": 0.003,
|
||||
"gpt-3.5-turbo-16k-0613": 0.003,
|
||||
# GPT-3.5 output
|
||||
"gpt-3.5-turbo-completion": 0.002,
|
||||
"gpt-3.5-turbo-0301-completion": 0.002,
|
||||
"gpt-3.5-turbo-0613-completion": 0.002,
|
||||
"gpt-3.5-turbo-1106-completion": 0.002,
|
||||
"gpt-3.5-turbo-instruct-completion": 0.002,
|
||||
"gpt-3.5-turbo-16k-completion": 0.004,
|
||||
"gpt-3.5-turbo-16k-0613-completion": 0.004,
|
||||
# Azure GPT-35 input
|
||||
"gpt-35-turbo": 0.0015, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0301": 0.0015, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0613": 0.0015,
|
||||
"gpt-35-turbo-instruct": 0.0015,
|
||||
"gpt-35-turbo-16k": 0.003,
|
||||
"gpt-35-turbo-16k-0613": 0.003,
|
||||
# Azure GPT-35 output
|
||||
"gpt-35-turbo-completion": 0.002, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0301-completion": 0.002, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0613-completion": 0.002,
|
||||
"gpt-35-turbo-instruct-completion": 0.002,
|
||||
"gpt-35-turbo-16k-completion": 0.004,
|
||||
"gpt-35-turbo-16k-0613-completion": 0.004,
|
||||
# Others
|
||||
"text-ada-001": 0.0004,
|
||||
"ada": 0.0004,
|
||||
"text-babbage-001": 0.0005,
|
||||
"babbage": 0.0005,
|
||||
"text-curie-001": 0.002,
|
||||
"curie": 0.002,
|
||||
"text-davinci-003": 0.02,
|
||||
"text-davinci-002": 0.02,
|
||||
"code-davinci-002": 0.02,
|
||||
# Fine Tuned input
|
||||
"babbage-002-finetuned": 0.0016,
|
||||
"davinci-002-finetuned": 0.012,
|
||||
"gpt-3.5-turbo-0613-finetuned": 0.012,
|
||||
# Fine Tuned output
|
||||
"babbage-002-finetuned-completion": 0.0016,
|
||||
"davinci-002-finetuned-completion": 0.012,
|
||||
"gpt-3.5-turbo-0613-finetuned-completion": 0.016,
|
||||
# Azure Fine Tuned input
|
||||
"babbage-002-azure-finetuned": 0.0004,
|
||||
"davinci-002-azure-finetuned": 0.002,
|
||||
"gpt-35-turbo-0613-azure-finetuned": 0.0015,
|
||||
# Azure Fine Tuned output
|
||||
"babbage-002-azure-finetuned-completion": 0.0004,
|
||||
"davinci-002-azure-finetuned-completion": 0.002,
|
||||
"gpt-35-turbo-0613-azure-finetuned-completion": 0.002,
|
||||
# Legacy fine-tuned models
|
||||
"ada-finetuned-legacy": 0.0016,
|
||||
"babbage-finetuned-legacy": 0.0024,
|
||||
"curie-finetuned-legacy": 0.012,
|
||||
"davinci-finetuned-legacy": 0.12,
|
||||
}
|
||||
|
||||
|
||||
def standardize_model_name(
|
||||
model_name: str,
|
||||
is_completion: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Standardize the model name to a format that can be used in the OpenAI API.
|
||||
|
||||
Args:
|
||||
model_name: Model name to standardize.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
Standardized model name.
|
||||
|
||||
"""
|
||||
model_name = model_name.lower()
|
||||
if ".ft-" in model_name:
|
||||
model_name = model_name.split(".ft-")[0] + "-azure-finetuned"
|
||||
if ":ft-" in model_name:
|
||||
model_name = model_name.split(":")[0] + "-finetuned-legacy"
|
||||
if "ft:" in model_name:
|
||||
model_name = model_name.split(":")[1] + "-finetuned"
|
||||
if is_completion and (
|
||||
model_name.startswith("gpt-4")
|
||||
or model_name.startswith("gpt-3.5")
|
||||
or model_name.startswith("gpt-35")
|
||||
or ("finetuned" in model_name and "legacy" not in model_name)
|
||||
):
|
||||
return model_name + "-completion"
|
||||
else:
|
||||
return model_name
|
||||
|
||||
|
||||
def get_openai_token_cost_for_model(
|
||||
model_name: str, num_tokens: int, is_completion: bool = False
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost in USD for a given model and number of tokens.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
num_tokens: Number of tokens.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
Cost in USD.
|
||||
"""
|
||||
model_name = standardize_model_name(model_name, is_completion=is_completion)
|
||||
if model_name not in MODEL_COST_PER_1K_TOKENS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_name}. Please provide a valid OpenAI model name."
|
||||
"Known models are: " + ", ".join(MODEL_COST_PER_1K_TOKENS.keys())
|
||||
)
|
||||
return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000)
|
||||
|
||||
|
||||
class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that tracks OpenAI info."""
|
||||
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
successful_requests: int = 0
|
||||
total_cost: float = 0.0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Tokens Used: {self.total_tokens}\n"
|
||||
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
||||
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
||||
f"Successful Requests: {self.successful_requests}\n"
|
||||
f"Total Cost (USD): ${self.total_cost}"
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Print out the token."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
if response.llm_output is None:
|
||||
return None
|
||||
self.successful_requests += 1
|
||||
if "token_usage" not in response.llm_output:
|
||||
return None
|
||||
token_usage = response.llm_output["token_usage"]
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
model_name = standardize_model_name(response.llm_output.get("model_name", ""))
|
||||
if model_name in MODEL_COST_PER_1K_TOKENS:
|
||||
completion_cost = get_openai_token_cost_for_model(
|
||||
model_name, completion_tokens, is_completion=True
|
||||
)
|
||||
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
|
||||
self.total_cost += prompt_cost + completion_cost
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
|
||||
def __copy__(self) -> "OpenAICallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "OpenAICallbackHandler":
|
||||
"""Return a deep copy of the callback handler."""
|
||||
return self
|
@@ -0,0 +1,162 @@
|
||||
"""Callback handler for promptlayer."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import promptlayer
|
||||
|
||||
|
||||
def _lazy_import_promptlayer() -> promptlayer:
|
||||
"""Lazy import promptlayer to avoid circular imports."""
|
||||
try:
|
||||
import promptlayer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The PromptLayerCallbackHandler requires the promptlayer package. "
|
||||
" Please install it with `pip install promptlayer`."
|
||||
)
|
||||
return promptlayer
|
||||
|
||||
|
||||
class PromptLayerCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for promptlayer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pl_id_callback: Optional[Callable[..., Any]] = None,
|
||||
pl_tags: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialize the PromptLayerCallbackHandler."""
|
||||
_lazy_import_promptlayer()
|
||||
self.pl_id_callback = pl_id_callback
|
||||
self.pl_tags = pl_tags or []
|
||||
self.runs: Dict[UUID, Dict[str, Any]] = {}
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.runs[run_id] = {
|
||||
"messages": [self._create_message_dicts(m)[0] for m in messages],
|
||||
"invocation_params": kwargs.get("invocation_params", {}),
|
||||
"name": ".".join(serialized["id"]),
|
||||
"request_start_time": datetime.datetime.now().timestamp(),
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.runs[run_id] = {
|
||||
"prompts": prompts,
|
||||
"invocation_params": kwargs.get("invocation_params", {}),
|
||||
"name": ".".join(serialized["id"]),
|
||||
"request_start_time": datetime.datetime.now().timestamp(),
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
|
||||
run_info = self.runs.get(run_id, {})
|
||||
if not run_info:
|
||||
return
|
||||
run_info["request_end_time"] = datetime.datetime.now().timestamp()
|
||||
for i in range(len(response.generations)):
|
||||
generation = response.generations[i][0]
|
||||
|
||||
resp = {
|
||||
"text": generation.text,
|
||||
"llm_output": response.llm_output,
|
||||
}
|
||||
model_params = run_info.get("invocation_params", {})
|
||||
is_chat_model = run_info.get("messages", None) is not None
|
||||
model_input = (
|
||||
run_info.get("messages", [])[i]
|
||||
if is_chat_model
|
||||
else [run_info.get("prompts", [])[i]]
|
||||
)
|
||||
model_response = (
|
||||
[self._convert_message_to_dict(generation.message)]
|
||||
if is_chat_model and isinstance(generation, ChatGeneration)
|
||||
else resp
|
||||
)
|
||||
|
||||
pl_request_id = promptlayer_api_request(
|
||||
run_info.get("name"),
|
||||
"langchain",
|
||||
model_input,
|
||||
model_params,
|
||||
self.pl_tags,
|
||||
model_response,
|
||||
run_info.get("request_start_time"),
|
||||
run_info.get("request_end_time"),
|
||||
get_api_key(),
|
||||
return_pl_id=bool(self.pl_id_callback is not None),
|
||||
metadata={
|
||||
"_langchain_run_id": str(run_id),
|
||||
"_langchain_parent_run_id": str(parent_run_id),
|
||||
"_langchain_tags": str(run_info.get("tags", [])),
|
||||
},
|
||||
)
|
||||
|
||||
if self.pl_id_callback:
|
||||
self.pl_id_callback(pl_request_id)
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
||||
if isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params: Dict[str, Any] = {}
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
@@ -0,0 +1,276 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
flatten_dict,
|
||||
)
|
||||
|
||||
|
||||
def save_json(data: dict, file_path: str) -> None:
|
||||
"""Save dict to local file path.
|
||||
|
||||
Parameters:
|
||||
data (dict): The dictionary to be saved.
|
||||
file_path (str): Local file path.
|
||||
"""
|
||||
with open(file_path, "w") as outfile:
|
||||
json.dump(data, outfile)
|
||||
|
||||
|
||||
class SageMakerCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.
|
||||
|
||||
Parameters:
|
||||
run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
|
||||
"""
|
||||
|
||||
def __init__(self, run: Any) -> None:
|
||||
"""Initialize callback handler."""
|
||||
super().__init__()
|
||||
|
||||
self.run = run
|
||||
|
||||
self.metrics = {
|
||||
"step": 0,
|
||||
"starts": 0,
|
||||
"ends": 0,
|
||||
"errors": 0,
|
||||
"text_ctr": 0,
|
||||
"chain_starts": 0,
|
||||
"chain_ends": 0,
|
||||
"llm_starts": 0,
|
||||
"llm_ends": 0,
|
||||
"llm_streams": 0,
|
||||
"tool_starts": 0,
|
||||
"tool_ends": 0,
|
||||
"agent_ends": 0,
|
||||
}
|
||||
|
||||
# Create a temporary directory
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
def _reset(self) -> None:
|
||||
for k, v in self.metrics.items():
|
||||
self.metrics[k] = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
llm_starts = self.metrics["llm_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
for idx, prompt in enumerate(prompts):
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompt"] = prompt
|
||||
self.jsonf(
|
||||
prompt_resp,
|
||||
self.temp_dir,
|
||||
f"llm_start_{llm_starts}_prompt_{idx}",
|
||||
)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_streams"] += 1
|
||||
|
||||
llm_streams = self.metrics["llm_streams"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
llm_ends = self.metrics["llm_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
|
||||
resp.update(self.metrics)
|
||||
|
||||
for generations in response.generations:
|
||||
for idx, generation in enumerate(generations):
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
|
||||
self.jsonf(
|
||||
resp,
|
||||
self.temp_dir,
|
||||
f"llm_end_{llm_ends}_generation_{idx}",
|
||||
)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
chain_starts = self.metrics["chain_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["inputs"] = chain_input
|
||||
|
||||
self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
chain_ends = self.metrics["chain_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
||||
resp.update({"action": "on_chain_end", "outputs": chain_output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
tool_ends = self.metrics["tool_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["text_ctr"] += 1
|
||||
|
||||
text_ctr = self.metrics["text_ctr"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["agent_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
agent_ends = self.metrics["agent_ends"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
|
||||
|
||||
def jsonf(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
data_dir: str,
|
||||
filename: str,
|
||||
is_output: Optional[bool] = True,
|
||||
) -> None:
|
||||
"""To log the input data as json file artifact."""
|
||||
file_path = os.path.join(data_dir, f"{filename}.json")
|
||||
save_json(data, file_path)
|
||||
self.run.log_file(file_path, name=filename, is_output=is_output)
|
||||
|
||||
def flush_tracker(self) -> None:
|
||||
"""Reset the steps and delete the temporary local directory."""
|
||||
self._reset()
|
||||
shutil.rmtree(self.temp_dir)
|
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
|
||||
LLMThoughtLabeler as LLMThoughtLabeler,
|
||||
)
|
||||
from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
|
||||
StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
|
||||
def StreamlitCallbackHandler(
|
||||
parent_container: DeltaGenerator,
|
||||
*,
|
||||
max_thought_containers: int = 4,
|
||||
expand_new_thoughts: bool = True,
|
||||
collapse_completed_thoughts: bool = True,
|
||||
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
||||
) -> BaseCallbackHandler:
|
||||
"""Callback Handler that writes to a Streamlit app.
|
||||
|
||||
This CallbackHandler is geared towards
|
||||
use with a LangChain Agent; it displays the Agent's LLM and tool-usage "thoughts"
|
||||
inside a series of Streamlit expanders.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_container
|
||||
The `st.container` that will contain all the Streamlit elements that the
|
||||
Handler creates.
|
||||
max_thought_containers
|
||||
The max number of completed LLM thought containers to show at once. When this
|
||||
threshold is reached, a new thought will cause the oldest thoughts to be
|
||||
collapsed into a "History" expander. Defaults to 4.
|
||||
expand_new_thoughts
|
||||
Each LLM "thought" gets its own `st.expander`. This param controls whether that
|
||||
expander is expanded by default. Defaults to True.
|
||||
collapse_completed_thoughts
|
||||
If True, LLM thought expanders will be collapsed when completed.
|
||||
Defaults to True.
|
||||
thought_labeler
|
||||
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
||||
will use the default thought labeling logic. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A new StreamlitCallbackHandler instance.
|
||||
|
||||
Note that this is an "auto-updating" API: if the installed version of Streamlit
|
||||
has a more recent StreamlitCallbackHandler implementation, an instance of that class
|
||||
will be used.
|
||||
|
||||
"""
|
||||
# If we're using a version of Streamlit that implements StreamlitCallbackHandler,
|
||||
# delegate to it instead of using our built-in handler. The official handler is
|
||||
# guaranteed to support the same set of kwargs.
|
||||
try:
|
||||
from streamlit.external.langchain import (
|
||||
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501
|
||||
)
|
||||
|
||||
return OfficialStreamlitCallbackHandler(
|
||||
parent_container,
|
||||
max_thought_containers=max_thought_containers,
|
||||
expand_new_thoughts=expand_new_thoughts,
|
||||
collapse_completed_thoughts=collapse_completed_thoughts,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
||||
except ImportError:
|
||||
return _InternalStreamlitCallbackHandler(
|
||||
parent_container,
|
||||
max_thought_containers=max_thought_containers,
|
||||
expand_new_thoughts=expand_new_thoughts,
|
||||
collapse_completed_thoughts=collapse_completed_thoughts,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
from streamlit.type_util import SupportsStr
|
||||
|
||||
|
||||
class ChildType(Enum):
|
||||
"""The enumerator of the child type."""
|
||||
|
||||
MARKDOWN = "MARKDOWN"
|
||||
EXCEPTION = "EXCEPTION"
|
||||
|
||||
|
||||
class ChildRecord(NamedTuple):
|
||||
"""The child record as a NamedTuple."""
|
||||
|
||||
type: ChildType
|
||||
kwargs: Dict[str, Any]
|
||||
dg: DeltaGenerator
|
||||
|
||||
|
||||
class MutableExpander:
|
||||
"""A Streamlit expander that can be renamed and dynamically expanded/collapsed."""
|
||||
|
||||
def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
|
||||
"""Create a new MutableExpander.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_container
|
||||
The `st.container` that the expander will be created inside.
|
||||
|
||||
The expander transparently deletes and recreates its underlying
|
||||
`st.expander` instance when its label changes, and it uses
|
||||
`parent_container` to ensure it recreates this underlying expander in the
|
||||
same location onscreen.
|
||||
label
|
||||
The expander's initial label.
|
||||
expanded
|
||||
The expander's initial `expanded` value.
|
||||
"""
|
||||
self._label = label
|
||||
self._expanded = expanded
|
||||
self._parent_cursor = parent_container.empty()
|
||||
self._container = self._parent_cursor.expander(label, expanded)
|
||||
self._child_records: List[ChildRecord] = []
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
"""The expander's label string."""
|
||||
return self._label
|
||||
|
||||
@property
|
||||
def expanded(self) -> bool:
|
||||
"""True if the expander was created with `expanded=True`."""
|
||||
return self._expanded
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove the container and its contents entirely. A cleared container can't
|
||||
be reused.
|
||||
"""
|
||||
self._container = self._parent_cursor.empty()
|
||||
self._child_records.clear()
|
||||
|
||||
def append_copy(self, other: MutableExpander) -> None:
|
||||
"""Append a copy of another MutableExpander's children to this
|
||||
MutableExpander.
|
||||
"""
|
||||
other_records = other._child_records.copy()
|
||||
for record in other_records:
|
||||
self._create_child(record.type, record.kwargs)
|
||||
|
||||
def update(
|
||||
self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
|
||||
) -> None:
|
||||
"""Change the expander's label and expanded state"""
|
||||
if new_label is None:
|
||||
new_label = self._label
|
||||
if new_expanded is None:
|
||||
new_expanded = self._expanded
|
||||
|
||||
if self._label == new_label and self._expanded == new_expanded:
|
||||
# No change!
|
||||
return
|
||||
|
||||
self._label = new_label
|
||||
self._expanded = new_expanded
|
||||
self._container = self._parent_cursor.expander(new_label, new_expanded)
|
||||
|
||||
prev_records = self._child_records
|
||||
self._child_records = []
|
||||
|
||||
# Replay all children into the new container
|
||||
for record in prev_records:
|
||||
self._create_child(record.type, record.kwargs)
|
||||
|
||||
def markdown(
|
||||
self,
|
||||
body: SupportsStr,
|
||||
unsafe_allow_html: bool = False,
|
||||
*,
|
||||
help: Optional[str] = None,
|
||||
index: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Add a Markdown element to the container and return its index."""
|
||||
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
|
||||
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
|
||||
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
|
||||
return self._add_record(record, index)
|
||||
|
||||
def exception(
|
||||
self, exception: BaseException, *, index: Optional[int] = None
|
||||
) -> int:
|
||||
"""Add an Exception element to the container and return its index."""
|
||||
kwargs = {"exception": exception}
|
||||
new_dg = self._get_dg(index).exception(**kwargs)
|
||||
record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
|
||||
return self._add_record(record, index)
|
||||
|
||||
def _create_child(self, type: ChildType, kwargs: Dict[str, Any]) -> None:
|
||||
"""Create a new child with the given params"""
|
||||
if type == ChildType.MARKDOWN:
|
||||
self.markdown(**kwargs)
|
||||
elif type == ChildType.EXCEPTION:
|
||||
self.exception(**kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected child type {type}")
|
||||
|
||||
def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
|
||||
"""Add a ChildRecord to self._children. If `index` is specified, replace
|
||||
the existing record at that index. Otherwise, append the record to the
|
||||
end of the list.
|
||||
|
||||
Return the index of the added record.
|
||||
"""
|
||||
if index is not None:
|
||||
# Replace existing child
|
||||
self._child_records[index] = record
|
||||
return index
|
||||
|
||||
# Append new child
|
||||
self._child_records.append(record)
|
||||
return len(self._child_records) - 1
|
||||
|
||||
def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
|
||||
if index is not None:
|
||||
# Existing index: reuse child's DeltaGenerator
|
||||
assert 0 <= index < len(self._child_records), f"Bad index: {index}"
|
||||
return self._child_records[index].dg
|
||||
|
||||
# No index: use container's DeltaGenerator
|
||||
return self._container
|
@@ -0,0 +1,414 @@
|
||||
"""Callback Handler that prints to streamlit."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.callbacks.streamlit.mutable_expander import MutableExpander
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
|
||||
def _convert_newlines(text: str) -> str:
|
||||
"""Convert newline characters to markdown newline sequences
|
||||
(space, space, newline).
|
||||
"""
|
||||
return text.replace("\n", " \n")
|
||||
|
||||
|
||||
CHECKMARK_EMOJI = "✅"
|
||||
THINKING_EMOJI = ":thinking_face:"
|
||||
HISTORY_EMOJI = ":books:"
|
||||
EXCEPTION_EMOJI = "⚠️"
|
||||
|
||||
|
||||
class LLMThoughtState(Enum):
|
||||
"""Enumerator of the LLMThought state."""
|
||||
|
||||
# The LLM is thinking about what to do next. We don't know which tool we'll run.
|
||||
THINKING = "THINKING"
|
||||
# The LLM has decided to run a tool. We don't have results from the tool yet.
|
||||
RUNNING_TOOL = "RUNNING_TOOL"
|
||||
# We have results from the tool.
|
||||
COMPLETE = "COMPLETE"
|
||||
|
||||
|
||||
class ToolRecord(NamedTuple):
|
||||
"""The tool record as a NamedTuple."""
|
||||
|
||||
name: str
|
||||
input_str: str
|
||||
|
||||
|
||||
class LLMThoughtLabeler:
|
||||
"""
|
||||
Generates markdown labels for LLMThought containers. Pass a custom
|
||||
subclass of this to StreamlitCallbackHandler to override its default
|
||||
labeling logic.
|
||||
"""
|
||||
|
||||
def get_initial_label(self) -> str:
|
||||
"""Return the markdown label for a new LLMThought that doesn't have
|
||||
an associated tool yet.
|
||||
"""
|
||||
return f"{THINKING_EMOJI} **Thinking...**"
|
||||
|
||||
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
|
||||
"""Return the label for an LLMThought that has an associated
|
||||
tool.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tool
|
||||
The tool's ToolRecord
|
||||
|
||||
is_complete
|
||||
True if the thought is complete; False if the thought
|
||||
is still receiving input.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The markdown label for the thought's container.
|
||||
|
||||
"""
|
||||
input = tool.input_str
|
||||
name = tool.name
|
||||
emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
|
||||
if name == "_Exception":
|
||||
emoji = EXCEPTION_EMOJI
|
||||
name = "Parsing error"
|
||||
idx = min([60, len(input)])
|
||||
input = input[0:idx]
|
||||
if len(tool.input_str) > idx:
|
||||
input = input + "..."
|
||||
input = input.replace("\n", " ")
|
||||
label = f"{emoji} **{name}:** {input}"
|
||||
return label
|
||||
|
||||
def get_history_label(self) -> str:
|
||||
"""Return a markdown label for the special 'history' container
|
||||
that contains overflow thoughts.
|
||||
"""
|
||||
return f"{HISTORY_EMOJI} **History**"
|
||||
|
||||
def get_final_agent_thought_label(self) -> str:
|
||||
"""Return the markdown label for the agent's final thought -
|
||||
the "Now I have the answer" thought, that doesn't involve
|
||||
a tool.
|
||||
"""
|
||||
return f"{CHECKMARK_EMOJI} **Complete!**"
|
||||
|
||||
|
||||
class LLMThought:
|
||||
"""A thought in the LLM's thought stream."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_container: DeltaGenerator,
|
||||
labeler: LLMThoughtLabeler,
|
||||
expanded: bool,
|
||||
collapse_on_complete: bool,
|
||||
):
|
||||
"""Initialize the LLMThought.
|
||||
|
||||
Args:
|
||||
parent_container: The container we're writing into.
|
||||
labeler: The labeler to use for this thought.
|
||||
expanded: Whether the thought should be expanded by default.
|
||||
collapse_on_complete: Whether the thought should be collapsed.
|
||||
"""
|
||||
self._container = MutableExpander(
|
||||
parent_container=parent_container,
|
||||
label=labeler.get_initial_label(),
|
||||
expanded=expanded,
|
||||
)
|
||||
self._state = LLMThoughtState.THINKING
|
||||
self._llm_token_stream = ""
|
||||
self._llm_token_writer_idx: Optional[int] = None
|
||||
self._last_tool: Optional[ToolRecord] = None
|
||||
self._collapse_on_complete = collapse_on_complete
|
||||
self._labeler = labeler
|
||||
|
||||
@property
|
||||
def container(self) -> MutableExpander:
|
||||
"""The container we're writing into."""
|
||||
return self._container
|
||||
|
||||
@property
|
||||
def last_tool(self) -> Optional[ToolRecord]:
|
||||
"""The last tool executed by this thought"""
|
||||
return self._last_tool
|
||||
|
||||
def _reset_llm_token_stream(self) -> None:
|
||||
self._llm_token_stream = ""
|
||||
self._llm_token_writer_idx = None
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
|
||||
self._reset_llm_token_stream()
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
# This is only called when the LLM is initialized with `streaming=True`
|
||||
self._llm_token_stream += _convert_newlines(token)
|
||||
self._llm_token_writer_idx = self._container.markdown(
|
||||
self._llm_token_stream, index=self._llm_token_writer_idx
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
# `response` is the concatenation of all the tokens received by the LLM.
|
||||
# If we're receiving streaming tokens from `on_llm_new_token`, this response
|
||||
# data is redundant
|
||||
self._reset_llm_token_stream()
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self._container.markdown("**LLM encountered an error...**")
|
||||
self._container.exception(error)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
# Called with the name of the tool we're about to run (in `serialized[name]`),
|
||||
# and its input. We change our container's label to be the tool name.
|
||||
self._state = LLMThoughtState.RUNNING_TOOL
|
||||
tool_name = serialized["name"]
|
||||
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
|
||||
self._container.update(
|
||||
new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
|
||||
)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._container.markdown(f"**{output}**")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self._container.markdown("**Tool encountered an error...**")
|
||||
self._container.exception(error)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
# Called when we're about to kick off a new tool. The `action` data
|
||||
# tells us the tool we're about to use, and the input we'll give it.
|
||||
# We don't output anything here, because we'll receive this same data
|
||||
# when `on_tool_start` is called immediately after.
|
||||
pass
|
||||
|
||||
def complete(self, final_label: Optional[str] = None) -> None:
|
||||
"""Finish the thought."""
|
||||
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
|
||||
assert (
|
||||
self._last_tool is not None
|
||||
), "_last_tool should never be null when _state == RUNNING_TOOL"
|
||||
final_label = self._labeler.get_tool_label(
|
||||
self._last_tool, is_complete=True
|
||||
)
|
||||
self._state = LLMThoughtState.COMPLETE
|
||||
if self._collapse_on_complete:
|
||||
self._container.update(new_label=final_label, new_expanded=False)
|
||||
else:
|
||||
self._container.update(new_label=final_label)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove the thought from the screen. A cleared thought can't be reused."""
|
||||
self._container.clear()
|
||||
|
||||
|
||||
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
"""A callback handler that writes to a Streamlit app."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_container: DeltaGenerator,
|
||||
*,
|
||||
max_thought_containers: int = 4,
|
||||
expand_new_thoughts: bool = True,
|
||||
collapse_completed_thoughts: bool = True,
|
||||
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
||||
):
|
||||
"""Create a StreamlitCallbackHandler instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_container
|
||||
The `st.container` that will contain all the Streamlit elements that the
|
||||
Handler creates.
|
||||
max_thought_containers
|
||||
The max number of completed LLM thought containers to show at once. When
|
||||
this threshold is reached, a new thought will cause the oldest thoughts to
|
||||
be collapsed into a "History" expander. Defaults to 4.
|
||||
expand_new_thoughts
|
||||
Each LLM "thought" gets its own `st.expander`. This param controls whether
|
||||
that expander is expanded by default. Defaults to True.
|
||||
collapse_completed_thoughts
|
||||
If True, LLM thought expanders will be collapsed when completed.
|
||||
Defaults to True.
|
||||
thought_labeler
|
||||
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
||||
will use the default thought labeling logic. Defaults to None.
|
||||
"""
|
||||
self._parent_container = parent_container
|
||||
self._history_parent = parent_container.container()
|
||||
self._history_container: Optional[MutableExpander] = None
|
||||
self._current_thought: Optional[LLMThought] = None
|
||||
self._completed_thoughts: List[LLMThought] = []
|
||||
self._max_thought_containers = max(max_thought_containers, 1)
|
||||
self._expand_new_thoughts = expand_new_thoughts
|
||||
self._collapse_completed_thoughts = collapse_completed_thoughts
|
||||
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
|
||||
|
||||
def _require_current_thought(self) -> LLMThought:
|
||||
"""Return our current LLMThought. Raise an error if we have no current
|
||||
thought.
|
||||
"""
|
||||
if self._current_thought is None:
|
||||
raise RuntimeError("Current LLMThought is unexpectedly None!")
|
||||
return self._current_thought
|
||||
|
||||
def _get_last_completed_thought(self) -> Optional[LLMThought]:
|
||||
"""Return our most recent completed LLMThought, or None if we don't have one."""
|
||||
if len(self._completed_thoughts) > 0:
|
||||
return self._completed_thoughts[len(self._completed_thoughts) - 1]
|
||||
return None
|
||||
|
||||
@property
|
||||
def _num_thought_containers(self) -> int:
|
||||
"""The number of 'thought containers' we're currently showing: the
|
||||
number of completed thought containers, the history container (if it exists),
|
||||
and the current thought container (if it exists).
|
||||
"""
|
||||
count = len(self._completed_thoughts)
|
||||
if self._history_container is not None:
|
||||
count += 1
|
||||
if self._current_thought is not None:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
|
||||
"""Complete the current thought, optionally assigning it a new label.
|
||||
Add it to our _completed_thoughts list.
|
||||
"""
|
||||
thought = self._require_current_thought()
|
||||
thought.complete(final_label)
|
||||
self._completed_thoughts.append(thought)
|
||||
self._current_thought = None
|
||||
|
||||
def _prune_old_thought_containers(self) -> None:
|
||||
"""If we have too many thoughts onscreen, move older thoughts to the
|
||||
'history container.'
|
||||
"""
|
||||
while (
|
||||
self._num_thought_containers > self._max_thought_containers
|
||||
and len(self._completed_thoughts) > 0
|
||||
):
|
||||
# Create our history container if it doesn't exist, and if
|
||||
# max_thought_containers is > 1. (if max_thought_containers is 1, we don't
|
||||
# have room to show history.)
|
||||
if self._history_container is None and self._max_thought_containers > 1:
|
||||
self._history_container = MutableExpander(
|
||||
self._history_parent,
|
||||
label=self._thought_labeler.get_history_label(),
|
||||
expanded=False,
|
||||
)
|
||||
|
||||
oldest_thought = self._completed_thoughts.pop(0)
|
||||
if self._history_container is not None:
|
||||
self._history_container.markdown(oldest_thought.container.label)
|
||||
self._history_container.append_copy(oldest_thought.container)
|
||||
oldest_thought.clear()
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
if self._current_thought is None:
|
||||
self._current_thought = LLMThought(
|
||||
parent_container=self._parent_container,
|
||||
expanded=self._expand_new_thoughts,
|
||||
collapse_on_complete=self._collapse_completed_thoughts,
|
||||
labeler=self._thought_labeler,
|
||||
)
|
||||
|
||||
self._current_thought.on_llm_start(serialized, prompts)
|
||||
|
||||
# We don't prune_old_thought_containers here, because our container won't
|
||||
# be visible until it has a child.
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_llm_new_token(token, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_llm_end(response, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_llm_error(error, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._require_current_thought().on_tool_end(
|
||||
output, color, observation_prefix, llm_prefix, **kwargs
|
||||
)
|
||||
self._complete_current_thought()
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_tool_error(error, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
self._require_current_thought().on_agent_action(action, color, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
if self._current_thought is not None:
|
||||
self._current_thought.complete(
|
||||
self._thought_labeler.get_final_agent_thought_label()
|
||||
)
|
||||
self._current_thought = None
|
@@ -0,0 +1,18 @@
|
||||
"""Tracers that record execution of LangChain runs."""
|
||||
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
|
||||
from langchain_core.tracers.stdout import (
|
||||
ConsoleCallbackHandler,
|
||||
FunctionCallbackHandler,
|
||||
)
|
||||
|
||||
from langchain_community.callbacks.tracers.wandb import WandbTracer
|
||||
|
||||
__all__ = [
|
||||
"ConsoleCallbackHandler",
|
||||
"FunctionCallbackHandler",
|
||||
"LangChainTracer",
|
||||
"LangChainTracerV1",
|
||||
"WandbTracer",
|
||||
]
|
138
libs/community/langchain_community/callbacks/tracers/comet.py
Normal file
138
libs/community/langchain_community/callbacks/tracers/comet.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict
|
||||
|
||||
from langchain_core.tracers import BaseTracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from comet_llm import Span
|
||||
from comet_llm.chains.chain import Chain
|
||||
|
||||
from langchain_community.callbacks.tracers.schemas import Run
|
||||
|
||||
|
||||
def _get_run_type(run: "Run") -> str:
|
||||
if isinstance(run.run_type, str):
|
||||
return run.run_type
|
||||
elif hasattr(run.run_type, "value"):
|
||||
return run.run_type.value
|
||||
else:
|
||||
return str(run.run_type)
|
||||
|
||||
|
||||
def import_comet_llm_api() -> SimpleNamespace:
|
||||
"""Import comet_llm api and raise an error if it is not installed."""
|
||||
try:
|
||||
from comet_llm import (
|
||||
experiment_info, # noqa: F401
|
||||
flush, # noqa: F401
|
||||
)
|
||||
from comet_llm.chains import api as chain_api # noqa: F401
|
||||
from comet_llm.chains import (
|
||||
chain, # noqa: F401
|
||||
span, # noqa: F401
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the CometTracer you need to have the "
|
||||
"`comet_llm>=2.0.0` python package installed. Please install it with"
|
||||
" `pip install -U comet_llm`"
|
||||
)
|
||||
return SimpleNamespace(
|
||||
chain=chain,
|
||||
span=span,
|
||||
chain_api=chain_api,
|
||||
experiment_info=experiment_info,
|
||||
flush=flush,
|
||||
)
|
||||
|
||||
|
||||
class CometTracer(BaseTracer):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._span_map: Dict["UUID", "Span"] = {}
|
||||
self._chains_map: Dict["UUID", "Chain"] = {}
|
||||
self._initialize_comet_modules()
|
||||
|
||||
def _initialize_comet_modules(self) -> None:
|
||||
comet_llm_api = import_comet_llm_api()
|
||||
self._chain: ModuleType = comet_llm_api.chain
|
||||
self._span: ModuleType = comet_llm_api.span
|
||||
self._chain_api: ModuleType = comet_llm_api.chain_api
|
||||
self._experiment_info: ModuleType = comet_llm_api.experiment_info
|
||||
self._flush: Callable[[], None] = comet_llm_api.flush
|
||||
|
||||
def _persist_run(self, run: "Run") -> None:
|
||||
chain_ = self._chains_map[run.id]
|
||||
chain_.set_outputs(outputs=run.outputs)
|
||||
self._chain_api.log_chain(chain_)
|
||||
|
||||
def _process_start_trace(self, run: "Run") -> None:
|
||||
if not run.parent_run_id:
|
||||
# This is the first run, which maps to a chain
|
||||
chain_: "Chain" = self._chain.Chain(
|
||||
inputs=run.inputs,
|
||||
metadata=None,
|
||||
experiment_info=self._experiment_info.get(),
|
||||
)
|
||||
self._chains_map[run.id] = chain_
|
||||
else:
|
||||
span: "Span" = self._span.Span(
|
||||
inputs=run.inputs,
|
||||
category=_get_run_type(run),
|
||||
metadata=run.extra,
|
||||
name=run.name,
|
||||
)
|
||||
span.__api__start__(self._chains_map[run.parent_run_id])
|
||||
self._chains_map[run.id] = self._chains_map[run.parent_run_id]
|
||||
self._span_map[run.id] = span
|
||||
|
||||
def _process_end_trace(self, run: "Run") -> None:
|
||||
if not run.parent_run_id:
|
||||
pass
|
||||
# Langchain will call _persist_run for us
|
||||
else:
|
||||
span = self._span_map[run.id]
|
||||
span.set_outputs(outputs=run.outputs)
|
||||
span.__api__end__()
|
||||
|
||||
def flush(self) -> None:
|
||||
self._flush()
|
||||
|
||||
def _on_llm_start(self, run: "Run") -> None:
|
||||
"""Process the LLM Run upon start."""
|
||||
self._process_start_trace(run)
|
||||
|
||||
def _on_llm_end(self, run: "Run") -> None:
|
||||
"""Process the LLM Run."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_llm_error(self, run: "Run") -> None:
|
||||
"""Process the LLM Run upon error."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_chain_start(self, run: "Run") -> None:
|
||||
"""Process the Chain Run upon start."""
|
||||
self._process_start_trace(run)
|
||||
|
||||
def _on_chain_end(self, run: "Run") -> None:
|
||||
"""Process the Chain Run."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_chain_error(self, run: "Run") -> None:
|
||||
"""Process the Chain Run upon error."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_tool_start(self, run: "Run") -> None:
|
||||
"""Process the Tool Run upon start."""
|
||||
self._process_start_trace(run)
|
||||
|
||||
def _on_tool_end(self, run: "Run") -> None:
|
||||
"""Process the Tool Run."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_tool_error(self, run: "Run") -> None:
|
||||
"""Process the Tool Run upon error."""
|
||||
self._process_end_trace(run)
|
514
libs/community/langchain_community/callbacks/tracers/wandb.py
Normal file
514
libs/community/langchain_community/callbacks/tracers/wandb.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""A Tracer Implementation that records activity to Weights & Biases."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from wandb import Settings as WBSettings
|
||||
from wandb.sdk.data_types.trace_tree import Span
|
||||
from wandb.sdk.lib.paths import StrPath
|
||||
from wandb.wandb_run import Run as WBRun
|
||||
|
||||
|
||||
PRINT_WARNINGS = True
|
||||
|
||||
|
||||
def _serialize_io(run_inputs: Optional[dict]) -> dict:
|
||||
if not run_inputs:
|
||||
return {}
|
||||
from google.protobuf.json_format import MessageToJson
|
||||
from google.protobuf.message import Message
|
||||
|
||||
serialized_inputs = {}
|
||||
for key, value in run_inputs.items():
|
||||
if isinstance(value, Message):
|
||||
serialized_inputs[key] = MessageToJson(value)
|
||||
elif key == "input_documents":
|
||||
serialized_inputs.update(
|
||||
{f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
|
||||
)
|
||||
else:
|
||||
serialized_inputs[key] = value
|
||||
return serialized_inputs
|
||||
|
||||
|
||||
class RunProcessor:
|
||||
"""Handles the conversion of a LangChain Runs into a WBTraceTree."""
|
||||
|
||||
def __init__(self, wandb_module: Any, trace_module: Any):
|
||||
self.wandb = wandb_module
|
||||
self.trace_tree = trace_module
|
||||
|
||||
def process_span(self, run: Run) -> Optional["Span"]:
|
||||
"""Converts a LangChain Run into a W&B Trace Span.
|
||||
:param run: The LangChain Run to convert.
|
||||
:return: The converted W&B Trace Span.
|
||||
"""
|
||||
try:
|
||||
span = self._convert_lc_run_to_wb_span(run)
|
||||
return span
|
||||
except Exception as e:
|
||||
if PRINT_WARNINGS:
|
||||
self.wandb.termwarn(
|
||||
f"Skipping trace saving - unable to safely convert LangChain Run "
|
||||
f"into W&B Trace due to: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _convert_run_to_wb_span(self, run: Run) -> "Span":
|
||||
"""Base utility to create a span from a run.
|
||||
:param run: The run to convert.
|
||||
:return: The converted Span.
|
||||
"""
|
||||
attributes = {**run.extra} if run.extra else {}
|
||||
attributes["execution_order"] = run.execution_order
|
||||
|
||||
return self.trace_tree.Span(
|
||||
span_id=str(run.id) if run.id is not None else None,
|
||||
name=run.name,
|
||||
start_time_ms=int(run.start_time.timestamp() * 1000),
|
||||
end_time_ms=int(run.end_time.timestamp() * 1000)
|
||||
if run.end_time is not None
|
||||
else None,
|
||||
status_code=self.trace_tree.StatusCode.SUCCESS
|
||||
if run.error is None
|
||||
else self.trace_tree.StatusCode.ERROR,
|
||||
status_message=run.error,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
def _convert_llm_run_to_wb_span(self, run: Run) -> "Span":
|
||||
"""Converts a LangChain LLM Run into a W&B Trace Span.
|
||||
:param run: The LangChain LLM Run to convert.
|
||||
:return: The converted W&B Trace Span.
|
||||
"""
|
||||
base_span = self._convert_run_to_wb_span(run)
|
||||
if base_span.attributes is None:
|
||||
base_span.attributes = {}
|
||||
base_span.attributes["llm_output"] = (run.outputs or {}).get("llm_output", {})
|
||||
|
||||
base_span.results = [
|
||||
self.trace_tree.Result(
|
||||
inputs={"prompt": prompt},
|
||||
outputs={
|
||||
f"gen_{g_i}": gen["text"]
|
||||
for g_i, gen in enumerate(run.outputs["generations"][ndx])
|
||||
}
|
||||
if (
|
||||
run.outputs is not None
|
||||
and len(run.outputs["generations"]) > ndx
|
||||
and len(run.outputs["generations"][ndx]) > 0
|
||||
)
|
||||
else None,
|
||||
)
|
||||
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
|
||||
]
|
||||
base_span.span_kind = self.trace_tree.SpanKind.LLM
|
||||
|
||||
return base_span
|
||||
|
||||
def _convert_chain_run_to_wb_span(self, run: Run) -> "Span":
|
||||
"""Converts a LangChain Chain Run into a W&B Trace Span.
|
||||
:param run: The LangChain Chain Run to convert.
|
||||
:return: The converted W&B Trace Span.
|
||||
"""
|
||||
base_span = self._convert_run_to_wb_span(run)
|
||||
|
||||
base_span.results = [
|
||||
self.trace_tree.Result(
|
||||
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
|
||||
)
|
||||
]
|
||||
base_span.child_spans = [
|
||||
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
|
||||
]
|
||||
base_span.span_kind = (
|
||||
self.trace_tree.SpanKind.AGENT
|
||||
if "agent" in run.name.lower()
|
||||
else self.trace_tree.SpanKind.CHAIN
|
||||
)
|
||||
|
||||
return base_span
|
||||
|
||||
def _convert_tool_run_to_wb_span(self, run: Run) -> "Span":
|
||||
"""Converts a LangChain Tool Run into a W&B Trace Span.
|
||||
:param run: The LangChain Tool Run to convert.
|
||||
:return: The converted W&B Trace Span.
|
||||
"""
|
||||
base_span = self._convert_run_to_wb_span(run)
|
||||
base_span.results = [
|
||||
self.trace_tree.Result(
|
||||
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
|
||||
)
|
||||
]
|
||||
base_span.child_spans = [
|
||||
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
|
||||
]
|
||||
base_span.span_kind = self.trace_tree.SpanKind.TOOL
|
||||
|
||||
return base_span
|
||||
|
||||
def _convert_lc_run_to_wb_span(self, run: Run) -> "Span":
|
||||
"""Utility to convert any generic LangChain Run into a W&B Trace Span.
|
||||
:param run: The LangChain Run to convert.
|
||||
:return: The converted W&B Trace Span.
|
||||
"""
|
||||
if run.run_type == "llm":
|
||||
return self._convert_llm_run_to_wb_span(run)
|
||||
elif run.run_type == "chain":
|
||||
return self._convert_chain_run_to_wb_span(run)
|
||||
elif run.run_type == "tool":
|
||||
return self._convert_tool_run_to_wb_span(run)
|
||||
else:
|
||||
return self._convert_run_to_wb_span(run)
|
||||
|
||||
def process_model(self, run: Run) -> Optional[Dict[str, Any]]:
|
||||
"""Utility to process a run for wandb model_dict serialization.
|
||||
:param run: The run to process.
|
||||
:return: The convert model_dict to pass to WBTraceTree.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(run.json())
|
||||
processed = self.flatten_run(data)
|
||||
keep_keys = (
|
||||
"id",
|
||||
"name",
|
||||
"serialized",
|
||||
"inputs",
|
||||
"outputs",
|
||||
"parent_run_id",
|
||||
"execution_order",
|
||||
)
|
||||
processed = self.truncate_run_iterative(processed, keep_keys=keep_keys)
|
||||
exact_keys, partial_keys = ("lc", "type"), ("api_key",)
|
||||
processed = self.modify_serialized_iterative(
|
||||
processed, exact_keys=exact_keys, partial_keys=partial_keys
|
||||
)
|
||||
output = self.build_tree(processed)
|
||||
return output
|
||||
except Exception as e:
|
||||
if PRINT_WARNINGS:
|
||||
self.wandb.termwarn(f"WARNING: Failed to serialize model: {e}")
|
||||
return None
|
||||
|
||||
def flatten_run(self, run: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Utility to flatten a nest run object into a list of runs.
|
||||
:param run: The base run to flatten.
|
||||
:return: The flattened list of runs.
|
||||
"""
|
||||
|
||||
def flatten(child_runs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Utility to recursively flatten a list of child runs in a run.
|
||||
:param child_runs: The list of child runs to flatten.
|
||||
:return: The flattened list of runs.
|
||||
"""
|
||||
if child_runs is None:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for item in child_runs:
|
||||
child_runs = item.pop("child_runs", [])
|
||||
result.append(item)
|
||||
result.extend(flatten(child_runs))
|
||||
|
||||
return result
|
||||
|
||||
return flatten([run])
|
||||
|
||||
def truncate_run_iterative(
|
||||
self, runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = ()
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Utility to truncate a list of runs dictionaries to only keep the specified
|
||||
keys in each run.
|
||||
:param runs: The list of runs to truncate.
|
||||
:param keep_keys: The keys to keep in each run.
|
||||
:return: The truncated list of runs.
|
||||
"""
|
||||
|
||||
def truncate_single(run: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Utility to truncate a single run dictionary to only keep the specified
|
||||
keys.
|
||||
:param run: The run dictionary to truncate.
|
||||
:return: The truncated run dictionary
|
||||
"""
|
||||
new_dict = {}
|
||||
for key in run:
|
||||
if key in keep_keys:
|
||||
new_dict[key] = run.get(key)
|
||||
return new_dict
|
||||
|
||||
return list(map(truncate_single, runs))
|
||||
|
||||
def modify_serialized_iterative(
|
||||
self,
|
||||
runs: List[Dict[str, Any]],
|
||||
exact_keys: Tuple[str, ...] = (),
|
||||
partial_keys: Tuple[str, ...] = (),
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Utility to modify the serialized field of a list of runs dictionaries.
|
||||
removes any keys that match the exact_keys and any keys that contain any of the
|
||||
partial_keys.
|
||||
recursively moves the dictionaries under the kwargs key to the top level.
|
||||
changes the "id" field to a string "_kind" field that tells WBTraceTree how to
|
||||
visualize the run. promotes the "serialized" field to the top level.
|
||||
|
||||
:param runs: The list of runs to modify.
|
||||
:param exact_keys: A tuple of keys to remove from the serialized field.
|
||||
:param partial_keys: A tuple of partial keys to remove from the serialized
|
||||
field.
|
||||
:return: The modified list of runs.
|
||||
"""
|
||||
|
||||
def remove_exact_and_partial_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Recursively removes exact and partial keys from a dictionary.
|
||||
:param obj: The dictionary to remove keys from.
|
||||
:return: The modified dictionary.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
obj = {
|
||||
k: v
|
||||
for k, v in obj.items()
|
||||
if k not in exact_keys
|
||||
and not any(partial in k for partial in partial_keys)
|
||||
}
|
||||
for k, v in obj.items():
|
||||
obj[k] = remove_exact_and_partial_keys(v)
|
||||
elif isinstance(obj, list):
|
||||
obj = [remove_exact_and_partial_keys(x) for x in obj]
|
||||
return obj
|
||||
|
||||
def handle_id_and_kwargs(
|
||||
obj: Dict[str, Any], root: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Recursively handles the id and kwargs fields of a dictionary.
|
||||
changes the id field to a string "_kind" field that tells WBTraceTree how
|
||||
to visualize the run. recursively moves the dictionaries under the kwargs
|
||||
key to the top level.
|
||||
:param obj: a run dictionary with id and kwargs fields.
|
||||
:param root: whether this is the root dictionary or the serialized
|
||||
dictionary.
|
||||
:return: The modified dictionary.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
if ("id" in obj or "name" in obj) and not root:
|
||||
_kind = obj.get("id")
|
||||
if not _kind:
|
||||
_kind = [obj.get("name")]
|
||||
obj["_kind"] = _kind[-1]
|
||||
obj.pop("id", None)
|
||||
obj.pop("name", None)
|
||||
if "kwargs" in obj:
|
||||
kwargs = obj.pop("kwargs")
|
||||
for k, v in kwargs.items():
|
||||
obj[k] = v
|
||||
for k, v in obj.items():
|
||||
obj[k] = handle_id_and_kwargs(v)
|
||||
elif isinstance(obj, list):
|
||||
obj = [handle_id_and_kwargs(x) for x in obj]
|
||||
return obj
|
||||
|
||||
def transform_serialized(serialized: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transforms the serialized field of a run dictionary to be compatible
|
||||
with WBTraceTree.
|
||||
:param serialized: The serialized field of a run dictionary.
|
||||
:return: The transformed serialized field.
|
||||
"""
|
||||
serialized = handle_id_and_kwargs(serialized, root=True)
|
||||
serialized = remove_exact_and_partial_keys(serialized)
|
||||
return serialized
|
||||
|
||||
def transform_run(run: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transforms a run dictionary to be compatible with WBTraceTree.
|
||||
:param run: The run dictionary to transform.
|
||||
:return: The transformed run dictionary.
|
||||
"""
|
||||
transformed_dict = transform_serialized(run)
|
||||
|
||||
serialized = transformed_dict.pop("serialized")
|
||||
for k, v in serialized.items():
|
||||
transformed_dict[k] = v
|
||||
|
||||
_kind = transformed_dict.get("_kind", None)
|
||||
name = transformed_dict.pop("name", None)
|
||||
exec_ord = transformed_dict.pop("execution_order", None)
|
||||
|
||||
if not name:
|
||||
name = _kind
|
||||
|
||||
output_dict = {
|
||||
f"{exec_ord}_{name}": transformed_dict,
|
||||
}
|
||||
return output_dict
|
||||
|
||||
return list(map(transform_run, runs))
|
||||
|
||||
def build_tree(self, runs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Builds a nested dictionary from a list of runs.
|
||||
:param runs: The list of runs to build the tree from.
|
||||
:return: The nested dictionary representing the langchain Run in a tree
|
||||
structure compatible with WBTraceTree.
|
||||
"""
|
||||
id_to_data = {}
|
||||
child_to_parent = {}
|
||||
|
||||
for entity in runs:
|
||||
for key, data in entity.items():
|
||||
id_val = data.pop("id", None)
|
||||
parent_run_id = data.pop("parent_run_id", None)
|
||||
id_to_data[id_val] = {key: data}
|
||||
if parent_run_id:
|
||||
child_to_parent[id_val] = parent_run_id
|
||||
|
||||
for child_id, parent_id in child_to_parent.items():
|
||||
parent_dict = id_to_data[parent_id]
|
||||
parent_dict[next(iter(parent_dict))][
|
||||
next(iter(id_to_data[child_id]))
|
||||
] = id_to_data[child_id][next(iter(id_to_data[child_id]))]
|
||||
|
||||
root_dict = next(
|
||||
data for id_val, data in id_to_data.items() if id_val not in child_to_parent
|
||||
)
|
||||
|
||||
return root_dict
|
||||
|
||||
|
||||
class WandbRunArgs(TypedDict):
|
||||
"""Arguments for the WandbTracer."""
|
||||
|
||||
job_type: Optional[str]
|
||||
dir: Optional[StrPath]
|
||||
config: Union[Dict, str, None]
|
||||
project: Optional[str]
|
||||
entity: Optional[str]
|
||||
reinit: Optional[bool]
|
||||
tags: Optional[Sequence]
|
||||
group: Optional[str]
|
||||
name: Optional[str]
|
||||
notes: Optional[str]
|
||||
magic: Optional[Union[dict, str, bool]]
|
||||
config_exclude_keys: Optional[List[str]]
|
||||
config_include_keys: Optional[List[str]]
|
||||
anonymous: Optional[str]
|
||||
mode: Optional[str]
|
||||
allow_val_change: Optional[bool]
|
||||
resume: Optional[Union[bool, str]]
|
||||
force: Optional[bool]
|
||||
tensorboard: Optional[bool]
|
||||
sync_tensorboard: Optional[bool]
|
||||
monitor_gym: Optional[bool]
|
||||
save_code: Optional[bool]
|
||||
id: Optional[str]
|
||||
settings: Union[WBSettings, Dict[str, Any], None]
|
||||
|
||||
|
||||
class WandbTracer(BaseTracer):
|
||||
"""Callback Handler that logs to Weights and Biases.
|
||||
|
||||
This handler will log the model architecture and run traces to Weights and Biases.
|
||||
This will ensure that all LangChain activity is logged to W&B.
|
||||
"""
|
||||
|
||||
_run: Optional[WBRun] = None
|
||||
_run_args: Optional[WandbRunArgs] = None
|
||||
|
||||
def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
|
||||
"""Initializes the WandbTracer.
|
||||
|
||||
Parameters:
|
||||
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
|
||||
provided, `wandb.init()` will be called with no arguments. Please
|
||||
refer to the `wandb.init` for more details.
|
||||
|
||||
To use W&B to monitor all LangChain activity, add this tracer like any other
|
||||
LangChain callback:
|
||||
```
|
||||
from wandb.integration.langchain import WandbTracer
|
||||
|
||||
tracer = WandbTracer()
|
||||
chain = LLMChain(llm, callbacks=[tracer])
|
||||
# ...end of notebook / script:
|
||||
tracer.finish()
|
||||
```
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import wandb
|
||||
from wandb.sdk.data_types import trace_tree
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import wandb python package."
|
||||
"Please install it with `pip install -U wandb`."
|
||||
) from e
|
||||
self._wandb = wandb
|
||||
self._trace_tree = trace_tree
|
||||
self._run_args = run_args
|
||||
self._ensure_run(should_print_url=(wandb.run is None))
|
||||
self.run_processor = RunProcessor(self._wandb, self._trace_tree)
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Waits for all asynchronous processes to finish and data to upload.
|
||||
|
||||
Proxy for `wandb.finish()`.
|
||||
"""
|
||||
self._wandb.finish()
|
||||
|
||||
def _log_trace_from_run(self, run: Run) -> None:
|
||||
"""Logs a LangChain Run to W*B as a W&B Trace."""
|
||||
self._ensure_run()
|
||||
|
||||
root_span = self.run_processor.process_span(run)
|
||||
model_dict = self.run_processor.process_model(run)
|
||||
|
||||
if root_span is None:
|
||||
return
|
||||
|
||||
model_trace = self._trace_tree.WBTraceTree(
|
||||
root_span=root_span,
|
||||
model_dict=model_dict,
|
||||
)
|
||||
if self._wandb.run is not None:
|
||||
self._wandb.run.log({"langchain_trace": model_trace})
|
||||
|
||||
def _ensure_run(self, should_print_url: bool = False) -> None:
|
||||
"""Ensures an active W&B run exists.
|
||||
|
||||
If not, will start a new run with the provided run_args.
|
||||
"""
|
||||
if self._wandb.run is None:
|
||||
run_args = self._run_args or {} # type: ignore
|
||||
run_args: dict = {**run_args} # type: ignore
|
||||
|
||||
if "settings" not in run_args: # type: ignore
|
||||
run_args["settings"] = {"silent": True} # type: ignore
|
||||
|
||||
self._wandb.init(**run_args)
|
||||
if self._wandb.run is not None:
|
||||
if should_print_url:
|
||||
run_url = self._wandb.run.settings.run_url
|
||||
self._wandb.termlog(
|
||||
f"Streaming LangChain activity to W&B at {run_url}\n"
|
||||
"`WandbTracer` is currently in beta.\n"
|
||||
"Please report any issues to "
|
||||
"https://github.com/wandb/wandb/issues with the tag "
|
||||
"`langchain`."
|
||||
)
|
||||
|
||||
self._wandb.run._label(repo="langchain")
|
||||
|
||||
def _persist_run(self, run: "Run") -> None:
|
||||
"""Persist a run."""
|
||||
self._log_trace_from_run(run)
|
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
class TrubricsCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Callback handler for Trubrics.
|
||||
|
||||
Args:
|
||||
project: a trubrics project, default project is "default"
|
||||
email: a trubrics account email, can equally be set in env variables
|
||||
password: a trubrics account password, can equally be set in env variables
|
||||
**kwargs: all other kwargs are parsed and set to trubrics prompt variables,
|
||||
or added to the `metadata` dict
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str = "default",
|
||||
email: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
try:
|
||||
from trubrics import Trubrics
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The TrubricsCallbackHandler requires installation of "
|
||||
"the trubrics package. "
|
||||
"Please install it with `pip install trubrics`."
|
||||
)
|
||||
|
||||
self.trubrics = Trubrics(
|
||||
project=project,
|
||||
email=email or os.environ["TRUBRICS_EMAIL"],
|
||||
password=password or os.environ["TRUBRICS_PASSWORD"],
|
||||
)
|
||||
self.config_model: dict = {}
|
||||
self.prompt: Optional[str] = None
|
||||
self.messages: Optional[list] = None
|
||||
self.trubrics_kwargs: Optional[dict] = kwargs if kwargs else None
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.prompt = prompts[0]
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.messages = [_convert_message_to_dict(message) for message in messages[0]]
|
||||
self.prompt = self.messages[-1]["content"]
|
||||
|
||||
def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None:
|
||||
tags = ["langchain"]
|
||||
user_id = None
|
||||
session_id = None
|
||||
metadata: dict = {"langchain_run_id": run_id}
|
||||
if self.messages:
|
||||
metadata["messages"] = self.messages
|
||||
if self.trubrics_kwargs:
|
||||
if self.trubrics_kwargs.get("tags"):
|
||||
tags.append(*self.trubrics_kwargs.pop("tags"))
|
||||
user_id = self.trubrics_kwargs.pop("user_id", None)
|
||||
session_id = self.trubrics_kwargs.pop("session_id", None)
|
||||
metadata.update(self.trubrics_kwargs)
|
||||
|
||||
for generation in response.generations:
|
||||
self.trubrics.log_prompt(
|
||||
config_model={
|
||||
"model": response.llm_output.get("model_name")
|
||||
if response.llm_output
|
||||
else "NA"
|
||||
},
|
||||
prompt=self.prompt,
|
||||
generation=generation[0].text,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
)
|
258
libs/community/langchain_community/callbacks/utils.py
Normal file
258
libs/community/langchain_community/callbacks/utils.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Tuple, Union
|
||||
|
||||
|
||||
def import_spacy() -> Any:
|
||||
"""Import the spacy python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import spacy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"This callback manager requires the `spacy` python "
|
||||
"package installed. Please install it with `pip install spacy`"
|
||||
)
|
||||
return spacy
|
||||
|
||||
|
||||
def import_pandas() -> Any:
|
||||
"""Import the pandas python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import pandas
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"This callback manager requires the `pandas` python "
|
||||
"package installed. Please install it with `pip install pandas`"
|
||||
)
|
||||
return pandas
|
||||
|
||||
|
||||
def import_textstat() -> Any:
|
||||
"""Import the textstat python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import textstat
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"This callback manager requires the `textstat` python "
|
||||
"package installed. Please install it with `pip install textstat`"
|
||||
)
|
||||
return textstat
|
||||
|
||||
|
||||
def _flatten_dict(
|
||||
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
|
||||
) -> Iterable[Tuple[str, Any]]:
|
||||
"""
|
||||
Generator that yields flattened items from a nested dictionary for a flat dict.
|
||||
|
||||
Parameters:
|
||||
nested_dict (dict): The nested dictionary to flatten.
|
||||
parent_key (str): The prefix to prepend to the keys of the flattened dict.
|
||||
sep (str): The separator to use between the parent key and the key of the
|
||||
flattened dictionary.
|
||||
|
||||
Yields:
|
||||
(str, any): A key-value pair from the flattened dictionary.
|
||||
"""
|
||||
for key, value in nested_dict.items():
|
||||
new_key = parent_key + sep + key if parent_key else key
|
||||
if isinstance(value, dict):
|
||||
yield from _flatten_dict(value, new_key, sep)
|
||||
else:
|
||||
yield new_key, value
|
||||
|
||||
|
||||
def flatten_dict(
|
||||
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
|
||||
) -> Dict[str, Any]:
|
||||
"""Flattens a nested dictionary into a flat dictionary.
|
||||
|
||||
Parameters:
|
||||
nested_dict (dict): The nested dictionary to flatten.
|
||||
parent_key (str): The prefix to prepend to the keys of the flattened dict.
|
||||
sep (str): The separator to use between the parent key and the key of the
|
||||
flattened dictionary.
|
||||
|
||||
Returns:
|
||||
(dict): A flat dictionary.
|
||||
|
||||
"""
|
||||
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
||||
return flat_dict
|
||||
|
||||
|
||||
def hash_string(s: str) -> str:
|
||||
"""Hash a string using sha1.
|
||||
|
||||
Parameters:
|
||||
s (str): The string to hash.
|
||||
|
||||
Returns:
|
||||
(str): The hashed string.
|
||||
"""
|
||||
return hashlib.sha1(s.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def load_json(json_path: Union[str, Path]) -> str:
|
||||
"""Load json file to a string.
|
||||
|
||||
Parameters:
|
||||
json_path (str): The path to the json file.
|
||||
|
||||
Returns:
|
||||
(str): The string representation of the json file.
|
||||
"""
|
||||
with open(json_path, "r") as f:
|
||||
data = f.read()
|
||||
return data
|
||||
|
||||
|
||||
class BaseMetadataCallbackHandler:
|
||||
"""This class handles the metadata and associated function states for callbacks.
|
||||
|
||||
Attributes:
|
||||
step (int): The current step.
|
||||
starts (int): The number of times the start method has been called.
|
||||
ends (int): The number of times the end method has been called.
|
||||
errors (int): The number of times the error method has been called.
|
||||
text_ctr (int): The number of times the text method has been called.
|
||||
ignore_llm_ (bool): Whether to ignore llm callbacks.
|
||||
ignore_chain_ (bool): Whether to ignore chain callbacks.
|
||||
ignore_agent_ (bool): Whether to ignore agent callbacks.
|
||||
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
|
||||
always_verbose_ (bool): Whether to always be verbose.
|
||||
chain_starts (int): The number of times the chain start method has been called.
|
||||
chain_ends (int): The number of times the chain end method has been called.
|
||||
llm_starts (int): The number of times the llm start method has been called.
|
||||
llm_ends (int): The number of times the llm end method has been called.
|
||||
llm_streams (int): The number of times the text method has been called.
|
||||
tool_starts (int): The number of times the tool start method has been called.
|
||||
tool_ends (int): The number of times the tool end method has been called.
|
||||
agent_ends (int): The number of times the agent end method has been called.
|
||||
on_llm_start_records (list): A list of records of the on_llm_start method.
|
||||
on_llm_token_records (list): A list of records of the on_llm_token method.
|
||||
on_llm_end_records (list): A list of records of the on_llm_end method.
|
||||
on_chain_start_records (list): A list of records of the on_chain_start method.
|
||||
on_chain_end_records (list): A list of records of the on_chain_end method.
|
||||
on_tool_start_records (list): A list of records of the on_tool_start method.
|
||||
on_tool_end_records (list): A list of records of the on_tool_end method.
|
||||
on_agent_finish_records (list): A list of records of the on_agent_end method.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.ignore_retriever_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
self.on_llm_start_records: list = []
|
||||
self.on_llm_token_records: list = []
|
||||
self.on_llm_end_records: list = []
|
||||
|
||||
self.on_chain_start_records: list = []
|
||||
self.on_chain_end_records: list = []
|
||||
|
||||
self.on_tool_start_records: list = []
|
||||
self.on_tool_end_records: list = []
|
||||
|
||||
self.on_text_records: list = []
|
||||
self.on_agent_finish_records: list = []
|
||||
self.on_agent_action_records: list = []
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return self.always_verbose_
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
def get_custom_callback_meta(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step": self.step,
|
||||
"starts": self.starts,
|
||||
"ends": self.ends,
|
||||
"errors": self.errors,
|
||||
"text_ctr": self.text_ctr,
|
||||
"chain_starts": self.chain_starts,
|
||||
"chain_ends": self.chain_ends,
|
||||
"llm_starts": self.llm_starts,
|
||||
"llm_ends": self.llm_ends,
|
||||
"llm_streams": self.llm_streams,
|
||||
"tool_starts": self.tool_starts,
|
||||
"tool_ends": self.tool_ends,
|
||||
"agent_ends": self.agent_ends,
|
||||
}
|
||||
|
||||
def reset_callback_meta(self) -> None:
|
||||
"""Reset the callback metadata."""
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
self.on_llm_start_records = []
|
||||
self.on_llm_token_records = []
|
||||
self.on_llm_end_records = []
|
||||
|
||||
self.on_chain_start_records = []
|
||||
self.on_chain_end_records = []
|
||||
|
||||
self.on_tool_start_records = []
|
||||
self.on_tool_end_records = []
|
||||
|
||||
self.on_text_records = []
|
||||
self.on_agent_finish_records = []
|
||||
self.on_agent_action_records = []
|
||||
return None
|
587
libs/community/langchain_community/callbacks/wandb_callback.py
Normal file
587
libs/community/langchain_community/callbacks/wandb_callback.py
Normal file
@@ -0,0 +1,587 @@
|
||||
import json
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
|
||||
|
||||
def import_wandb() -> Any:
|
||||
"""Import the wandb python package and raise an error if it is not installed."""
|
||||
try:
|
||||
import wandb # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the wandb callback manager you need to have the `wandb` python "
|
||||
"package installed. Please install it with `pip install wandb`"
|
||||
)
|
||||
return wandb
|
||||
|
||||
|
||||
def load_json_to_dict(json_path: Union[str, Path]) -> dict:
|
||||
"""Load json file to a dictionary.
|
||||
|
||||
Parameters:
|
||||
json_path (str): The path to the json file.
|
||||
|
||||
Returns:
|
||||
(dict): The dictionary representation of the json file.
|
||||
"""
|
||||
with open(json_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: str,
|
||||
complexity_metrics: bool = True,
|
||||
visualize: bool = True,
|
||||
nlp: Any = None,
|
||||
output_dir: Optional[Union[str, Path]] = None,
|
||||
) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
complexity_metrics (bool): Whether to compute complexity metrics.
|
||||
visualize (bool): Whether to visualize the text.
|
||||
nlp (spacy.lang): The spacy language model to use for visualization.
|
||||
output_dir (str): The directory to save the visualization files to.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the complexity metrics and visualization
|
||||
files serialized in a wandb.Html element.
|
||||
"""
|
||||
resp = {}
|
||||
textstat = import_textstat()
|
||||
wandb = import_wandb()
|
||||
spacy = import_spacy()
|
||||
if complexity_metrics:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(text),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
"text_standard": textstat.text_standard(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if visualize and nlp and output_dir is not None:
|
||||
doc = nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
)
|
||||
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
|
||||
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
||||
|
||||
ent_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="ent", jupyter=False, page=True
|
||||
)
|
||||
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
|
||||
ent_output_path.open("w", encoding="utf-8").write(ent_out)
|
||||
|
||||
text_visualizations = {
|
||||
"dependency_tree": wandb.Html(str(dep_output_path)),
|
||||
"entities": wandb.Html(str(ent_output_path)),
|
||||
}
|
||||
resp.update(text_visualizations)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
|
||||
"""Construct an html element from a prompt and a generation.
|
||||
|
||||
Parameters:
|
||||
prompt (str): The prompt.
|
||||
generation (str): The generation.
|
||||
|
||||
Returns:
|
||||
(wandb.Html): The html element."""
|
||||
wandb = import_wandb()
|
||||
formatted_prompt = prompt.replace("\n", "<br>")
|
||||
formatted_generation = generation.replace("\n", "<br>")
|
||||
|
||||
return wandb.Html(
|
||||
f"""
|
||||
<p style="color:black;">{formatted_prompt}:</p>
|
||||
<blockquote>
|
||||
<p style="color:green;">
|
||||
{formatted_generation}
|
||||
</p>
|
||||
</blockquote>
|
||||
""",
|
||||
inject=False,
|
||||
)
|
||||
|
||||
|
||||
class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Weights and Biases.
|
||||
|
||||
Parameters:
|
||||
job_type (str): The type of job.
|
||||
project (str): The project to log to.
|
||||
entity (str): The entity to log to.
|
||||
tags (list): The tags to log.
|
||||
group (str): The group to log to.
|
||||
name (str): The name of the run.
|
||||
notes (str): The notes to log.
|
||||
visualize (bool): Whether to visualize the run.
|
||||
complexity_metrics (bool): Whether to log complexity metrics.
|
||||
stream_logs (bool): Whether to stream callback actions to W&B
|
||||
|
||||
This handler will utilize the associated callback method called and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run,
|
||||
and adds the response to the list of records for both the {method}_records and
|
||||
action. It then logs the response using the run.log() method to Weights and Biases.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job_type: Optional[str] = None,
|
||||
project: Optional[str] = "langchain_callback_demo",
|
||||
entity: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
group: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
notes: Optional[str] = None,
|
||||
visualize: bool = False,
|
||||
complexity_metrics: bool = False,
|
||||
stream_logs: bool = False,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
wandb = import_wandb()
|
||||
import_pandas()
|
||||
import_textstat()
|
||||
spacy = import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.job_type = job_type
|
||||
self.project = project
|
||||
self.entity = entity
|
||||
self.tags = tags
|
||||
self.group = group
|
||||
self.name = name
|
||||
self.notes = notes
|
||||
self.visualize = visualize
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.stream_logs = stream_logs
|
||||
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.run: wandb.sdk.wandb_run.Run = wandb.init( # type: ignore
|
||||
job_type=self.job_type,
|
||||
project=self.project,
|
||||
entity=self.entity,
|
||||
tags=self.tags,
|
||||
group=self.group,
|
||||
name=self.name,
|
||||
notes=self.notes,
|
||||
)
|
||||
warning = (
|
||||
"DEPRECATION: The `WandbCallbackHandler` will soon be deprecated in favor "
|
||||
"of the `WandbTracer`. Please update your code to use the `WandbTracer` "
|
||||
"instead."
|
||||
)
|
||||
wandb.termwarn(
|
||||
warning,
|
||||
repeat=False,
|
||||
)
|
||||
self.callback_columns: list = []
|
||||
self.action_records: list = []
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.visualize = visualize
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
def _init_resp(self) -> Dict:
|
||||
return {k: None for k in self.callback_columns}
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompts"] = prompt
|
||||
self.on_llm_start_records.append(prompt_resp)
|
||||
self.action_records.append(prompt_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(prompt_resp)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_llm_token_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
generation_resp.update(
|
||||
analyze_text(
|
||||
generation.text,
|
||||
complexity_metrics=self.complexity_metrics,
|
||||
visualize=self.visualize,
|
||||
nlp=self.nlp,
|
||||
output_dir=self.temp_dir.name,
|
||||
)
|
||||
)
|
||||
self.on_llm_end_records.append(generation_resp)
|
||||
self.action_records.append(generation_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(generation_resp)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
chain_input = inputs["input"]
|
||||
|
||||
if isinstance(chain_input, str):
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["input"] = chain_input
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(input_resp)
|
||||
elif isinstance(chain_input, list):
|
||||
for inp in chain_input:
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp.update(inp)
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(input_resp)
|
||||
else:
|
||||
raise ValueError("Unexpected data format provided!")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_chain_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_start_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_text_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_agent_finish_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
self.on_agent_action_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_input_prompts_df = (
|
||||
on_llm_start_records_df[["step", "prompts", "name"]]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "prompt_step"}, axis=1)
|
||||
)
|
||||
complexity_metrics_columns = []
|
||||
visualizations_columns = []
|
||||
|
||||
if self.complexity_metrics:
|
||||
complexity_metrics_columns = [
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"smog_index",
|
||||
"coleman_liau_index",
|
||||
"automated_readability_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"linsear_write_formula",
|
||||
"gunning_fog",
|
||||
"text_standard",
|
||||
"fernandez_huerta",
|
||||
"szigriszt_pazos",
|
||||
"gutierrez_polini",
|
||||
"crawford",
|
||||
"gulpease_index",
|
||||
"osman",
|
||||
]
|
||||
|
||||
if self.visualize:
|
||||
visualizations_columns = ["dependency_tree", "entities"]
|
||||
|
||||
llm_outputs_df = (
|
||||
on_llm_end_records_df[
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns
|
||||
]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "output_step", "text": "output"}, axis=1)
|
||||
)
|
||||
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
||||
session_analysis_df["chat_html"] = session_analysis_df[
|
||||
["prompts", "output"]
|
||||
].apply(
|
||||
lambda row: construct_html_from_prompt_and_generation(
|
||||
row["prompts"], row["output"]
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
return session_analysis_df
|
||||
|
||||
def flush_tracker(
|
||||
self,
|
||||
langchain_asset: Any = None,
|
||||
reset: bool = True,
|
||||
finish: bool = False,
|
||||
job_type: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
entity: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
group: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
notes: Optional[str] = None,
|
||||
visualize: Optional[bool] = None,
|
||||
complexity_metrics: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Flush the tracker and reset the session.
|
||||
|
||||
Args:
|
||||
langchain_asset: The langchain asset to save.
|
||||
reset: Whether to reset the session.
|
||||
finish: Whether to finish the run.
|
||||
job_type: The job type.
|
||||
project: The project.
|
||||
entity: The entity.
|
||||
tags: The tags.
|
||||
group: The group.
|
||||
name: The name.
|
||||
notes: The notes.
|
||||
visualize: Whether to visualize.
|
||||
complexity_metrics: Whether to compute complexity metrics.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pd = import_pandas()
|
||||
wandb = import_wandb()
|
||||
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
|
||||
session_analysis_table = wandb.Table(
|
||||
dataframe=self._create_session_analysis_df()
|
||||
)
|
||||
self.run.log(
|
||||
{
|
||||
"action_records": action_records_table,
|
||||
"session_analysis": session_analysis_table,
|
||||
}
|
||||
)
|
||||
|
||||
if langchain_asset:
|
||||
langchain_asset_path = Path(self.temp_dir.name, "model.json")
|
||||
model_artifact = wandb.Artifact(name="model", type="model")
|
||||
model_artifact.add(action_records_table, name="action_records")
|
||||
model_artifact.add(session_analysis_table, name="session_analysis")
|
||||
try:
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
except ValueError:
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
except NotImplementedError as e:
|
||||
print("Could not save model.")
|
||||
print(repr(e))
|
||||
pass
|
||||
self.run.log_artifact(model_artifact)
|
||||
|
||||
if finish or reset:
|
||||
self.run.finish()
|
||||
self.temp_dir.cleanup()
|
||||
self.reset_callback_meta()
|
||||
if reset:
|
||||
self.__init__( # type: ignore
|
||||
job_type=job_type if job_type else self.job_type,
|
||||
project=project if project else self.project,
|
||||
entity=entity if entity else self.entity,
|
||||
tags=tags if tags else self.tags,
|
||||
group=group if group else self.group,
|
||||
name=name if name else self.name,
|
||||
notes=notes if notes else self.notes,
|
||||
visualize=visualize if visualize else self.visualize,
|
||||
complexity_metrics=complexity_metrics
|
||||
if complexity_metrics
|
||||
else self.complexity_metrics,
|
||||
)
|
192
libs/community/langchain_community/callbacks/whylabs_callback.py
Normal file
192
libs/community/langchain_community/callbacks/whylabs_callback.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from whylogs.api.logger.logger import Logger
|
||||
|
||||
diagnostic_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_langkit(
|
||||
sentiment: bool = False,
|
||||
toxicity: bool = False,
|
||||
themes: bool = False,
|
||||
) -> Any:
|
||||
"""Import the langkit python package and raise an error if it is not installed.
|
||||
|
||||
Args:
|
||||
sentiment: Whether to import the langkit.sentiment module. Defaults to False.
|
||||
toxicity: Whether to import the langkit.toxicity module. Defaults to False.
|
||||
themes: Whether to import the langkit.themes module. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The imported langkit module.
|
||||
"""
|
||||
try:
|
||||
import langkit # noqa: F401
|
||||
import langkit.regexes # noqa: F401
|
||||
import langkit.textstat # noqa: F401
|
||||
|
||||
if sentiment:
|
||||
import langkit.sentiment # noqa: F401
|
||||
if toxicity:
|
||||
import langkit.toxicity # noqa: F401
|
||||
if themes:
|
||||
import langkit.themes # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the whylabs callback manager you need to have the `langkit` python "
|
||||
"package installed. Please install it with `pip install langkit`."
|
||||
)
|
||||
return langkit
|
||||
|
||||
|
||||
class WhyLabsCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Callback Handler for logging to WhyLabs. This callback handler utilizes
|
||||
`langkit` to extract features from the prompts & responses when interacting with
|
||||
an LLM. These features can be used to guardrail, evaluate, and observe interactions
|
||||
over time to detect issues relating to hallucinations, prompt engineering,
|
||||
or output validation. LangKit is an LLM monitoring toolkit developed by WhyLabs.
|
||||
|
||||
Here are some examples of what can be monitored with LangKit:
|
||||
* Text Quality
|
||||
- readability score
|
||||
- complexity and grade scores
|
||||
* Text Relevance
|
||||
- Similarity scores between prompt/responses
|
||||
- Similarity scores against user-defined themes
|
||||
- Topic classification
|
||||
* Security and Privacy
|
||||
- patterns - count of strings matching a user-defined regex pattern group
|
||||
- jailbreaks - similarity scores with respect to known jailbreak attempts
|
||||
- prompt injection - similarity scores with respect to known prompt attacks
|
||||
- refusals - similarity scores with respect to known LLM refusal responses
|
||||
* Sentiment and Toxicity
|
||||
- sentiment analysis
|
||||
- toxicity analysis
|
||||
|
||||
For more information, see https://docs.whylabs.ai/docs/language-model-monitoring
|
||||
or check out the LangKit repo here: https://github.com/whylabs/langkit
|
||||
|
||||
---
|
||||
Args:
|
||||
api_key (Optional[str]): WhyLabs API key. Optional because the preferred
|
||||
way to specify the API key is with environment variable
|
||||
WHYLABS_API_KEY.
|
||||
org_id (Optional[str]): WhyLabs organization id to write profiles to.
|
||||
Optional because the preferred way to specify the organization id is
|
||||
with environment variable WHYLABS_DEFAULT_ORG_ID.
|
||||
dataset_id (Optional[str]): WhyLabs dataset id to write profiles to.
|
||||
Optional because the preferred way to specify the dataset id is
|
||||
with environment variable WHYLABS_DEFAULT_DATASET_ID.
|
||||
sentiment (bool): Whether to enable sentiment analysis. Defaults to False.
|
||||
toxicity (bool): Whether to enable toxicity analysis. Defaults to False.
|
||||
themes (bool): Whether to enable theme analysis. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: Logger, handler: Any):
|
||||
"""Initiate the rolling logger."""
|
||||
super().__init__()
|
||||
if hasattr(handler, "init"):
|
||||
handler.init(self)
|
||||
if hasattr(handler, "_get_callbacks"):
|
||||
self._callbacks = handler._get_callbacks()
|
||||
else:
|
||||
self._callbacks = dict()
|
||||
diagnostic_logger.warning("initialized handler without callbacks.")
|
||||
self._logger = logger
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Explicitly write current profile if using a rolling logger."""
|
||||
if self._logger and hasattr(self._logger, "_do_rollover"):
|
||||
self._logger._do_rollover()
|
||||
diagnostic_logger.info("Flushing WhyLabs logger, writing profile...")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close any loggers to allow writing out of any profiles before exiting."""
|
||||
if self._logger and hasattr(self._logger, "close"):
|
||||
self._logger.close()
|
||||
diagnostic_logger.info("Closing WhyLabs logger, see you next time!")
|
||||
|
||||
def __enter__(self) -> WhyLabsCallbackHandler:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, exception_type: Any, exception_value: Any, traceback: Any
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
sentiment: bool = False,
|
||||
toxicity: bool = False,
|
||||
themes: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
) -> WhyLabsCallbackHandler:
|
||||
"""Instantiate whylogs Logger from params.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): WhyLabs API key. Optional because the preferred
|
||||
way to specify the API key is with environment variable
|
||||
WHYLABS_API_KEY.
|
||||
org_id (Optional[str]): WhyLabs organization id to write profiles to.
|
||||
If not set must be specified in environment variable
|
||||
WHYLABS_DEFAULT_ORG_ID.
|
||||
dataset_id (Optional[str]): The model or dataset this callback is gathering
|
||||
telemetry for. If not set must be specified in environment variable
|
||||
WHYLABS_DEFAULT_DATASET_ID.
|
||||
sentiment (bool): If True will initialize a model to perform
|
||||
sentiment analysis compound score. Defaults to False and will not gather
|
||||
this metric.
|
||||
toxicity (bool): If True will initialize a model to score
|
||||
toxicity. Defaults to False and will not gather this metric.
|
||||
themes (bool): If True will initialize a model to calculate
|
||||
distance to configured themes. Defaults to None and will not gather this
|
||||
metric.
|
||||
logger (Optional[Logger]): If specified will bind the configured logger as
|
||||
the telemetry gathering agent. Defaults to LangKit schema with periodic
|
||||
WhyLabs writer.
|
||||
"""
|
||||
# langkit library will import necessary whylogs libraries
|
||||
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
|
||||
|
||||
import whylogs as why
|
||||
from langkit.callback_handler import get_callback_instance
|
||||
from whylogs.api.writer.whylabs import WhyLabsWriter
|
||||
from whylogs.experimental.core.udf_schema import udf_schema
|
||||
|
||||
if logger is None:
|
||||
api_key = api_key or get_from_env("api_key", "WHYLABS_API_KEY")
|
||||
org_id = org_id or get_from_env("org_id", "WHYLABS_DEFAULT_ORG_ID")
|
||||
dataset_id = dataset_id or get_from_env(
|
||||
"dataset_id", "WHYLABS_DEFAULT_DATASET_ID"
|
||||
)
|
||||
whylabs_writer = WhyLabsWriter(
|
||||
api_key=api_key, org_id=org_id, dataset_id=dataset_id
|
||||
)
|
||||
|
||||
whylabs_logger = why.logger(
|
||||
mode="rolling", interval=5, when="M", schema=udf_schema()
|
||||
)
|
||||
|
||||
whylabs_logger.append_writer(writer=whylabs_writer)
|
||||
else:
|
||||
diagnostic_logger.info("Using passed in whylogs logger {logger}")
|
||||
whylabs_logger = logger
|
||||
|
||||
callback_handler_cls = get_callback_instance(logger=whylabs_logger, impl=cls)
|
||||
diagnostic_logger.info(
|
||||
"Started whylogs Logger with WhyLabsWriter and initialized LangKit. 📝"
|
||||
)
|
||||
return callback_handler_cls
|
Reference in New Issue
Block a user