mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Update arize_callback.py - bug fix (#6784)
- Description: Bug Fix - Added a step variable to keep track of prompts - Issue: Bug from internal Arize testing - The prompts and responses that are ingested were not mapped correctly - Dependencies: N/A
This commit is contained in:
parent
c460b04c64
commit
2928b080f6
@ -1,4 +1,3 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@ -33,6 +32,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.step = 0
|
||||
|
||||
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
|
||||
from arize.pandas.logger import Client
|
||||
@ -84,11 +84,10 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
self.total_tokens
|
||||
) = self.completion_tokens = 0 # assign default value
|
||||
|
||||
i = 0
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
prompt = self.prompt_records[i]
|
||||
prompt = self.prompt_records[self.step]
|
||||
self.step = self.step + 1
|
||||
prompt_embedding = pd.Series(
|
||||
self.generator.generate_embeddings(
|
||||
text_col=pd.Series(prompt.replace("\n", " "))
|
||||
@ -102,7 +101,6 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
text_col=pd.Series(generation.text.replace("\n", " "))
|
||||
).reset_index(drop=True)
|
||||
)
|
||||
str(uuid.uuid4())
|
||||
pred_timestamp = datetime.now().timestamp()
|
||||
|
||||
# Define the columns and data
|
||||
@ -165,8 +163,6 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
else:
|
||||
print(f'❌ Logging failed "{response_from_arize.text}"')
|
||||
|
||||
i = i + 1
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user