WhyLabsCallbackHandler updates (#7621)

Updates to the WhyLabsCallbackHandler and example notebook
- Update dependency to langkit 0.0.6 which defines new helper methods
for callback integrations
- Update WhyLabsCallbackHandler to use the new `get_callback_instance`
so that the callback is mostly defined in langkit
- Remove much of the implementation of the WhyLabsCallbackHandler here
in favor of the callback instance

This does not change the behavior of the whylabs callback handler
implementation but is a reorganization that moves some of the
implementation externally to our optional dependency package, and should
make future updates easier.

@agola11
This commit is contained in:
Jamie Broomall 2023-07-12 22:46:56 -05:00 committed by GitHub
parent 53722dcfdc
commit 0e1d7a27c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 64 additions and 129 deletions

View File

@ -1,6 +1,7 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@ -16,6 +17,7 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@ -28,10 +30,11 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install langkit -q" "%pip install langkit openai langchain"
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@ -54,6 +57,7 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -63,6 +67,7 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@ -125,16 +130,7 @@
" ]\n", " ]\n",
")\n", ")\n",
"print(result)\n", "print(result)\n",
"# you don't need to call flush, this will occur periodically, but to demo let's not wait.\n", "# you don't need to call close to write profiles to WhyLabs, upload will occur periodically, but to demo let's not wait.\n",
"whylabs.flush()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"whylabs.close()" "whylabs.close()"
] ]
} }
@ -155,7 +151,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.6" "version": "3.8.10"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

View File

@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
from langchain.utils import get_from_env from langchain.utils import get_from_env
if TYPE_CHECKING: if TYPE_CHECKING:
@ -91,99 +90,29 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
themes (bool): Whether to enable theme analysis. Defaults to False. themes (bool): Whether to enable theme analysis. Defaults to False.
""" """
def __init__(self, logger: Logger): def __init__(self, logger: Logger, handler: Any):
"""Initiate the rolling logger""" """Initiate the rolling logger."""
super().__init__() super().__init__()
self.logger = logger if hasattr(handler, "init"):
diagnostic_logger.info( handler.init(self)
"Initialized WhyLabs callback handler with configured whylogs Logger." if hasattr(handler, "_get_callbacks"):
) self._callbacks = handler._get_callbacks()
else:
def _profile_generations(self, generations: List[Generation]) -> None: self._callbacks = dict()
for gen in generations: diagnostic_logger.warning("initialized handler without callbacks.")
self.logger.log({"response": gen.text}) self._logger = logger
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Pass the input prompts to the logger"""
for prompt in prompts:
self.logger.log({"prompt": prompt})
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Pass the generated response to the logger."""
for generations in response.generations:
self._profile_generations(generations)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing."""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing."""
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Do nothing."""
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing."""
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing."""
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
pass
def flush(self) -> None: def flush(self) -> None:
self.logger._do_rollover() """Explicitly write current profile if using a rolling logger."""
diagnostic_logger.info("Flushing WhyLabs logger, writing profile...") if self._logger and hasattr(self._logger, "_do_rollover"):
self._logger._do_rollover()
diagnostic_logger.info("Flushing WhyLabs logger, writing profile...")
def close(self) -> None: def close(self) -> None:
self.logger.close() """Close any loggers to allow writing out of any profiles before exiting."""
diagnostic_logger.info("Closing WhyLabs logger, see you next time!") if self._logger and hasattr(self._logger, "close"):
self._logger.close()
diagnostic_logger.info("Closing WhyLabs logger, see you next time!")
def __enter__(self) -> WhyLabsCallbackHandler: def __enter__(self) -> WhyLabsCallbackHandler:
return self return self
@ -203,7 +132,8 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
sentiment: bool = False, sentiment: bool = False,
toxicity: bool = False, toxicity: bool = False,
themes: bool = False, themes: bool = False,
) -> Logger: logger: Optional[Logger] = None,
) -> WhyLabsCallbackHandler:
"""Instantiate whylogs Logger from params. """Instantiate whylogs Logger from params.
Args: Args:
@ -224,31 +154,39 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
themes (bool): If True will initialize a model to calculate themes (bool): If True will initialize a model to calculate
distance to configured themes. Defaults to None and will not gather this distance to configured themes. Defaults to None and will not gather this
metric. metric.
logger (Optional[Logger]): If specified will bind the configured logger as
the telemetry gathering agent. Defaults to LangKit schema with periodic
WhyLabs writer.
""" """
# langkit library will import necessary whylogs libraries # langkit library will import necessary whylogs libraries
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes) import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
import whylogs as why import whylogs as why
from langkit.callback_handler import get_callback_instance
from whylogs.api.writer.whylabs import WhyLabsWriter from whylogs.api.writer.whylabs import WhyLabsWriter
from whylogs.core.schema import DeclarativeSchema from whylogs.experimental.core.udf_schema import udf_schema
from whylogs.experimental.core.metrics.udf_metric import generate_udf_schema
api_key = api_key or get_from_env("api_key", "WHYLABS_API_KEY") if logger is None:
org_id = org_id or get_from_env("org_id", "WHYLABS_DEFAULT_ORG_ID") api_key = api_key or get_from_env("api_key", "WHYLABS_API_KEY")
dataset_id = dataset_id or get_from_env( org_id = org_id or get_from_env("org_id", "WHYLABS_DEFAULT_ORG_ID")
"dataset_id", "WHYLABS_DEFAULT_DATASET_ID" dataset_id = dataset_id or get_from_env(
) "dataset_id", "WHYLABS_DEFAULT_DATASET_ID"
whylabs_writer = WhyLabsWriter( )
api_key=api_key, org_id=org_id, dataset_id=dataset_id whylabs_writer = WhyLabsWriter(
) api_key=api_key, org_id=org_id, dataset_id=dataset_id
)
langkit_schema = DeclarativeSchema(generate_udf_schema()) whylabs_logger = why.logger(
whylabs_logger = why.logger( mode="rolling", interval=5, when="M", schema=udf_schema()
mode="rolling", interval=5, when="M", schema=langkit_schema )
)
whylabs_logger.append_writer(writer=whylabs_writer) whylabs_logger.append_writer(writer=whylabs_writer)
else:
diagnostic_logger.info("Using passed in whylogs logger {logger}")
whylabs_logger = logger
callback_handler_cls = get_callback_instance(logger=whylabs_logger, impl=cls)
diagnostic_logger.info( diagnostic_logger.info(
"Started whylogs Logger with WhyLabsWriter and initialized LangKit. 📝" "Started whylogs Logger with WhyLabsWriter and initialized LangKit. 📝"
) )
return cls(whylabs_logger) return callback_handler_cls

19
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@ -4382,6 +4382,7 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [ files = [
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
] ]
[[package]] [[package]]
@ -4708,20 +4709,20 @@ data = ["language-data (>=1.1,<2.0)"]
[[package]] [[package]]
name = "langkit" name = "langkit"
version = "0.0.1" version = "0.0.6"
description = "A collection of text metric udfs for whylogs profiling and monitoring in WhyLabs" description = "A collection of text metric udfs for whylogs profiling and monitoring in WhyLabs"
category = "main" category = "main"
optional = true optional = true
python-versions = ">=3.8,<4.0" python-versions = ">=3.8,<4.0"
files = [ files = [
{file = "langkit-0.0.1-py3-none-any.whl", hash = "sha256:361a593cafd1611d054dd92dda8c3f5532232e3465e88ab32347232078f8ccd3"}, {file = "langkit-0.0.6-py3-none-any.whl", hash = "sha256:5b36830e9094934c933f8756177b5a8a5c7d6dc014ca49076a358c9c8fb5ddbc"},
{file = "langkit-0.0.1.tar.gz", hash = "sha256:d5ed28f21d6f641208f5da5f7d7ffd8a536b018868a004fb1fad54b83091e985"}, {file = "langkit-0.0.6.tar.gz", hash = "sha256:08421bb0799fc831b0d1e431e600cad8acab7d7bdbf6aa6c7535291172a66343"},
] ]
[package.dependencies] [package.dependencies]
pandas = "*" pandas = "*"
textstat = ">=0.7.3,<0.8.0" textstat = ">=0.7.3,<0.8.0"
whylogs = "1.1.45.dev6" whylogs = ">=1.2.3,<2.0.0"
[package.extras] [package.extras]
all = ["datasets (>=2.12.0,<3.0.0)", "nltk (>=3.8.1,<4.0.0)", "openai (>=0.27.6,<0.28.0)", "sentence-transformers (>=2.2.2,<3.0.0)", "torch"] all = ["datasets (>=2.12.0,<3.0.0)", "nltk (>=3.8.1,<4.0.0)", "openai (>=0.27.6,<0.28.0)", "sentence-transformers (>=2.2.2,<3.0.0)", "torch"]
@ -12166,14 +12167,14 @@ urllib3 = ">=1.25.3"
[[package]] [[package]]
name = "whylogs" name = "whylogs"
version = "1.1.45.dev6" version = "1.2.3"
description = "Profile and monitor your ML data pipeline end-to-end" description = "Profile and monitor your ML data pipeline end-to-end"
category = "main" category = "main"
optional = true optional = true
python-versions = ">=3.7.1,<4" python-versions = ">=3.7.1,<4"
files = [ files = [
{file = "whylogs-1.1.45.dev6-py3-none-any.whl", hash = "sha256:8d5a96ecf7b181b5034168b7ba979d463a2dd33a8b7a728ffd911b493de89857"}, {file = "whylogs-1.2.3-py3-none-any.whl", hash = "sha256:92cfe02985760c52d25b88bad69001901844ff51c76b62537bce1c31d12c271e"},
{file = "whylogs-1.1.45.dev6.tar.gz", hash = "sha256:a4b81bea7adf407de200ae6e8c7b53316b3304b3922ae4518a0f4234086db7ba"}, {file = "whylogs-1.2.3.tar.gz", hash = "sha256:d0000f502b1b30c48a5ad9535488370e961e85825dafdd75421447ffff0516e7"},
] ]
[package.dependencies] [package.dependencies]
@ -12718,4 +12719,4 @@ text-helpers = ["chardet"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "e700e2ae2c9a9f7f6efd3bfbec6063650864d45bf8439ebfd14dcf0683d0f17a" content-hash = "a77a3b8ac071e8ae9cd4004e577dbe4fd39552a69adb3277b06ab91f3fd0c77b"

View File

@ -92,7 +92,7 @@ pandas = {version = "^2.0.1", optional = true}
telethon = {version = "^1.28.5", optional = true} telethon = {version = "^1.28.5", optional = true}
neo4j = {version = "^5.8.1", optional = true} neo4j = {version = "^5.8.1", optional = true}
zep-python = {version=">=0.32", optional=true} zep-python = {version=">=0.32", optional=true}
langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true} langkit = {version = ">=0.0.6, <0.1.0", optional = true}
chardet = {version="^5.1.0", optional=true} chardet = {version="^5.1.0", optional=true}
requests-toolbelt = {version = "^1.0.0", optional = true} requests-toolbelt = {version = "^1.0.0", optional = true}
openlm = {version = "^0.0.5", optional = true} openlm = {version = "^0.0.5", optional = true}