Compare commits

...

18 Commits

Author SHA1 Message Date
Ankush Gola
3d2578303b lint 2023-06-10 17:12:34 -07:00
Ankush Gola
0083ff9c7d change _agenerate in to use gather 2023-06-10 16:41:43 -07:00
Ankush Gola
b3475d6b50 cr 2023-06-10 16:33:10 -07:00
Ankush Gola
8af3534170 avoid double counting for openai callback 2023-06-10 15:40:31 -07:00
Ankush Gola
1724366b12 Merge branch 'master' into ankush/batch-llm-fix 2023-06-10 14:52:27 -07:00
Ankush Gola
1c001aace2 format 2023-06-07 23:44:57 -07:00
Ankush Gola
70f0f337df refactor 2023-06-07 23:36:36 -07:00
Ankush Gola
aea44eb246 format 2023-06-07 22:56:57 -07:00
Ankush Gola
f458da845a fix tests 2023-06-07 22:56:46 -07:00
Ankush Gola
960d0f6646 recalc tokens 2023-06-07 22:41:05 -07:00
Ankush Gola
0197abf2b3 fix chat model case 2023-06-07 18:01:39 -07:00
Ankush Gola
ab7a8e14df rm extaneous run id field 2023-06-07 15:33:12 -07:00
Ankush Gola
858c46bd87 cr 2023-06-07 15:22:51 -07:00
Ankush Gola
928fb3f719 cr 2023-06-07 15:22:00 -07:00
Ankush Gola
725e4a41f2 Update langchain/callbacks/manager.py
Co-authored-by: Nuno Campos <nuno@boringbits.io>
2023-06-07 14:58:38 -07:00
Ankush Gola
40409b4dbf fix case for empty lists 2023-06-06 22:55:34 -07:00
Ankush Gola
5fa7ea2146 fix chat models and tests 2023-06-06 22:50:30 -07:00
Ankush Gola
207a7b7bbd split up batch llm calls into separate runs 2023-06-06 19:54:43 -07:00
15 changed files with 361 additions and 242 deletions

View File

@@ -631,54 +631,66 @@ class CallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
managers = []
for prompt in prompts:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
managers.append(
CallbackManagerForLLMRun(
run_id_,
self.handlers,
self.inheritable_handlers,
self.parent_run_id,
)
)
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
return managers
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
# Re-use the LLM Run Manager since the outputs are treated
# the same for now
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
managers = []
for message_list in messages:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message_list],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
managers.append(
CallbackManagerForLLMRun(
run_id_,
self.handlers,
self.inheritable_handlers,
self.parent_run_id,
)
)
return managers
def on_chain_start(
self,
@@ -756,52 +768,78 @@ class AsyncCallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForLLMRun:
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
tasks = []
managers = []
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
for prompt in prompts:
run_id_ = uuid4()
tasks.append(
_ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
run_id_,
self.handlers,
self.inheritable_handlers,
self.parent_run_id,
)
)
await asyncio.gather(*tasks)
return managers
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if run_id is None:
run_id = uuid4()
tasks = []
managers = []
await _ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
for message_list in messages:
run_id_ = uuid4()
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
tasks.append(
_ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message_list],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
run_id_,
self.handlers,
self.inheritable_handlers,
self.parent_run_id,
)
)
await asyncio.gather(*tasks)
return managers
async def on_chain_start(
self,

View File

@@ -1,8 +1,8 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import LLMResult
MODEL_COST_PER_1K_TOKENS = {
"gpt-4": 0.03,
@@ -110,64 +110,6 @@ class OpenAICallbackHandler(BaseCallbackHandler):
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
pass
def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""
return self

View File

@@ -73,7 +73,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chat_model_start(
run_managers = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
)
@@ -82,20 +82,30 @@ class BaseChatModel(BaseLanguageModel, ABC):
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager)
self._generate(
m, stop=stop, run_manager=run_managers[0] if run_managers else None
)
if new_arg_supported
else self._generate(m, stop=stop)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
async def agenerate(
@@ -111,7 +121,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chat_model_start(
run_managers = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
)
@@ -121,21 +131,40 @@ class BaseChatModel(BaseLanguageModel, ABC):
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager)
self._agenerate(
m,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
)
if new_arg_supported
else self._agenerate(m, stop=stop)
for m in messages
]
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate_prompt(

View File

@@ -1,4 +1,5 @@
"""Base interface for large language models to expose."""
import asyncio
import inspect
import json
import warnings
@@ -143,6 +144,36 @@ class BaseLLM(BaseLanguageModel, ABC):
prompt_strings = [p.to_string() for p in prompts]
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks)
def _generate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[CallbackManagerForLLMRun],
new_arg_supported: bool,
) -> LLMResult:
try:
output = (
self._generate(
prompts,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
)
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate(
self,
prompts: List[str],
@@ -150,8 +181,6 @@ class BaseLLM(BaseLanguageModel, ABC):
callbacks: Callbacks = None,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
# not make sense.
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
@@ -173,55 +202,77 @@ class BaseLLM(BaseLanguageModel, ABC):
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = callback_manager.on_llm_start(
run_managers = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
)
try:
output = (
self._generate(prompts, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
output = self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported)
)
return output
if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start(
run_managers = callback_manager.on_llm_start(
{"name": self.__class__.__name__},
missing_prompts,
invocation_params=params,
)
try:
new_results = (
self._generate(missing_prompts, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._generate(missing_prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(new_results)
new_results = self._generate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported)
)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
if run_managers
else None
)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
async def _agenerate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[AsyncCallbackManagerForLLMRun],
new_arg_supported: bool,
) -> LLMResult:
try:
output = (
await self._agenerate(
prompts,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = output.flatten()
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
async def agenerate(
self,
prompts: List[str],
@@ -245,51 +296,34 @@ class BaseLLM(BaseLanguageModel, ABC):
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = await callback_manager.on_llm_start(
run_managers = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
)
try:
output = (
await self._agenerate(prompts, stop=stop, run_manager=run_manager)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e, verbose=self.verbose)
raise e
await run_manager.on_llm_end(output, verbose=self.verbose)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
output = await self._agenerate_helper(
prompts, stop, run_managers, bool(new_arg_supported)
)
return output
if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start(
run_managers = await callback_manager.on_llm_start(
{"name": self.__class__.__name__},
missing_prompts,
invocation_params=params,
)
try:
new_results = (
await self._agenerate(
missing_prompts, stop=stop, run_manager=run_manager
)
if new_arg_supported
else await self._agenerate(missing_prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
await run_manager.on_llm_end(new_results)
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported)
)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
if run_managers
else None
)
else:
llm_output = {}
run_info = None
@@ -460,13 +494,17 @@ class LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
for prompt in prompts:
text = (
async def get_completion(prompt: str) -> str:
return (
await self._acall(prompt, stop=stop, run_manager=run_manager)
if new_arg_supported
else await self._acall(prompt, stop=stop)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
generations = await asyncio.gather(
*(get_completion(prompt) for prompt in prompts)
)
return LLMResult(generations=[[Generation(text=text)] for text in generations])

View File

@@ -182,9 +182,35 @@ class LLMResult(BaseModel):
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
run: Optional[RunInfo] = None
run: Optional[List[RunInfo]] = None
"""Run metadata."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list."""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
if i == 0:
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=self.llm_output,
)
)
else:
if self.llm_output is not None:
llm_output = self.llm_output.copy()
llm_output["token_usage"] = dict()
else:
llm_output = None
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=llm_output,
)
)
return llm_results
def __eq__(self, other: object) -> bool:
if not isinstance(other, LLMResult):
return NotImplemented

View File

@@ -38,6 +38,21 @@ async def test_openai_callback() -> None:
assert cb.total_tokens == total_tokens
def test_openai_callback_batch_llm() -> None:
llm = OpenAI(temperature=0)
with get_openai_callback() as cb:
llm.generate(["What is the square root of 4?", "What is the square root of 4?"])
assert cb.total_tokens > 0
total_tokens = cb.total_tokens
with get_openai_callback() as cb:
llm("What is the square root of 4?")
llm("What is the square root of 4?")
assert cb.total_tokens == total_tokens
def test_openai_callback_agent() -> None:
llm = OpenAI(temperature=0)
tools = load_tools(["serpapi", "llm-math"], llm=llm)

View File

@@ -96,6 +96,15 @@ def test_openai_streaming() -> None:
assert isinstance(token["choices"][0]["text"], str)
def test_openai_multiple_prompts() -> None:
"""Test completion with multiple prompts."""
llm = OpenAI(max_tokens=10)
output = llm.generate(["I'm Pickle Rick", "I'm Pickle Rick"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2
def test_openai_streaming_error() -> None:
"""Test error handling in stream."""
llm = OpenAI(best_of=2)

View File

@@ -27,6 +27,10 @@ class FakeListLLM(LLM):
print(self.responses[self.i])
return self.responses[self.i]
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens in text."""
return len(text.split())
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}

View File

@@ -43,6 +43,10 @@ class FakeListLLM(LLM):
self.i += 1
return self.responses[self.i]
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens in text."""
return len(text.split())
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}

View File

@@ -18,11 +18,12 @@ def _test_callback_manager(
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
run_manager = manager.on_llm_start({}, [])
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_manager.on_llm_new_token("foo")
run_manager.on_text("foo")
run_managers = manager.on_llm_start({}, ["prompt"])
for run_manager in run_managers:
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_manager.on_llm_new_token("foo")
run_manager.on_text("foo")
run_manager_chain = manager.on_chain_start({"name": "foo"}, {})
run_manager_chain.on_chain_end({})
@@ -42,11 +43,12 @@ async def _test_callback_manager_async(
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
run_manager = await manager.on_llm_start({}, [])
await run_manager.on_llm_end(LLMResult(generations=[]))
await run_manager.on_llm_error(Exception())
await run_manager.on_llm_new_token("foo")
await run_manager.on_text("foo")
run_managers = await manager.on_llm_start({}, ["prompt"])
for run_manager in run_managers:
await run_manager.on_llm_end(LLMResult(generations=[]))
await run_manager.on_llm_error(Exception())
await run_manager.on_llm_new_token("foo")
await run_manager.on_text("foo")
run_manager_chain = await manager.on_chain_start({"name": "foo"}, {})
await run_manager_chain.on_chain_end({})
@@ -95,9 +97,10 @@ def test_ignore_llm() -> None:
handler1 = FakeCallbackHandler(ignore_llm_=True)
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
run_manager = manager.on_llm_start({}, [])
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_managers = manager.on_llm_start({}, ["prompt"])
for run_manager in run_managers:
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0

View File

@@ -11,7 +11,7 @@ from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import LLMResult
from langchain.schema import HumanMessage, LLMResult
class FakeTracer(BaseTracer):
@@ -55,9 +55,13 @@ def test_tracer_llm_run() -> None:
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_managers = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[HumanMessage(content="")]]
)
compare_run = Run(
id=str(uuid),
id=str(run_managers[0].run_id),
name="chat_model",
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
@@ -65,17 +69,13 @@ def test_tracer_chat_model_run() -> None:
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
inputs=dict(prompts=[""]),
inputs=dict(prompts=["Human: "]),
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
)
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
for run_manager in run_managers:
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]

View File

@@ -18,7 +18,7 @@ from langchain.callbacks.tracers.langchain_v1 import (
TracerSessionV1,
)
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
from langchain.schema import LLMResult
from langchain.schema import HumanMessage, LLMResult
TEST_SESSION_ID = 2023
@@ -124,9 +124,15 @@ def test_tracer_llm_run() -> None:
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_managers = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[HumanMessage(content="")]]
)
compare_run = LLMRun(
uuid=str(uuid),
uuid=str(run_managers[0].run_id),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
@@ -134,19 +140,13 @@ def test_tracer_chat_model_run() -> None:
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
prompts=[""],
prompts=["Human: "],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
for run_manager in run_managers:
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]

View File

@@ -47,6 +47,10 @@ class FakeLLM(BaseLLM):
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens."""
return len(text.split())
@property
def _llm_type(self) -> str:
"""Return type of llm."""

View File

@@ -27,6 +27,9 @@ class FakeLLM(LLM):
"""Return type of llm."""
return "fake"
def get_num_tokens(self, text: str) -> int:
return len(text.split())
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}

View File

@@ -24,6 +24,10 @@ class FakeLLM(LLM):
)
return queries
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens."""
return len(text.split())
@property
def _llm_type(self) -> str:
"""Return type of llm."""