mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-24 04:36:46 +00:00
refactor
This commit is contained in:
@@ -11,7 +11,8 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction, BaseLanguageModel
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
120
langchain/base_language_model.py
Normal file
120
langchain/base_language_model.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Base class for language models."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, validator, Extra
|
||||
|
||||
from langchain.callbacks import BaseCallbackManager, get_callback_manager
|
||||
from langchain.schema import PromptValue, LLMResult
|
||||
|
||||
import langchain
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
"""Base class for language models."""
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
def generate_prompt(self, prompts: List[PromptValue], stop: Optional[List[str]] = None) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
self.callback_manager.on_llm_start_prompt_value(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
try:
|
||||
output = self._generate_prompt(prompts, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
|
||||
async def agenerate_prompt(self, prompts: List[PromptValue], stop: Optional[List[str]] = None) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start_prompt_value(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start_prompt_value(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
try:
|
||||
output = await self._agenerate_prompt(prompts, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
def _generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
async def _agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
@@ -30,7 +30,7 @@ class BaseCallbackHandler(ABC):
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
def on_llm_start_prompt_value(self, serialized: Dict[str, Any], prompt: PromptValue, **kwargs: Any) -> Any:
|
||||
def on_llm_start_prompt_value(self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
pass
|
||||
|
||||
@@ -132,13 +132,13 @@ class CallbackManager(BaseCallbackManager):
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
def on_llm_start_prompt_value(
|
||||
self, serialized: Dict[str, Any], prompt: PromptValue, verbose: bool = False, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], prompts: List[PromptValue], verbose: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_start_prompt_value(serialized, prompt, **kwargs)
|
||||
handler.on_llm_start_prompt_value(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
@@ -290,7 +290,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start_prompt_value(
|
||||
self, serialized: Dict[str, Any], prompt: PromptValue, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
@@ -359,19 +359,19 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
async def on_llm_start_prompt_value(
|
||||
self, serialized: Dict[str, Any], prompt: PromptValue, verbose: bool = False, **kwargs: Any
|
||||
self, serialized: Dict[str, Any], prompts: List[PromptValue], verbose: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_llm_start_prompt_value):
|
||||
return await handler.on_llm_start_prompt_value(serialized, prompt, **kwargs)
|
||||
return await handler.on_llm_start_prompt_value(serialized, prompts, **kwargs)
|
||||
else:
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
handler.on_llm_start_prompt_value, serialized, prompt, **kwargs
|
||||
handler.on_llm_start_prompt_value, serialized, prompts, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class LLMChain(Chain, BaseModel):
|
||||
@@ -59,12 +60,12 @@ class LLMChain(Chain, BaseModel):
|
||||
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list)
|
||||
return self.llm.generate_prompt(prompts, stop)
|
||||
return self.llm._generate_prompt(prompts, stop)
|
||||
|
||||
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = await self.aprep_prompts(input_list)
|
||||
return await self.llm.agenerate_prompt(prompts, stop)
|
||||
return await self.llm._agenerate_prompt(prompts, stop)
|
||||
|
||||
def prep_prompts(
|
||||
self, input_list: List[Dict[str, Any]]
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class BasePromptSelector(BaseModel, ABC):
|
||||
|
||||
@@ -15,7 +15,7 @@ from langchain.chains.question_answering import (
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
||||
@@ -8,13 +8,13 @@ from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
)
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
@@ -22,26 +22,7 @@ def _get_verbosity() -> bool:
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
"""Base class for chat models."""
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
@@ -56,13 +37,13 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
||||
results = [await self._agenerate(m, stop=stop) for m in messages]
|
||||
return LLMResult(generations=[res.generations for res in results])
|
||||
|
||||
def generate_prompt(
|
||||
def _generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return self.generate(prompt_messages, stop=stop)
|
||||
|
||||
async def agenerate_prompt(
|
||||
async def _agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
|
||||
@@ -5,16 +5,11 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra, Field, validator
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
from langchain.schema import Generation, LLMResult, PromptValue
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
|
||||
|
||||
def get_prompts(
|
||||
@@ -57,9 +52,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
||||
"""LLM wrapper should take in a prompt and return a string."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -67,27 +59,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
@@ -100,13 +71,13 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
def generate_prompt(
|
||||
def _generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return self.generate(prompt_strings, stop=stop)
|
||||
|
||||
async def agenerate_prompt(
|
||||
async def _agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
|
||||
@@ -10,7 +10,8 @@ from langchain.memory.prompt import (
|
||||
)
|
||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class ConversationEntityMemory(BaseChatMemory, BaseModel):
|
||||
|
||||
@@ -12,7 +12,8 @@ from langchain.memory.prompt import (
|
||||
)
|
||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, SystemMessage
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
||||
|
||||
@@ -7,7 +7,8 @@ from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.memory.utils import get_buffer_string
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage
|
||||
from langchain.schema import BaseMessage, SystemMessage
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class SummarizerMixin(BaseModel):
|
||||
|
||||
@@ -100,41 +100,6 @@ class PromptValue(BaseModel, ABC):
|
||||
"""Return prompt as messages."""
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
|
||||
class BaseMemory(BaseModel, ABC):
|
||||
"""Base interface for memory in chains."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user