From e1b801be36d2d492eea071dcb594e966414b31cf Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Sat, 24 Jun 2023 21:03:31 -0700 Subject: [PATCH] split up batch llm calls into separate runs (#5804) --- langchain/callbacks/manager.py | 212 ++++++++++-------- langchain/callbacks/openai_info.py | 62 +---- langchain/chat_models/base.py | 105 ++++++--- langchain/llms/base.py | 175 +++++++++------ langchain/schema.py | 28 ++- .../callbacks/test_openai_callback.py | 15 ++ tests/integration_tests/llms/test_openai.py | 9 + tests/unit_tests/agents/test_agent.py | 4 + .../callbacks/test_callback_manager.py | 29 +-- .../callbacks/tracers/test_base_tracer.py | 20 +- .../callbacks/tracers/test_langchain_v1.py | 24 +- tests/unit_tests/chains/test_hyde.py | 4 + tests/unit_tests/chains/test_natbot.py | 3 + tests/unit_tests/llms/fake_llm.py | 4 + 14 files changed, 401 insertions(+), 293 deletions(-) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 3e899234e61..cbc58c4db80 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -672,66 +672,72 @@ class CallbackManager(BaseCallbackManager): self, serialized: Dict[str, Any], prompts: List[str], - run_id: Optional[UUID] = None, **kwargs: Any, - ) -> CallbackManagerForLLMRun: + ) -> List[CallbackManagerForLLMRun]: """Run when LLM starts running.""" - if run_id is None: - run_id = uuid4() + managers = [] + for prompt in prompts: + run_id_ = uuid4() + _handle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) - _handle_event( - self.handlers, - "on_llm_start", - "ignore_llm", - serialized, - prompts, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) + managers.append( + CallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + ) + ) - return CallbackManagerForLLMRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - ) + return managers def on_chat_model_start( self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], - run_id: Optional[UUID] = None, **kwargs: Any, - ) -> CallbackManagerForLLMRun: + ) -> List[CallbackManagerForLLMRun]: """Run when LLM starts running.""" - if run_id is None: - run_id = uuid4() - _handle_event( - self.handlers, - "on_chat_model_start", - "ignore_chat_model", - serialized, - messages, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) - # Re-use the LLM Run Manager since the outputs are treated - # the same for now - return CallbackManagerForLLMRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - ) + managers = [] + for message_list in messages: + run_id_ = uuid4() + _handle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + [message_list], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + + managers.append( + CallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + ) + ) + + return managers def on_chain_start( self, @@ -830,64 +836,84 @@ class AsyncCallbackManager(BaseCallbackManager): self, serialized: Dict[str, Any], prompts: List[str], - run_id: Optional[UUID] = None, **kwargs: Any, - ) -> AsyncCallbackManagerForLLMRun: + ) -> List[AsyncCallbackManagerForLLMRun]: """Run when LLM starts running.""" - if run_id is None: - run_id = uuid4() - await _ahandle_event( - self.handlers, - "on_llm_start", - "ignore_llm", - serialized, - prompts, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) + tasks = [] + managers = [] - return AsyncCallbackManagerForLLMRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - ) + for prompt in prompts: + run_id_ = uuid4() + + tasks.append( + _ahandle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + ) + + managers.append( + AsyncCallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + ) + ) + + await asyncio.gather(*tasks) + + return managers async def on_chat_model_start( self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], - run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - if run_id is None: - run_id = uuid4() + tasks = [] + managers = [] - await _ahandle_event( - self.handlers, - "on_chat_model_start", - "ignore_chat_model", - serialized, - messages, - run_id=run_id, - parent_run_id=self.parent_run_id, - tags=self.tags, - **kwargs, - ) + for message_list in messages: + run_id_ = uuid4() - return AsyncCallbackManagerForLLMRun( - run_id=run_id, - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, - parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - ) + tasks.append( + _ahandle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + [message_list], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + **kwargs, + ) + ) + + managers.append( + AsyncCallbackManagerForLLMRun( + run_id=run_id_, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + ) + ) + + await asyncio.gather(*tasks) + return managers async def on_chain_start( self, diff --git a/langchain/callbacks/openai_info.py b/langchain/callbacks/openai_info.py index 6dc1a73195c..c66ec7ce873 100644 --- a/langchain/callbacks/openai_info.py +++ b/langchain/callbacks/openai_info.py @@ -1,8 +1,8 @@ """Callback Handler that prints to std out.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema import LLMResult MODEL_COST_PER_1K_TOKENS = { # GPT-4 input @@ -152,64 +152,6 @@ class OpenAICallbackHandler(BaseCallbackHandler): self.prompt_tokens += prompt_tokens self.completion_tokens += completion_tokens - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - pass - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - pass - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Print out the log in specified color.""" - pass - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - pass - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" - pass - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Run on agent end.""" - pass - def __copy__(self) -> "OpenAICallbackHandler": """Return a copy of the callback handler.""" return self diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index c4bb81305af..51b3a5a9bbb 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -101,26 +101,37 @@ class BaseChatModel(BaseLanguageModel, ABC): tags, self.tags, ) - run_manager = callback_manager.on_chat_model_start( + run_managers = callback_manager.on_chat_model_start( dumpd(self), messages, invocation_params=params, options=options ) - - try: - results = [ - self._generate_with_cache( - m, stop=stop, run_manager=run_manager, **kwargs + results = [] + for i, m in enumerate(messages): + try: + results.append( + self._generate_with_cache( + m, + stop=stop, + run_manager=run_managers[i] if run_managers else None, + **kwargs, + ) ) - for m in messages - ] - except (KeyboardInterrupt, Exception) as e: - run_manager.on_llm_error(e) - raise e + except (KeyboardInterrupt, Exception) as e: + if run_managers: + run_managers[i].on_llm_error(e) + raise e + flattened_outputs = [ + LLMResult(generations=[res.generations], llm_output=res.llm_output) + for res in results + ] llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) - run_manager.on_llm_end(output) - if run_manager: - output.run = RunInfo(run_id=run_manager.run_id) + if run_managers: + run_infos = [] + for manager, flattened_output in zip(run_managers, flattened_outputs): + manager.on_llm_end(flattened_output) + run_infos.append(RunInfo(run_id=manager.run_id)) + output.run = run_infos return output async def agenerate( @@ -143,28 +154,62 @@ class BaseChatModel(BaseLanguageModel, ABC): tags, self.tags, ) - run_manager = await callback_manager.on_chat_model_start( + + run_managers = await callback_manager.on_chat_model_start( dumpd(self), messages, invocation_params=params, options=options ) - try: - results = await asyncio.gather( - *[ - self._agenerate_with_cache( - m, stop=stop, run_manager=run_manager, **kwargs - ) - for m in messages - ] - ) - except (KeyboardInterrupt, Exception) as e: - await run_manager.on_llm_error(e) - raise e + results = await asyncio.gather( + *[ + self._agenerate_with_cache( + m, + stop=stop, + run_manager=run_managers[i] if run_managers else None, + **kwargs, + ) + for i, m in enumerate(messages) + ], + return_exceptions=True, + ) + exceptions = [] + for i, res in enumerate(results): + if isinstance(res, Exception): + if run_managers: + await run_managers[i].on_llm_error(res) + exceptions.append(res) + if exceptions: + if run_managers: + await asyncio.gather( + *[ + run_manager.on_llm_end( + LLMResult( + generations=[res.generations], llm_output=res.llm_output + ) + ) + for run_manager, res in zip(run_managers, results) + if not isinstance(res, Exception) + ] + ) + raise exceptions[0] + flattened_outputs = [ + LLMResult(generations=[res.generations], llm_output=res.llm_output) + for res in results + ] llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) - await run_manager.on_llm_end(output) - if run_manager: - output.run = RunInfo(run_id=run_manager.run_id) + await asyncio.gather( + *[ + run_manager.on_llm_end(flattened_output) + for run_manager, flattened_output in zip( + run_managers, flattened_outputs + ) + ] + ) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] return output def generate_prompt( diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 46e3a81fd89..a25d65f659f 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -1,4 +1,5 @@ """Base interface for large language models to expose.""" +import asyncio import inspect import json import warnings @@ -151,6 +152,39 @@ class BaseLLM(BaseLanguageModel, ABC): prompt_strings, stop=stop, callbacks=callbacks, **kwargs ) + def _generate_helper( + self, + prompts: List[str], + stop: Optional[List[str]], + run_managers: List[CallbackManagerForLLMRun], + new_arg_supported: bool, + **kwargs: Any, + ) -> LLMResult: + try: + output = ( + self._generate( + prompts, + stop=stop, + # TODO: support multiple run managers + run_manager=run_managers[0] if run_managers else None, + **kwargs, + ) + if new_arg_supported + else self._generate(prompts, stop=stop) + ) + except (KeyboardInterrupt, Exception) as e: + for run_manager in run_managers: + run_manager.on_llm_error(e) + raise e + flattened_outputs = output.flatten() + for manager, flattened_output in zip(run_managers, flattened_outputs): + manager.on_llm_end(flattened_output) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] + return output + def generate( self, prompts: List[str], @@ -161,8 +195,6 @@ class BaseLLM(BaseLanguageModel, ABC): **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - # If string is passed in directly no errors will be raised but outputs will - # not make sense. if not isinstance(prompts, list): raise ValueError( "Argument 'prompts' is expected to be of type List[str], received" @@ -185,60 +217,77 @@ class BaseLLM(BaseLanguageModel, ABC): "run_manager" ) if langchain.llm_cache is None or disregard_cache: - # This happens when langchain.cache is None, but self.cache is True if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - run_manager = callback_manager.on_llm_start( + run_managers = callback_manager.on_llm_start( dumpd(self), prompts, invocation_params=params, options=options ) - try: - output = ( - self._generate( - prompts, stop=stop, run_manager=run_manager, **kwargs - ) - if new_arg_supported - else self._generate(prompts, stop=stop, **kwargs) - ) - except (KeyboardInterrupt, Exception) as e: - run_manager.on_llm_error(e) - raise e - run_manager.on_llm_end(output) - if run_manager: - output.run = RunInfo(run_id=run_manager.run_id) + output = self._generate_helper( + prompts, stop, run_managers, bool(new_arg_supported), **kwargs + ) return output if len(missing_prompts) > 0: - run_manager = callback_manager.on_llm_start( - dumpd(self), - missing_prompts, - invocation_params=params, - options=options, + run_managers = callback_manager.on_llm_start( + dumpd(self), missing_prompts, invocation_params=params, options=options + ) + new_results = self._generate_helper( + missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) - try: - new_results = ( - self._generate( - missing_prompts, stop=stop, run_manager=run_manager, **kwargs - ) - if new_arg_supported - else self._generate(missing_prompts, stop=stop, **kwargs) - ) - except (KeyboardInterrupt, Exception) as e: - run_manager.on_llm_error(e) - raise e - run_manager.on_llm_end(new_results) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) - run_info = None - if run_manager: - run_info = RunInfo(run_id=run_manager.run_id) + run_info = ( + [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] + if run_managers + else None + ) else: llm_output = {} run_info = None generations = [existing_prompts[i] for i in range(len(prompts))] return LLMResult(generations=generations, llm_output=llm_output, run=run_info) + async def _agenerate_helper( + self, + prompts: List[str], + stop: Optional[List[str]], + run_managers: List[AsyncCallbackManagerForLLMRun], + new_arg_supported: bool, + **kwargs: Any, + ) -> LLMResult: + try: + output = ( + await self._agenerate( + prompts, + stop=stop, + run_manager=run_managers[0] if run_managers else None, + **kwargs, + ) + if new_arg_supported + else await self._agenerate(prompts, stop=stop) + ) + except (KeyboardInterrupt, Exception) as e: + await asyncio.gather( + *[run_manager.on_llm_error(e) for run_manager in run_managers] + ) + raise e + flattened_outputs = output.flatten() + await asyncio.gather( + *[ + run_manager.on_llm_end(flattened_output) + for run_manager, flattened_output in zip( + run_managers, flattened_outputs + ) + ] + ) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] + return output + async def agenerate( self, prompts: List[str], @@ -266,54 +315,32 @@ class BaseLLM(BaseLanguageModel, ABC): "run_manager" ) if langchain.llm_cache is None or disregard_cache: - # This happens when langchain.cache is None, but self.cache is True if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - run_manager = await callback_manager.on_llm_start( + run_managers = await callback_manager.on_llm_start( dumpd(self), prompts, invocation_params=params, options=options ) - try: - output = ( - await self._agenerate( - prompts, stop=stop, run_manager=run_manager, **kwargs - ) - if new_arg_supported - else await self._agenerate(prompts, stop=stop, **kwargs) - ) - except (KeyboardInterrupt, Exception) as e: - await run_manager.on_llm_error(e, verbose=self.verbose) - raise e - await run_manager.on_llm_end(output, verbose=self.verbose) - if run_manager: - output.run = RunInfo(run_id=run_manager.run_id) + output = await self._agenerate_helper( + prompts, stop, run_managers, bool(new_arg_supported), **kwargs + ) return output if len(missing_prompts) > 0: - run_manager = await callback_manager.on_llm_start( - dumpd(self), - missing_prompts, - invocation_params=params, - options=options, + run_managers = await callback_manager.on_llm_start( + dumpd(self), missing_prompts, invocation_params=params, options=options + ) + new_results = await self._agenerate_helper( + missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) - try: - new_results = ( - await self._agenerate( - missing_prompts, stop=stop, run_manager=run_manager, **kwargs - ) - if new_arg_supported - else await self._agenerate(missing_prompts, stop=stop, **kwargs) - ) - except (KeyboardInterrupt, Exception) as e: - await run_manager.on_llm_error(e) - raise e - await run_manager.on_llm_end(new_results) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) - run_info = None - if run_manager: - run_info = RunInfo(run_id=run_manager.run_id) + run_info = ( + [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] + if run_managers + else None + ) else: llm_output = {} run_info = None diff --git a/langchain/schema.py b/langchain/schema.py index 88baff3c202..07106699b58 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -227,9 +227,35 @@ class LLMResult(BaseModel): each input could have multiple generations.""" llm_output: Optional[dict] = None """For arbitrary LLM provider specific output.""" - run: Optional[RunInfo] = None + run: Optional[List[RunInfo]] = None """Run metadata.""" + def flatten(self) -> List[LLMResult]: + """Flatten generations into a single list.""" + llm_results = [] + for i, gen_list in enumerate(self.generations): + # Avoid double counting tokens in OpenAICallback + if i == 0: + llm_results.append( + LLMResult( + generations=[gen_list], + llm_output=self.llm_output, + ) + ) + else: + if self.llm_output is not None: + llm_output = self.llm_output.copy() + llm_output["token_usage"] = dict() + else: + llm_output = None + llm_results.append( + LLMResult( + generations=[gen_list], + llm_output=llm_output, + ) + ) + return llm_results + def __eq__(self, other: object) -> bool: if not isinstance(other, LLMResult): return NotImplemented diff --git a/tests/integration_tests/callbacks/test_openai_callback.py b/tests/integration_tests/callbacks/test_openai_callback.py index 9704cb5612f..e32dbaa0472 100644 --- a/tests/integration_tests/callbacks/test_openai_callback.py +++ b/tests/integration_tests/callbacks/test_openai_callback.py @@ -38,6 +38,21 @@ async def test_openai_callback() -> None: assert cb.total_tokens == total_tokens +def test_openai_callback_batch_llm() -> None: + llm = OpenAI(temperature=0) + with get_openai_callback() as cb: + llm.generate(["What is the square root of 4?", "What is the square root of 4?"]) + + assert cb.total_tokens > 0 + total_tokens = cb.total_tokens + + with get_openai_callback() as cb: + llm("What is the square root of 4?") + llm("What is the square root of 4?") + + assert cb.total_tokens == total_tokens + + def test_openai_callback_agent() -> None: llm = OpenAI(temperature=0) tools = load_tools(["serpapi", "llm-math"], llm=llm) diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 7b7253ab357..f1a146da44f 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -96,6 +96,15 @@ def test_openai_streaming() -> None: assert isinstance(token["choices"][0]["text"], str) +def test_openai_multiple_prompts() -> None: + """Test completion with multiple prompts.""" + llm = OpenAI(max_tokens=10) + output = llm.generate(["I'm Pickle Rick", "I'm Pickle Rick"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list) + assert len(output.generations) == 2 + + def test_openai_streaming_error() -> None: """Test error handling in stream.""" llm = OpenAI(best_of=2) diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index ac89aa4a768..be3818611b2 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -28,6 +28,10 @@ class FakeListLLM(LLM): print(self.responses[self.i]) return self.responses[self.i] + def get_num_tokens(self, text: str) -> int: + """Return number of tokens in text.""" + return len(text.split()) + @property def _identifying_params(self) -> Mapping[str, Any]: return {} diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 24877b81d7f..73d38233ba5 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -18,11 +18,12 @@ def _test_callback_manager( manager: CallbackManager, *handlers: BaseFakeCallbackHandler ) -> None: """Test the CallbackManager.""" - run_manager = manager.on_llm_start({}, []) - run_manager.on_llm_end(LLMResult(generations=[])) - run_manager.on_llm_error(Exception()) - run_manager.on_llm_new_token("foo") - run_manager.on_text("foo") + run_managers = manager.on_llm_start({}, ["prompt"]) + for run_manager in run_managers: + run_manager.on_llm_end(LLMResult(generations=[])) + run_manager.on_llm_error(Exception()) + run_manager.on_llm_new_token("foo") + run_manager.on_text("foo") run_manager_chain = manager.on_chain_start({"name": "foo"}, {}) run_manager_chain.on_chain_end({}) @@ -42,11 +43,12 @@ async def _test_callback_manager_async( manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler ) -> None: """Test the CallbackManager.""" - run_manager = await manager.on_llm_start({}, []) - await run_manager.on_llm_end(LLMResult(generations=[])) - await run_manager.on_llm_error(Exception()) - await run_manager.on_llm_new_token("foo") - await run_manager.on_text("foo") + run_managers = await manager.on_llm_start({}, ["prompt"]) + for run_manager in run_managers: + await run_manager.on_llm_end(LLMResult(generations=[])) + await run_manager.on_llm_error(Exception()) + await run_manager.on_llm_new_token("foo") + await run_manager.on_text("foo") run_manager_chain = await manager.on_chain_start({"name": "foo"}, {}) await run_manager_chain.on_chain_end({}) @@ -95,9 +97,10 @@ def test_ignore_llm() -> None: handler1 = FakeCallbackHandler(ignore_llm_=True) handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - run_manager = manager.on_llm_start({}, []) - run_manager.on_llm_end(LLMResult(generations=[])) - run_manager.on_llm_error(Exception()) + run_managers = manager.on_llm_start({}, ["prompt"]) + for run_manager in run_managers: + run_manager.on_llm_end(LLMResult(generations=[])) + run_manager.on_llm_error(Exception()) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 diff --git a/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/tests/unit_tests/callbacks/tracers/test_base_tracer.py index c0736c62eb6..ca7735467d2 100644 --- a/tests/unit_tests/callbacks/tracers/test_base_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_base_tracer.py @@ -11,7 +11,7 @@ from freezegun import freeze_time from langchain.callbacks.manager import CallbackManager from langchain.callbacks.tracers.base import BaseTracer, TracerException from langchain.callbacks.tracers.schemas import Run -from langchain.schema import LLMResult +from langchain.schema import HumanMessage, LLMResult SERIALIZED = {"id": ["llm"]} SERIALIZED_CHAT = {"id": ["chat_model"]} @@ -58,9 +58,13 @@ def test_tracer_llm_run() -> None: @freeze_time("2023-01-01") def test_tracer_chat_model_run() -> None: """Test tracer on a Chat Model run.""" - uuid = uuid4() + tracer = FakeTracer() + manager = CallbackManager(handlers=[tracer]) + run_managers = manager.on_chat_model_start( + serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]] + ) compare_run = Run( - id=str(uuid), + id=str(run_managers[0].run_id), name="chat_model", start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -68,17 +72,13 @@ def test_tracer_chat_model_run() -> None: execution_order=1, child_execution_order=1, serialized=SERIALIZED_CHAT, - inputs=dict(prompts=[""]), + inputs=dict(prompts=["Human: "]), outputs=LLMResult(generations=[[]]), error=None, run_type="llm", ) - tracer = FakeTracer() - manager = CallbackManager(handlers=[tracer]) - run_manager = manager.on_chat_model_start( - serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid - ) - run_manager.on_llm_end(response=LLMResult(generations=[[]])) + for run_manager in run_managers: + run_manager.on_llm_end(response=LLMResult(generations=[[]])) assert tracer.runs == [compare_run] diff --git a/tests/unit_tests/callbacks/tracers/test_langchain_v1.py b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py index 782f3fbcfe7..f57b93a074e 100644 --- a/tests/unit_tests/callbacks/tracers/test_langchain_v1.py +++ b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py @@ -18,7 +18,7 @@ from langchain.callbacks.tracers.langchain_v1 import ( TracerSessionV1, ) from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base -from langchain.schema import LLMResult +from langchain.schema import HumanMessage, LLMResult TEST_SESSION_ID = 2023 @@ -127,9 +127,15 @@ def test_tracer_llm_run() -> None: @freeze_time("2023-01-01") def test_tracer_chat_model_run() -> None: """Test tracer on a Chat Model run.""" - uuid = uuid4() + tracer = FakeTracer() + + tracer.new_session() + manager = CallbackManager(handlers=[tracer]) + run_managers = manager.on_chat_model_start( + serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]] + ) compare_run = LLMRun( - uuid=str(uuid), + uuid=str(run_managers[0].run_id), parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -137,19 +143,13 @@ def test_tracer_chat_model_run() -> None: execution_order=1, child_execution_order=1, serialized=SERIALIZED_CHAT, - prompts=[""], + prompts=["Human: "], response=LLMResult(generations=[[]]), session_id=TEST_SESSION_ID, error=None, ) - tracer = FakeTracer() - - tracer.new_session() - manager = CallbackManager(handlers=[tracer]) - run_manager = manager.on_chat_model_start( - serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid - ) - run_manager.on_llm_end(response=LLMResult(generations=[[]])) + for run_manager in run_managers: + run_manager.on_llm_end(response=LLMResult(generations=[[]])) assert tracer.runs == [compare_run] diff --git a/tests/unit_tests/chains/test_hyde.py b/tests/unit_tests/chains/test_hyde.py index e189c84e898..c7eeb32e621 100644 --- a/tests/unit_tests/chains/test_hyde.py +++ b/tests/unit_tests/chains/test_hyde.py @@ -49,6 +49,10 @@ class FakeLLM(BaseLLM): ) -> LLMResult: return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) + def get_num_tokens(self, text: str) -> int: + """Return number of tokens.""" + return len(text.split()) + @property def _llm_type(self) -> str: """Return type of llm.""" diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py index e5f68ab5006..0cc5a649701 100644 --- a/tests/unit_tests/chains/test_natbot.py +++ b/tests/unit_tests/chains/test_natbot.py @@ -28,6 +28,9 @@ class FakeLLM(LLM): """Return type of llm.""" return "fake" + def get_num_tokens(self, text: str) -> int: + return len(text.split()) + @property def _identifying_params(self) -> Mapping[str, Any]: return {} diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py index 71c2f0b3657..33e88d46702 100644 --- a/tests/unit_tests/llms/fake_llm.py +++ b/tests/unit_tests/llms/fake_llm.py @@ -24,6 +24,10 @@ class FakeLLM(LLM): ) return queries + def get_num_tokens(self, text: str) -> int: + """Return number of tokens.""" + return len(text.split()) + @property def _llm_type(self) -> str: """Return type of llm."""