This commit is contained in:
Ankush Gola
2023-06-10 16:33:10 -07:00
parent 8af3534170
commit b3475d6b50
3 changed files with 18 additions and 3 deletions

View File

@@ -1,8 +1,8 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import LLMResult
MODEL_COST_PER_1K_TOKENS = {
"gpt-4": 0.03,

View File

@@ -200,7 +200,7 @@ class LLMResult(BaseModel):
else:
if self.llm_output is not None:
llm_output = self.llm_output.copy()
llm_output["token_usage"] = None
llm_output["token_usage"] = dict()
else:
llm_output = None
llm_results.append(

View File

@@ -38,6 +38,21 @@ async def test_openai_callback() -> None:
assert cb.total_tokens == total_tokens
def test_openai_callback_batch_llm() -> None:
llm = OpenAI(temperature=0)
with get_openai_callback() as cb:
llm.generate(["What is the square root of 4?", "What is the square root of 4?"])
assert cb.total_tokens > 0
total_tokens = cb.total_tokens
with get_openai_callback() as cb:
llm("What is the square root of 4?")
llm("What is the square root of 4?")
assert cb.total_tokens == total_tokens
def test_openai_callback_agent() -> None:
llm = OpenAI(temperature=0)
tools = load_tools(["serpapi", "llm-math"], llm=llm)