From 5fa7ea21464420aeb1fd834a944614d925eb4fd5 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Tue, 6 Jun 2023 22:50:30 -0700 Subject: [PATCH] fix chat models and tests --- langchain/callbacks/manager.py | 86 ++++++++++++------- langchain/chat_models/base.py | 41 ++++++--- .../callbacks/test_callback_manager.py | 29 ++++--- .../callbacks/tracers/test_base_tracer.py | 20 ++--- .../callbacks/tracers/test_langchain_v1.py | 24 +++--- 5 files changed, 120 insertions(+), 80 deletions(-) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 31a0a232984..f05db249761 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -674,26 +674,33 @@ class CallbackManager(BaseCallbackManager): messages: List[List[BaseMessage]], run_id: Optional[UUID] = None, **kwargs: Any, - ) -> CallbackManagerForLLMRun: + ) -> List[CallbackManagerForLLMRun]: """Run when LLM starts running.""" - if run_id is None: - run_id = uuid4() - _handle_event( - self.handlers, - "on_chat_model_start", - "ignore_chat_model", - serialized, - messages, - run_id=run_id, - parent_run_id=self.parent_run_id, - **kwargs, - ) - # Re-use the LLM Run Manager since the outputs are treated - # the same for now - return CallbackManagerForLLMRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id - ) + managers = [] + for message in messages: + run_id_ = uuid4() + _handle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + [message], + run_id=run_id_, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + managers.append( + CallbackManagerForLLMRun( + run_id_, + self.handlers, + self.inheritable_handlers, + self.parent_run_id, + ) + ) + + return managers def on_chain_start( self, @@ -815,23 +822,36 @@ class AsyncCallbackManager(BaseCallbackManager): run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - if run_id is None: - run_id = uuid4() + tasks = [] + managers = [] - await _ahandle_event( - self.handlers, - "on_chat_model_start", - "ignore_chat_model", - serialized, - messages, - run_id=run_id, - parent_run_id=self.parent_run_id, - **kwargs, - ) + for message in messages: + run_id_ = uuid4() - return AsyncCallbackManagerForLLMRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id - ) + tasks.append( + _ahandle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + [message], + run_id=run_id_, + parent_run_id=self.parent_run_id, + **kwargs, + ) + ) + + managers.append( + AsyncCallbackManagerForLLMRun( + run_id_, + self.handlers, + self.inheritable_handlers, + self.parent_run_id, + ) + ) + + await asyncio.gather(*tasks) + return managers async def on_chain_start( self, diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index dcb4ebebcfa..fc77993de11 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -73,7 +73,7 @@ class BaseChatModel(BaseLanguageModel, ABC): callback_manager = CallbackManager.configure( callbacks, self.callbacks, self.verbose ) - run_manager = callback_manager.on_chat_model_start( + run_managers = callback_manager.on_chat_model_start( {"name": self.__class__.__name__}, messages, invocation_params=params ) @@ -82,20 +82,25 @@ class BaseChatModel(BaseLanguageModel, ABC): ) try: results = [ - self._generate(m, stop=stop, run_manager=run_manager) + self._generate(m, stop=stop, run_manager=run_managers[0]) if new_arg_supported else self._generate(m, stop=stop) for m in messages ] except (KeyboardInterrupt, Exception) as e: - run_manager.on_llm_error(e) + for run_manager in run_managers: + run_manager.on_llm_error(e) raise e llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) - run_manager.on_llm_end(output) - if run_manager: - output.run = RunInfo(run_id=run_manager.run_id) + flattened_outputs = output.flatten() + for manager, flattened_output in zip(run_managers, flattened_outputs): + manager.on_llm_end(flattened_output) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] return output async def agenerate( @@ -111,7 +116,7 @@ class BaseChatModel(BaseLanguageModel, ABC): callback_manager = AsyncCallbackManager.configure( callbacks, self.callbacks, self.verbose ) - run_manager = await callback_manager.on_chat_model_start( + run_managers = await callback_manager.on_chat_model_start( {"name": self.__class__.__name__}, messages, invocation_params=params ) @@ -121,21 +126,33 @@ class BaseChatModel(BaseLanguageModel, ABC): try: results = await asyncio.gather( *[ - self._agenerate(m, stop=stop, run_manager=run_manager) + self._agenerate(m, stop=stop, run_manager=run_managers[0]) if new_arg_supported else self._agenerate(m, stop=stop) for m in messages ] ) except (KeyboardInterrupt, Exception) as e: - await run_manager.on_llm_error(e) + await asyncio.gather( + *[run_manager.on_llm_error(e) for run_manager in run_managers] + ) raise e llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) - await run_manager.on_llm_end(output) - if run_manager: - output.run = RunInfo(run_id=run_manager.run_id) + flattened_outputs = output.flatten() + await asyncio.gather( + *[ + run_manager.on_llm_end(flattened_output) + for run_manager, flattened_output in zip( + run_managers, flattened_outputs + ) + ] + ) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] return output def generate_prompt( diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 6a215985741..f53dfdab263 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -18,11 +18,12 @@ def _test_callback_manager( manager: CallbackManager, *handlers: BaseFakeCallbackHandler ) -> None: """Test the CallbackManager.""" - run_manager = manager.on_llm_start({}, []) - run_manager.on_llm_end(LLMResult(generations=[])) - run_manager.on_llm_error(Exception()) - run_manager.on_llm_new_token("foo") - run_manager.on_text("foo") + run_managers = manager.on_llm_start({}, ["prompt"]) + for run_manager in run_managers: + run_manager.on_llm_end(LLMResult(generations=[])) + run_manager.on_llm_error(Exception()) + run_manager.on_llm_new_token("foo") + run_manager.on_text("foo") run_manager_chain = manager.on_chain_start({"name": "foo"}, {}) run_manager_chain.on_chain_end({}) @@ -42,11 +43,12 @@ async def _test_callback_manager_async( manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler ) -> None: """Test the CallbackManager.""" - run_manager = await manager.on_llm_start({}, []) - await run_manager.on_llm_end(LLMResult(generations=[])) - await run_manager.on_llm_error(Exception()) - await run_manager.on_llm_new_token("foo") - await run_manager.on_text("foo") + run_managers = await manager.on_llm_start({}, ["prompt"]) + for run_manager in run_managers: + await run_manager.on_llm_end(LLMResult(generations=[])) + await run_manager.on_llm_error(Exception()) + await run_manager.on_llm_new_token("foo") + await run_manager.on_text("foo") run_manager_chain = await manager.on_chain_start({"name": "foo"}, {}) await run_manager_chain.on_chain_end({}) @@ -95,9 +97,10 @@ def test_ignore_llm() -> None: handler1 = FakeCallbackHandler(ignore_llm_=True) handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - run_manager = manager.on_llm_start({}, []) - run_manager.on_llm_end(LLMResult(generations=[])) - run_manager.on_llm_error(Exception()) + run_managers = manager.on_llm_start({}, ["prompt"]) + for run_manager in run_managers: + run_manager.on_llm_end(LLMResult(generations=[])) + run_manager.on_llm_error(Exception()) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 diff --git a/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/tests/unit_tests/callbacks/tracers/test_base_tracer.py index 4ff2e342c5f..cfcb9a23ef3 100644 --- a/tests/unit_tests/callbacks/tracers/test_base_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_base_tracer.py @@ -11,7 +11,7 @@ from freezegun import freeze_time from langchain.callbacks.manager import CallbackManager from langchain.callbacks.tracers.base import BaseTracer, TracerException from langchain.callbacks.tracers.schemas import Run -from langchain.schema import LLMResult +from langchain.schema import HumanMessage, LLMResult class FakeTracer(BaseTracer): @@ -55,9 +55,13 @@ def test_tracer_llm_run() -> None: @freeze_time("2023-01-01") def test_tracer_chat_model_run() -> None: """Test tracer on a Chat Model run.""" - uuid = uuid4() + tracer = FakeTracer() + manager = CallbackManager(handlers=[tracer]) + run_managers = manager.on_chat_model_start( + serialized={"name": "chat_model"}, messages=[[HumanMessage(content="")]] + ) compare_run = Run( - id=str(uuid), + id=str(run_managers[0].run_id), name="chat_model", start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -65,17 +69,13 @@ def test_tracer_chat_model_run() -> None: execution_order=1, child_execution_order=1, serialized={"name": "chat_model"}, - inputs=dict(prompts=[""]), + inputs=dict(prompts=["Human: "]), outputs=LLMResult(generations=[[]]), error=None, run_type="llm", ) - tracer = FakeTracer() - manager = CallbackManager(handlers=[tracer]) - run_manager = manager.on_chat_model_start( - serialized={"name": "chat_model"}, messages=[[]], run_id=uuid - ) - run_manager.on_llm_end(response=LLMResult(generations=[[]])) + for run_manager in run_managers: + run_manager.on_llm_end(response=LLMResult(generations=[[]])) assert tracer.runs == [compare_run] diff --git a/tests/unit_tests/callbacks/tracers/test_langchain_v1.py b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py index ab655ac6312..92fc5c54d3b 100644 --- a/tests/unit_tests/callbacks/tracers/test_langchain_v1.py +++ b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py @@ -18,7 +18,7 @@ from langchain.callbacks.tracers.langchain_v1 import ( TracerSessionV1, ) from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base -from langchain.schema import LLMResult +from langchain.schema import HumanMessage, LLMResult TEST_SESSION_ID = 2023 @@ -124,9 +124,15 @@ def test_tracer_llm_run() -> None: @freeze_time("2023-01-01") def test_tracer_chat_model_run() -> None: """Test tracer on a Chat Model run.""" - uuid = uuid4() + tracer = FakeTracer() + + tracer.new_session() + manager = CallbackManager(handlers=[tracer]) + run_managers = manager.on_chat_model_start( + serialized={"name": "chat_model"}, messages=[[HumanMessage(content="")]] + ) compare_run = LLMRun( - uuid=str(uuid), + uuid=str(run_managers[0].run_id), parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), @@ -134,19 +140,13 @@ def test_tracer_chat_model_run() -> None: execution_order=1, child_execution_order=1, serialized={"name": "chat_model"}, - prompts=[""], + prompts=["Human: "], response=LLMResult(generations=[[]]), session_id=TEST_SESSION_ID, error=None, ) - tracer = FakeTracer() - - tracer.new_session() - manager = CallbackManager(handlers=[tracer]) - run_manager = manager.on_chat_model_start( - serialized={"name": "chat_model"}, messages=[[]], run_id=uuid - ) - run_manager.on_llm_end(response=LLMResult(generations=[[]])) + for run_manager in run_managers: + run_manager.on_llm_end(response=LLMResult(generations=[[]])) assert tracer.runs == [compare_run]