mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-23 02:15:42 +00:00 
			
		
		
		
	### Add on_chat_message_start to callback manager and base tracer Goal: trace messages directly to permit reloading as chat messages (store in an integration-agnostic way) Add an `on_chat_message_start` method. Fall back to `on_llm_start()` for handlers that don't have it implemented. Does so in a non-backwards-compat breaking way (for now)
		
			
				
	
	
		
			333 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			333 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """A fake callback handler for testing purposes."""
 | |
| from itertools import chain
 | |
| from typing import Any, Dict, List, Optional
 | |
| from uuid import UUID
 | |
| 
 | |
| from pydantic import BaseModel
 | |
| 
 | |
| from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
 | |
| from langchain.schema import BaseMessage
 | |
| 
 | |
| 
 | |
| class BaseFakeCallbackHandler(BaseModel):
 | |
|     """Base fake callback handler for testing."""
 | |
| 
 | |
|     starts: int = 0
 | |
|     ends: int = 0
 | |
|     errors: int = 0
 | |
|     text: int = 0
 | |
|     ignore_llm_: bool = False
 | |
|     ignore_chain_: bool = False
 | |
|     ignore_agent_: bool = False
 | |
|     ignore_chat_model_: bool = False
 | |
| 
 | |
|     # add finer-grained counters for easier debugging of failing tests
 | |
|     chain_starts: int = 0
 | |
|     chain_ends: int = 0
 | |
|     llm_starts: int = 0
 | |
|     llm_ends: int = 0
 | |
|     llm_streams: int = 0
 | |
|     tool_starts: int = 0
 | |
|     tool_ends: int = 0
 | |
|     agent_actions: int = 0
 | |
|     agent_ends: int = 0
 | |
|     chat_model_starts: int = 0
 | |
| 
 | |
| 
 | |
| class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
 | |
|     """Base fake callback handler mixin for testing."""
 | |
| 
 | |
|     def on_llm_start_common(self) -> None:
 | |
|         self.llm_starts += 1
 | |
|         self.starts += 1
 | |
| 
 | |
|     def on_llm_end_common(self) -> None:
 | |
|         self.llm_ends += 1
 | |
|         self.ends += 1
 | |
| 
 | |
|     def on_llm_error_common(self) -> None:
 | |
|         self.errors += 1
 | |
| 
 | |
|     def on_llm_new_token_common(self) -> None:
 | |
|         self.llm_streams += 1
 | |
| 
 | |
|     def on_chain_start_common(self) -> None:
 | |
|         print("CHAIN START")
 | |
|         self.chain_starts += 1
 | |
|         self.starts += 1
 | |
| 
 | |
|     def on_chain_end_common(self) -> None:
 | |
|         self.chain_ends += 1
 | |
|         self.ends += 1
 | |
| 
 | |
|     def on_chain_error_common(self) -> None:
 | |
|         self.errors += 1
 | |
| 
 | |
|     def on_tool_start_common(self) -> None:
 | |
|         self.tool_starts += 1
 | |
|         self.starts += 1
 | |
| 
 | |
|     def on_tool_end_common(self) -> None:
 | |
|         self.tool_ends += 1
 | |
|         self.ends += 1
 | |
| 
 | |
|     def on_tool_error_common(self) -> None:
 | |
|         self.errors += 1
 | |
| 
 | |
|     def on_agent_action_common(self) -> None:
 | |
|         print("AGENT ACTION")
 | |
|         self.agent_actions += 1
 | |
|         self.starts += 1
 | |
| 
 | |
|     def on_agent_finish_common(self) -> None:
 | |
|         self.agent_ends += 1
 | |
|         self.ends += 1
 | |
| 
 | |
|     def on_chat_model_start_common(self) -> None:
 | |
|         print("STARTING CHAT MODEL")
 | |
|         self.chat_model_starts += 1
 | |
|         self.starts += 1
 | |
| 
 | |
|     def on_text_common(self) -> None:
 | |
|         self.text += 1
 | |
| 
 | |
| 
 | |
| class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
 | |
|     """Fake callback handler for testing."""
 | |
| 
 | |
|     @property
 | |
|     def ignore_llm(self) -> bool:
 | |
|         """Whether to ignore LLM callbacks."""
 | |
|         return self.ignore_llm_
 | |
| 
 | |
|     @property
 | |
|     def ignore_chain(self) -> bool:
 | |
|         """Whether to ignore chain callbacks."""
 | |
|         return self.ignore_chain_
 | |
| 
 | |
|     @property
 | |
|     def ignore_agent(self) -> bool:
 | |
|         """Whether to ignore agent callbacks."""
 | |
|         return self.ignore_agent_
 | |
| 
 | |
|     def on_llm_start(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_llm_start_common()
 | |
| 
 | |
|     def on_llm_new_token(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_llm_new_token_common()
 | |
| 
 | |
|     def on_llm_end(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_llm_end_common()
 | |
| 
 | |
|     def on_llm_error(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_llm_error_common()
 | |
| 
 | |
|     def on_chain_start(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_chain_start_common()
 | |
| 
 | |
|     def on_chain_end(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_chain_end_common()
 | |
| 
 | |
|     def on_chain_error(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_chain_error_common()
 | |
| 
 | |
|     def on_tool_start(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_tool_start_common()
 | |
| 
 | |
|     def on_tool_end(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_tool_end_common()
 | |
| 
 | |
|     def on_tool_error(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_tool_error_common()
 | |
| 
 | |
|     def on_agent_action(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_agent_action_common()
 | |
| 
 | |
|     def on_agent_finish(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_agent_finish_common()
 | |
| 
 | |
|     def on_text(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         self.on_text_common()
 | |
| 
 | |
|     def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
 | |
|         return self
 | |
| 
 | |
| 
 | |
| class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
 | |
|     def on_chat_model_start(
 | |
|         self,
 | |
|         serialized: Dict[str, Any],
 | |
|         messages: List[List[BaseMessage]],
 | |
|         *,
 | |
|         run_id: UUID,
 | |
|         parent_run_id: Optional[UUID] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> Any:
 | |
|         assert all(isinstance(m, BaseMessage) for m in chain(*messages))
 | |
|         self.on_chat_model_start_common()
 | |
| 
 | |
| 
 | |
| class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
 | |
|     """Fake async callback handler for testing."""
 | |
| 
 | |
|     @property
 | |
|     def ignore_llm(self) -> bool:
 | |
|         """Whether to ignore LLM callbacks."""
 | |
|         return self.ignore_llm_
 | |
| 
 | |
|     @property
 | |
|     def ignore_chain(self) -> bool:
 | |
|         """Whether to ignore chain callbacks."""
 | |
|         return self.ignore_chain_
 | |
| 
 | |
|     @property
 | |
|     def ignore_agent(self) -> bool:
 | |
|         """Whether to ignore agent callbacks."""
 | |
|         return self.ignore_agent_
 | |
| 
 | |
|     async def on_llm_start(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_llm_start_common()
 | |
| 
 | |
|     async def on_llm_new_token(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_llm_new_token_common()
 | |
| 
 | |
|     async def on_llm_end(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_llm_end_common()
 | |
| 
 | |
|     async def on_llm_error(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_llm_error_common()
 | |
| 
 | |
|     async def on_chain_start(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_chain_start_common()
 | |
| 
 | |
|     async def on_chain_end(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_chain_end_common()
 | |
| 
 | |
|     async def on_chain_error(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_chain_error_common()
 | |
| 
 | |
|     async def on_tool_start(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_tool_start_common()
 | |
| 
 | |
|     async def on_tool_end(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_tool_end_common()
 | |
| 
 | |
|     async def on_tool_error(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_tool_error_common()
 | |
| 
 | |
|     async def on_agent_action(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_agent_action_common()
 | |
| 
 | |
|     async def on_agent_finish(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_agent_finish_common()
 | |
| 
 | |
|     async def on_text(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         self.on_text_common()
 | |
| 
 | |
|     def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
 | |
|         return self
 |