From 2928b080f60fdbdcf76f7b0f08e8c845da79bd71 Mon Sep 17 00:00:00 2001 From: Hakan Tekgul Date: Mon, 26 Jun 2023 16:49:46 -0700 Subject: [PATCH] Update arize_callback.py - bug fix (#6784) - Description: Bug Fix - Added a step variable to keep track of prompts - Issue: Bug from internal Arize testing - The prompts and responses that are ingested were not mapped correctly - Dependencies: N/A --- langchain/callbacks/arize_callback.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/langchain/callbacks/arize_callback.py b/langchain/callbacks/arize_callback.py index 7e1196e7c0d..62f952588a9 100644 --- a/langchain/callbacks/arize_callback.py +++ b/langchain/callbacks/arize_callback.py @@ -1,4 +1,3 @@ -import uuid from datetime import datetime from typing import Any, Dict, List, Optional, Union @@ -33,6 +32,7 @@ class ArizeCallbackHandler(BaseCallbackHandler): self.prompt_tokens = 0 self.completion_tokens = 0 self.total_tokens = 0 + self.step = 0 from arize.pandas.embeddings import EmbeddingGenerator, UseCases from arize.pandas.logger import Client @@ -84,11 +84,10 @@ class ArizeCallbackHandler(BaseCallbackHandler): self.total_tokens ) = self.completion_tokens = 0 # assign default value - i = 0 - for generations in response.generations: for generation in generations: - prompt = self.prompt_records[i] + prompt = self.prompt_records[self.step] + self.step = self.step + 1 prompt_embedding = pd.Series( self.generator.generate_embeddings( text_col=pd.Series(prompt.replace("\n", " ")) @@ -102,7 +101,6 @@ class ArizeCallbackHandler(BaseCallbackHandler): text_col=pd.Series(generation.text.replace("\n", " ")) ).reset_index(drop=True) ) - str(uuid.uuid4()) pred_timestamp = datetime.now().timestamp() # Define the columns and data @@ -165,8 +163,6 @@ class ArizeCallbackHandler(BaseCallbackHandler): else: print(f'❌ Logging failed "{response_from_arize.text}"') - i = i + 1 - def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: