From e3edd74eabb807349b520d002b75f505421e8aa1 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 29 Dec 2022 23:07:55 -0500 Subject: [PATCH] switch up defaults (#485) i kinda like this just because we call `self.callback_manager` so many times, and thats nicer than `self._get_callback_manager()`? --- langchain/agents/agent.py | 8 ++++---- langchain/chains/base.py | 22 +++++++++++++--------- langchain/llms/base.py | 28 ++++++++++++++++------------ 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index d6f6db29870..847633ab1c8 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -220,7 +220,7 @@ class AgentExecutor(Chain, BaseModel): # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): if self.verbose: - self._get_callback_manager().on_agent_end(output.log, color="green") + self.callback_manager.on_agent_end(output.log, color="green") final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps @@ -230,7 +230,7 @@ class AgentExecutor(Chain, BaseModel): if output.tool in name_to_tool_map: chain = name_to_tool_map[output.tool] if self.verbose: - self._get_callback_manager().on_tool_start( + self.callback_manager.on_tool_start( {"name": str(chain)[:60] + "..."}, output, color="green" ) # We then call the tool on the tool input to get an observation @@ -238,13 +238,13 @@ class AgentExecutor(Chain, BaseModel): color = color_mapping[output.tool] else: if self.verbose: - self._get_callback_manager().on_tool_start( + self.callback_manager.on_tool_start( {"name": "N/A"}, output, color="green" ) observation = f"{output.tool} is not a valid tool, try another one." color = None if self.verbose: - self._get_callback_manager().on_tool_end( + self.callback_manager.on_tool_end( observation, color=color, observation_prefix=self.agent.observation_prefix, diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 87bc0f1baef..c830d0b1b7f 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Extra, Field, validator import langchain from langchain.callbacks import get_callback_manager @@ -44,7 +44,7 @@ class Chain(BaseModel, ABC): """Base interface that all chains should implement.""" memory: Optional[Memory] = None - callback_manager: Optional[BaseCallbackManager] = None + callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) verbose: bool = Field( default_factory=_get_verbosity ) # Whether to print the response text @@ -54,11 +54,15 @@ class Chain(BaseModel, ABC): arbitrary_types_allowed = True - def _get_callback_manager(self) -> BaseCallbackManager: - """Get the callback manager.""" - if self.callback_manager is not None: - return self.callback_manager - return get_callback_manager() + @validator("callback_manager", pre=True, always=True) + def set_callback_manager( + cls, callback_manager: Optional[BaseCallbackManager] + ) -> BaseCallbackManager: + """If callback manager is None, set it. + + This allows users to pass in None as context manager, which is a nice UX. + """ + return callback_manager or get_callback_manager() @property @abstractmethod @@ -120,12 +124,12 @@ class Chain(BaseModel, ABC): inputs = dict(inputs, **external_context) self._validate_inputs(inputs) if self.verbose: - self._get_callback_manager().on_chain_start( + self.callback_manager.on_chain_start( {"name": self.__class__.__name__}, inputs ) outputs = self._call(inputs) if self.verbose: - self._get_callback_manager().on_chain_end(outputs) + self.callback_manager.on_chain_end(outputs) self._validate_outputs(outputs) if self.memory is not None: self.memory.save_context(inputs, outputs) diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 8b50cfcaf9b..d3b9990112a 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Union import yaml -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Extra, Field, validator import langchain from langchain.callbacks import get_callback_manager @@ -23,7 +23,7 @@ class BaseLLM(BaseModel, ABC): cache: Optional[bool] = None verbose: bool = Field(default_factory=_get_verbosity) """Whether to print out response text.""" - callback_manager: Optional[BaseCallbackManager] = None + callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) class Config: """Configuration for this pydantic object.""" @@ -31,18 +31,22 @@ class BaseLLM(BaseModel, ABC): extra = Extra.forbid arbitrary_types_allowed = True + @validator("callback_manager", pre=True, always=True) + def set_callback_manager( + cls, callback_manager: Optional[BaseCallbackManager] + ) -> BaseCallbackManager: + """If callback manager is None, set it. + + This allows users to pass in None as context manager, which is a nice UX. + """ + return callback_manager or get_callback_manager() + @abstractmethod def _generate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: """Run the LLM on the given prompts.""" - def _get_callback_manager(self) -> BaseCallbackManager: - """Get the callback manager.""" - if self.callback_manager is not None: - return self.callback_manager - return get_callback_manager() - def generate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: @@ -55,12 +59,12 @@ class BaseLLM(BaseModel, ABC): "Asked to cache, but no cache found at `langchain.cache`." ) if self.verbose: - self._get_callback_manager().on_llm_start( + self.callback_manager.on_llm_start( {"name": self.__class__.__name__}, prompts ) output = self._generate(prompts, stop=stop) if self.verbose: - self._get_callback_manager().on_llm_end(output) + self.callback_manager.on_llm_end(output) return output params = self._llm_dict() params["stop"] = stop @@ -75,11 +79,11 @@ class BaseLLM(BaseModel, ABC): else: missing_prompts.append(prompt) missing_prompt_idxs.append(i) - self._get_callback_manager().on_llm_start( + self.callback_manager.on_llm_start( {"name": self.__class__.__name__}, missing_prompts ) new_results = self._generate(missing_prompts, stop=stop) - self._get_callback_manager().on_llm_end(new_results) + self.callback_manager.on_llm_end(new_results) for i, result in enumerate(new_results.generations): existing_prompts[i] = result prompt = prompts[i]