switch up defaults (#485)

i kinda like this just because we call `self.callback_manager` so many
times, and thats nicer than `self._get_callback_manager()`?
This commit is contained in:
Harrison Chase 2022-12-29 23:07:55 -05:00 committed by GitHub
parent 52490e2dcd
commit e3edd74eab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 25 deletions

View File

@ -220,7 +220,7 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
if self.verbose: if self.verbose:
self._get_callback_manager().on_agent_end(output.log, color="green") self.callback_manager.on_agent_end(output.log, color="green")
final_output = output.return_values final_output = output.return_values
if self.return_intermediate_steps: if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps final_output["intermediate_steps"] = intermediate_steps
@ -230,7 +230,7 @@ class AgentExecutor(Chain, BaseModel):
if output.tool in name_to_tool_map: if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool] chain = name_to_tool_map[output.tool]
if self.verbose: if self.verbose:
self._get_callback_manager().on_tool_start( self.callback_manager.on_tool_start(
{"name": str(chain)[:60] + "..."}, output, color="green" {"name": str(chain)[:60] + "..."}, output, color="green"
) )
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
@ -238,13 +238,13 @@ class AgentExecutor(Chain, BaseModel):
color = color_mapping[output.tool] color = color_mapping[output.tool]
else: else:
if self.verbose: if self.verbose:
self._get_callback_manager().on_tool_start( self.callback_manager.on_tool_start(
{"name": "N/A"}, output, color="green" {"name": "N/A"}, output, color="green"
) )
observation = f"{output.tool} is not a valid tool, try another one." observation = f"{output.tool} is not a valid tool, try another one."
color = None color = None
if self.verbose: if self.verbose:
self._get_callback_manager().on_tool_end( self.callback_manager.on_tool_end(
observation, observation,
color=color, color=color,
observation_prefix=self.agent.observation_prefix, observation_prefix=self.agent.observation_prefix,

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field from pydantic import BaseModel, Extra, Field, validator
import langchain import langchain
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
@ -44,7 +44,7 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
memory: Optional[Memory] = None memory: Optional[Memory] = None
callback_manager: Optional[BaseCallbackManager] = None callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
verbose: bool = Field( verbose: bool = Field(
default_factory=_get_verbosity default_factory=_get_verbosity
) # Whether to print the response text ) # Whether to print the response text
@ -54,11 +54,15 @@ class Chain(BaseModel, ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def _get_callback_manager(self) -> BaseCallbackManager: @validator("callback_manager", pre=True, always=True)
"""Get the callback manager.""" def set_callback_manager(
if self.callback_manager is not None: cls, callback_manager: Optional[BaseCallbackManager]
return self.callback_manager ) -> BaseCallbackManager:
return get_callback_manager() """If callback manager is None, set it.
This allows users to pass in None as context manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@property @property
@abstractmethod @abstractmethod
@ -120,12 +124,12 @@ class Chain(BaseModel, ABC):
inputs = dict(inputs, **external_context) inputs = dict(inputs, **external_context)
self._validate_inputs(inputs) self._validate_inputs(inputs)
if self.verbose: if self.verbose:
self._get_callback_manager().on_chain_start( self.callback_manager.on_chain_start(
{"name": self.__class__.__name__}, inputs {"name": self.__class__.__name__}, inputs
) )
outputs = self._call(inputs) outputs = self._call(inputs)
if self.verbose: if self.verbose:
self._get_callback_manager().on_chain_end(outputs) self.callback_manager.on_chain_end(outputs)
self._validate_outputs(outputs) self._validate_outputs(outputs)
if self.memory is not None: if self.memory is not None:
self.memory.save_context(inputs, outputs) self.memory.save_context(inputs, outputs)

View File

@ -5,7 +5,7 @@ from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union from typing import Any, Dict, List, Mapping, Optional, Union
import yaml import yaml
from pydantic import BaseModel, Extra, Field from pydantic import BaseModel, Extra, Field, validator
import langchain import langchain
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
@ -23,7 +23,7 @@ class BaseLLM(BaseModel, ABC):
cache: Optional[bool] = None cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity) verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text.""" """Whether to print out response text."""
callback_manager: Optional[BaseCallbackManager] = None callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -31,18 +31,22 @@ class BaseLLM(BaseModel, ABC):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as context manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@abstractmethod @abstractmethod
def _generate( def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
def _get_callback_manager(self) -> BaseCallbackManager:
"""Get the callback manager."""
if self.callback_manager is not None:
return self.callback_manager
return get_callback_manager()
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
@ -55,12 +59,12 @@ class BaseLLM(BaseModel, ABC):
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
if self.verbose: if self.verbose:
self._get_callback_manager().on_llm_start( self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts {"name": self.__class__.__name__}, prompts
) )
output = self._generate(prompts, stop=stop) output = self._generate(prompts, stop=stop)
if self.verbose: if self.verbose:
self._get_callback_manager().on_llm_end(output) self.callback_manager.on_llm_end(output)
return output return output
params = self._llm_dict() params = self._llm_dict()
params["stop"] = stop params["stop"] = stop
@ -75,11 +79,11 @@ class BaseLLM(BaseModel, ABC):
else: else:
missing_prompts.append(prompt) missing_prompts.append(prompt)
missing_prompt_idxs.append(i) missing_prompt_idxs.append(i)
self._get_callback_manager().on_llm_start( self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts {"name": self.__class__.__name__}, missing_prompts
) )
new_results = self._generate(missing_prompts, stop=stop) new_results = self._generate(missing_prompts, stop=stop)
self._get_callback_manager().on_llm_end(new_results) self.callback_manager.on_llm_end(new_results)
for i, result in enumerate(new_results.generations): for i, result in enumerate(new_results.generations):
existing_prompts[i] = result existing_prompts[i] = result
prompt = prompts[i] prompt = prompts[i]