langchain/libs/community/langchain_community/callbacks/arize_callback.py
Bagatur a0c2281540
infra: update mypy 1.10, ruff 0.5 (#23721)
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path

import toml
import subprocess
import re

ROOT_DIR = Path(__file__).parents[1]


def main():
    for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
        print(path)
        with open(path, "rb") as f:
            pyproject = tomllib.load(f)
        try:
            pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
                "^1.10"
            )
            pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
                "^0.5"
            )
        except KeyError:
            continue
        with open(path, "w") as f:
            toml.dump(pyproject, f)
        cwd = "/".join(path.split("/")[:-1])
        completed = subprocess.run(
            "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )
        logs = completed.stdout.split("\n")

        to_ignore = {}
        for l in logs:
            if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
                path, line_no, error_type = re.match(
                    "^(.*)\:(\d+)\: error:.*\[(.*)\]", l
                ).groups()
                if (path, line_no) in to_ignore:
                    to_ignore[(path, line_no)].append(error_type)
                else:
                    to_ignore[(path, line_no)] = [error_type]
        print(len(to_ignore))
        for (error_path, line_no), error_types in to_ignore.items():
            all_errors = ", ".join(error_types)
            full_path = f"{cwd}/{error_path}"
            try:
                with open(full_path, "r") as f:
                    file_lines = f.readlines()
            except FileNotFoundError:
                continue
            file_lines[int(line_no) - 1] = (
                file_lines[int(line_no) - 1][:-1] + f"  # type: ignore[{all_errors}]\n"
            )
            with open(full_path, "w") as f:
                f.write("".join(file_lines))

        subprocess.run(
            "poetry run ruff format .; poetry run ruff --select I --fix .",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )


if __name__ == "__main__":
    main()

```
2024-07-03 10:33:27 -07:00

214 lines
7.3 KiB
Python

from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_community.callbacks.utils import import_pandas
class ArizeCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Arize."""
def __init__(
self,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
SPACE_KEY: Optional[str] = None,
API_KEY: Optional[str] = None,
) -> None:
"""Initialize callback handler."""
super().__init__()
self.model_id = model_id
self.model_version = model_version
self.space_key = SPACE_KEY
self.api_key = API_KEY
self.prompt_records: List[str] = []
self.response_records: List[str] = []
self.prediction_ids: List[str] = []
self.pred_timestamps: List[int] = []
self.response_embeddings: List[float] = []
self.prompt_embeddings: List[float] = []
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
self.step = 0
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
from arize.pandas.logger import Client
self.generator = EmbeddingGenerator.from_use_case(
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
model_name="distilbert-base-uncased",
tokenizer_max_length=512,
batch_size=256,
)
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
raise ValueError("❌ CHANGE SPACE AND API KEYS")
else:
print("✅ Arize client setup done! Now you can start using Arize!") # noqa: T201
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
for prompt in prompts:
self.prompt_records.append(prompt.replace("\n", ""))
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pd = import_pandas()
from arize.utils.types import (
EmbeddingColumnNames,
Environments,
ModelTypes,
Schema,
)
# Safe check if 'llm_output' and 'token_usage' exist
if response.llm_output and "token_usage" in response.llm_output:
self.prompt_tokens = response.llm_output["token_usage"].get(
"prompt_tokens", 0
)
self.total_tokens = response.llm_output["token_usage"].get(
"total_tokens", 0
)
self.completion_tokens = response.llm_output["token_usage"].get(
"completion_tokens", 0
)
else:
self.prompt_tokens = self.total_tokens = self.completion_tokens = (
0 # assign default value
)
for generations in response.generations:
for generation in generations:
prompt = self.prompt_records[self.step]
self.step = self.step + 1
prompt_embedding = pd.Series(
self.generator.generate_embeddings(
text_col=pd.Series(prompt.replace("\n", " "))
).reset_index(drop=True)
)
# Assigning text to response_text instead of response
response_text = generation.text.replace("\n", " ")
response_embedding = pd.Series(
self.generator.generate_embeddings(
text_col=pd.Series(generation.text.replace("\n", " "))
).reset_index(drop=True)
)
pred_timestamp = datetime.now().timestamp()
# Define the columns and data
columns = [
"prediction_ts",
"response",
"prompt",
"response_vector",
"prompt_vector",
"prompt_token",
"completion_token",
"total_token",
]
data = [
[
pred_timestamp,
response_text,
prompt,
response_embedding[0],
prompt_embedding[0],
self.prompt_tokens,
self.total_tokens,
self.completion_tokens,
]
]
# Create the DataFrame
df = pd.DataFrame(data, columns=columns)
# Declare prompt and response columns
prompt_columns = EmbeddingColumnNames(
vector_column_name="prompt_vector", data_column_name="prompt"
)
response_columns = EmbeddingColumnNames(
vector_column_name="response_vector", data_column_name="response"
)
schema = Schema(
timestamp_column_name="prediction_ts",
tag_column_names=[
"prompt_token",
"completion_token",
"total_token",
],
prompt_column_names=prompt_columns,
response_column_names=response_columns,
)
response_from_arize = self.arize_client.log(
dataframe=df,
schema=schema,
model_id=self.model_id,
model_version=self.model_version,
model_type=ModelTypes.GENERATIVE_LLM,
environment=Environments.PRODUCTION,
)
if response_from_arize.status_code == 200:
print("✅ Successfully logged data to Arize!") # noqa: T201
else:
print(f'❌ Logging failed "{response_from_arize.text}"') # noqa: T201
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing."""
pass
def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
pass
def on_text(self, text: str, **kwargs: Any) -> None:
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
pass