community[minor] : adds callback handler for Fiddler AI (#17708)

**Description:**  Callback handler to integrate fiddler with langchain. 
This PR adds the following -

1. `FiddlerCallbackHandler` implementation into langchain/community
2. Example notebook `fiddler.ipynb` for usage documentation

[Internal Tracker : FDL-14305]

**Issue:** 
NA

**Dependencies:** 
- Installation of langchain-community is unaffected.
- Usage of FiddlerCallbackHandler requires installation of latest
fiddler-client (2.5+)

**Twitter handle:** @fiddlerlabs @behalder

Co-authored-by: Barun Halder <barun@fiddler.ai>
This commit is contained in:
Barun Amalkumar Halder 2024-02-25 18:17:03 -08:00 committed by GitHub
parent b8b5ce0c8c
commit cc69976860
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 496 additions and 0 deletions

View File

@ -0,0 +1,215 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0cebf93b",
"metadata": {},
"source": [
"## Fiddler Langchain integration Quick Start Guide\n",
"\n",
"Fiddler is the pioneer in enterprise Generative and Predictive system ops, offering a unified platform that enables Data Science, MLOps, Risk, Compliance, Analytics, and other LOB teams to monitor, explain, analyze, and improve ML deployments at enterprise scale. "
]
},
{
"cell_type": "markdown",
"id": "38d746c2",
"metadata": {},
"source": [
"## 1. Installation and Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0151955",
"metadata": {},
"outputs": [],
"source": [
"# langchain>=0.1.7 langchain-openai fiddler-client"
]
},
{
"cell_type": "markdown",
"id": "5662f2e5-d510-4eef-b44b-fa929e5b4ad4",
"metadata": {},
"source": [
"## 2. Fiddler connection details "
]
},
{
"cell_type": "markdown",
"id": "64fac323",
"metadata": {},
"source": [
"*Before you can add information about your model with Fiddler*\n",
"\n",
"1. The URL you're using to connect to Fiddler\n",
"2. Your organization ID\n",
"3. Your authorization token\n",
"\n",
"These can be found by navigating to the *Settings* page of your Fiddler environment."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6f8b73e-d350-40f0-b7a4-fb1e68a65a22",
"metadata": {},
"outputs": [],
"source": [
"URL = \"\" # Your Fiddler instance URL, Make sure to include the full URL (including https://). For example: https://demo.fiddler.ai\n",
"ORG_NAME = \"\"\n",
"AUTH_TOKEN = \"\" # Your Fiddler instance auth token\n",
"\n",
"# Fiddler project and model names, used for model registration\n",
"PROJECT_NAME = \"\"\n",
"MODEL_NAME = \"\" # Model name in Fiddler"
]
},
{
"cell_type": "markdown",
"id": "0645805a",
"metadata": {},
"source": [
"## 3. Create a fiddler callback handler instance"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13de4f9a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks.fiddler_callback import FiddlerCallback\n",
"\n",
"fiddler_handler = FiddlerCallback(\n",
" url=URL,\n",
" org_name=ORG_NAME,\n",
" project_name=PROJECT_NAME,\n",
" model_name=MODEL_NAME,\n",
" auth_token=AUTH_TOKEN,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2276368e-f1dc-46be-afe3-18796e7a66f2",
"metadata": {},
"source": [
"## Example 1 : Basic Chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9de0fd1",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_openai import OpenAI\n",
"\n",
"# Note : Make sure openai API key is set in the environment variable OPENAI_API_KEY\n",
"llm = OpenAI(temperature=0, streaming=True, callbacks=[fiddler_handler])\n",
"output_parser = StrOutputParser()\n",
"\n",
"chain = llm | output_parser\n",
"\n",
"# Invoke the chain. Invocation will be logged to Fiddler, and metrics automatically generated\n",
"chain.invoke(\"How far is moon from earth?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "309bde0b-e1ce-446c-98ac-3690c26a2676",
"metadata": {},
"outputs": [],
"source": [
"# Few more invocations\n",
"chain.invoke(\"What is the temperature on Mars?\")\n",
"chain.invoke(\"How much is 2 + 200000?\")\n",
"chain.invoke(\"Which movie won the oscars this year?\")\n",
"chain.invoke(\"Can you write me a poem about insomnia?\")\n",
"chain.invoke(\"How are you doing today?\")\n",
"chain.invoke(\"What is the meaning of life?\")"
]
},
{
"cell_type": "markdown",
"id": "48fa4782-c867-4510-9430-4ffa3de3b5eb",
"metadata": {},
"source": [
"## Example 2 : Chain with prompt templates"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2aa2c220-8946-4844-8d3c-8f69d744d13f",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import (\n",
" ChatPromptTemplate,\n",
" FewShotChatMessagePromptTemplate,\n",
")\n",
"\n",
"examples = [\n",
" {\"input\": \"2+2\", \"output\": \"4\"},\n",
" {\"input\": \"2+3\", \"output\": \"5\"},\n",
"]\n",
"\n",
"example_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"human\", \"{input}\"),\n",
" (\"ai\", \"{output}\"),\n",
" ]\n",
")\n",
"\n",
"few_shot_prompt = FewShotChatMessagePromptTemplate(\n",
" example_prompt=example_prompt,\n",
" examples=examples,\n",
")\n",
"\n",
"final_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You are a wondrous wizard of math.\"),\n",
" few_shot_prompt,\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"# Note : Make sure openai API key is set in the environment variable OPENAI_API_KEY\n",
"llm = OpenAI(temperature=0, streaming=True, callbacks=[fiddler_handler])\n",
"\n",
"chain = final_prompt | llm\n",
"\n",
"# Invoke the chain. Invocation will be logged to Fiddler, and metrics automatically generated\n",
"chain.invoke({\"input\": \"What's the square of a triangle?\"})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -14,6 +14,7 @@ from langchain_community.callbacks.arthur_callback import ArthurCallbackHandler
from langchain_community.callbacks.clearml_callback import ClearMLCallbackHandler
from langchain_community.callbacks.comet_ml_callback import CometCallbackHandler
from langchain_community.callbacks.context_callback import ContextCallbackHandler
from langchain_community.callbacks.fiddler_callback import FiddlerCallbackHandler
from langchain_community.callbacks.flyte_callback import FlyteCallbackHandler
from langchain_community.callbacks.human import HumanApprovalCallbackHandler
from langchain_community.callbacks.infino_callback import InfinoCallbackHandler
@ -63,4 +64,5 @@ __all__ = [
"SageMakerCallbackHandler",
"LabelStudioCallbackHandler",
"TrubricsCallbackHandler",
"FiddlerCallbackHandler",
]

View File

@ -0,0 +1,278 @@
import time
from typing import Any, Dict, List
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_community.callbacks.utils import import_pandas
# Define constants
# LLMResult keys
TOKEN_USAGE = "token_usage"
TOTAL_TOKENS = "total_tokens"
PROMPT_TOKENS = "prompt_tokens"
COMPLETION_TOKENS = "completion_tokens"
RUN_ID = "run_id"
MODEL_NAME = "model_name"
# Default values
DEFAULT_MAX_TOKEN = 65536
DEFAULT_MAX_DURATION = 120
# Fiddler specific constants
PROMPT = "prompt"
RESPONSE = "response"
DURATION = "duration"
# Define a dataset dictionary
_dataset_dict = {
PROMPT: ["fiddler"] * 10,
RESPONSE: ["fiddler"] * 10,
MODEL_NAME: ["fiddler"] * 10,
RUN_ID: ["123e4567-e89b-12d3-a456-426614174000"] * 10,
TOTAL_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
PROMPT_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
COMPLETION_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
DURATION: [1, DEFAULT_MAX_DURATION] * 5,
}
def import_fiddler() -> Any:
"""Import the fiddler python package and raise an error if it is not installed."""
try:
import fiddler # noqa: F401
except ImportError:
raise ImportError(
"To use fiddler callback handler you need to have `fiddler-client`"
"package installed. Please install it with `pip install fiddler-client`"
)
return fiddler
# First, define custom callback handler implementations
class FiddlerCallbackHandler(BaseCallbackHandler):
def __init__(
self,
url: str,
org: str,
project: str,
model: str,
api_key: str,
) -> None:
"""
Initialize Fiddler callback handler.
Args:
url: Fiddler URL (e.g. https://demo.fiddler.ai).
Make sure to include the protocol (http/https).
org: Fiddler organization id
project: Fiddler project name to publish events to
model: Fiddler model name to publish events to
api_key: Fiddler authentication token
"""
super().__init__()
# Initialize Fiddler client and other necessary properties
self.fdl = import_fiddler()
self.pd = import_pandas()
self.url = url
self.org = org
self.project = project
self.model = model
self.api_key = api_key
self._df = self.pd.DataFrame(_dataset_dict)
self.run_id_prompts: Dict[str, List[str]] = {}
self.run_id_starttime: Dict[str, int] = {}
# Initialize Fiddler client here
self.fiddler_client = self.fdl.FiddlerApi(url, org_id=org, auth_token=api_key)
if self.project not in self.fiddler_client.get_project_names():
print( # noqa: T201
f"adding project {self.project}." "This only has to be done once."
)
try:
self.fiddler_client.add_project(self.project)
except Exception as e:
print( # noqa: T201
f"Error adding project {self.project}:"
"{e}. Fiddler integration will not work."
)
raise e
dataset_info = self.fdl.DatasetInfo.from_dataframe(
self._df, max_inferred_cardinality=0
)
if self.model not in self.fiddler_client.get_dataset_names(self.project):
print( # noqa: T201
f"adding dataset {self.model} to project {self.project}."
"This only has to be done once."
)
try:
self.fiddler_client.upload_dataset(
project_id=self.project,
dataset_id=self.model,
dataset={"train": self._df},
info=dataset_info,
)
except Exception as e:
print( # noqa: T201
f"Error adding dataset {self.model}: {e}."
"Fiddler integration will not work."
)
raise e
model_info = self.fdl.ModelInfo.from_dataset_info(
dataset_info=dataset_info,
dataset_id="train",
model_task=self.fdl.ModelTask.LLM,
features=[PROMPT, RESPONSE],
metadata_cols=[
RUN_ID,
TOTAL_TOKENS,
PROMPT_TOKENS,
COMPLETION_TOKENS,
MODEL_NAME,
],
custom_features=self.custom_features,
)
if self.model not in self.fiddler_client.get_model_names(self.project):
print( # noqa: T201
f"adding model {self.model} to project {self.project}."
"This only has to be done once." # noqa: T201
)
try:
self.fiddler_client.add_model(
project_id=self.project,
dataset_id=self.model,
model_id=self.model,
model_info=model_info,
)
except Exception as e:
print( # noqa: T201
f"Error adding model {self.model}: {e}."
"Fiddler integration will not work." # noqa: T201
)
raise e
@property
def custom_features(self) -> list:
"""
Define custom features for the model to automatically enrich the data with.
Here, we enable the following enrichments:
- Automatic Embedding generation for prompt and response
- Text Statistics such as:
- Automated Readability Index
- Coleman Liau Index
- Dale Chall Readability Score
- Difficult Words
- Flesch Reading Ease
- Flesch Kincaid Grade
- Gunning Fog
- Linsear Write Formula
- PII - Personal Identifiable Information
- Sentiment Analysis
"""
return [
self.fdl.Enrichment(
name="Prompt Embedding",
enrichment="embedding",
columns=[PROMPT],
),
self.fdl.TextEmbedding(
name="Prompt CF",
source_column=PROMPT,
column="Prompt Embedding",
),
self.fdl.Enrichment(
name="Response Embedding",
enrichment="embedding",
columns=[RESPONSE],
),
self.fdl.TextEmbedding(
name="Response CF",
source_column=RESPONSE,
column="Response Embedding",
),
self.fdl.Enrichment(
name="Text Statistics",
enrichment="textstat",
columns=[PROMPT, RESPONSE],
config={
"statistics": [
"automated_readability_index",
"coleman_liau_index",
"dale_chall_readability_score",
"difficult_words",
"flesch_reading_ease",
"flesch_kincaid_grade",
"gunning_fog",
"linsear_write_formula",
]
},
),
self.fdl.Enrichment(
name="PII",
enrichment="pii",
columns=[PROMPT, RESPONSE],
),
self.fdl.Enrichment(
name="Sentiment",
enrichment="sentiment",
columns=[PROMPT, RESPONSE],
),
]
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
run_id = kwargs[RUN_ID]
self.run_id_prompts[run_id] = prompts
self.run_id_starttime[run_id] = int(time.time())
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
flattened_llmresult = response.flatten()
token_usage_dict = {}
run_id = kwargs[RUN_ID]
run_duration = self.run_id_starttime[run_id] - int(time.time())
prompt_responses = []
model_name = ""
if isinstance(response.llm_output, dict):
if TOKEN_USAGE in response.llm_output:
token_usage_dict = response.llm_output[TOKEN_USAGE]
if MODEL_NAME in response.llm_output:
model_name = response.llm_output[MODEL_NAME]
for llmresult in flattened_llmresult:
prompt_responses.append(llmresult.generations[0][0].text)
df = self.pd.DataFrame(
{
PROMPT: self.run_id_prompts[run_id],
RESPONSE: prompt_responses,
}
)
if TOTAL_TOKENS in token_usage_dict:
df[PROMPT_TOKENS] = int(token_usage_dict[TOTAL_TOKENS])
if PROMPT_TOKENS in token_usage_dict:
df[TOTAL_TOKENS] = int(token_usage_dict[PROMPT_TOKENS])
if COMPLETION_TOKENS in token_usage_dict:
df[COMPLETION_TOKENS] = token_usage_dict[COMPLETION_TOKENS]
df[MODEL_NAME] = model_name
df[RUN_ID] = str(run_id)
df[DURATION] = run_duration
try:
self.fiddler_client.publish_events_batch(self.project, self.model, df)
except Exception as e:
print(f"Error publishing events to fiddler: {e}. continuing...") # noqa: T201

View File

@ -24,6 +24,7 @@ EXPECTED_ALL = [
"SageMakerCallbackHandler",
"LabelStudioCallbackHandler",
"TrubricsCallbackHandler",
"FiddlerCallbackHandler",
]