This commit is contained in:
Ankush Gola
2023-03-09 20:41:50 -08:00
parent ecdfbfe1c7
commit 7d465cbc2f
13 changed files with 151 additions and 109 deletions

View File

@@ -11,7 +11,8 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction, BaseLanguageModel
from langchain.schema import AgentAction
from langchain.base_language_model import BaseLanguageModel
from langchain.tools import BaseTool
FINAL_ANSWER_ACTION = "Final Answer:"

View File

@@ -0,0 +1,120 @@
"""Base class for language models."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from pydantic import BaseModel, Field, validator, Extra
from langchain.callbacks import BaseCallbackManager, get_callback_manager
from langchain.schema import PromptValue, LLMResult
import langchain
def _get_verbosity() -> bool:
return langchain.verbose
class BaseLanguageModel(BaseModel, ABC):
"""Base class for language models."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
def generate_prompt(self, prompts: List[PromptValue], stop: Optional[List[str]] = None) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
self.callback_manager.on_llm_start_prompt_value(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = self._generate_prompt(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
async def agenerate_prompt(self, prompts: List[PromptValue], stop: Optional[List[str]] = None) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start_prompt_value(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
else:
self.callback_manager.on_llm_start_prompt_value(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = await self._agenerate_prompt(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
else:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
@abstractmethod
def _generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def _agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please it install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)

View File

@@ -30,7 +30,7 @@ class BaseCallbackHandler(ABC):
"""Whether to ignore agent callbacks."""
return False
def on_llm_start_prompt_value(self, serialized: Dict[str, Any], prompt: PromptValue, **kwargs: Any) -> Any:
def on_llm_start_prompt_value(self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any) -> Any:
"""Run when LLM starts running."""
pass
@@ -132,13 +132,13 @@ class CallbackManager(BaseCallbackManager):
self.handlers: List[BaseCallbackHandler] = handlers
def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompt: PromptValue, verbose: bool = False, **kwargs: Any
self, serialized: Dict[str, Any], prompts: List[PromptValue], verbose: bool = False, **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_start_prompt_value(serialized, prompt, **kwargs)
handler.on_llm_start_prompt_value(serialized, prompts, **kwargs)
def on_llm_start(
self,
@@ -290,7 +290,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
async def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompt: PromptValue, **kwargs: Any
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
@@ -359,19 +359,19 @@ class AsyncCallbackManager(BaseCallbackManager):
self.handlers: List[BaseCallbackHandler] = handlers
async def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompt: PromptValue, verbose: bool = False, **kwargs: Any
self, serialized: Dict[str, Any], prompts: List[PromptValue], verbose: bool = False, **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_start_prompt_value):
return await handler.on_llm_start_prompt_value(serialized, prompt, **kwargs)
return await handler.on_llm_start_prompt_value(serialized, prompts, **kwargs)
else:
return await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_llm_start_prompt_value, serialized, prompt, **kwargs
handler.on_llm_start_prompt_value, serialized, prompts, **kwargs
),
)

View File

@@ -12,7 +12,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.base_language_model import BaseLanguageModel
from langchain.vectorstores.base import VectorStore

View File

@@ -9,7 +9,8 @@ from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
from langchain.schema import LLMResult, PromptValue
from langchain.base_language_model import BaseLanguageModel
class LLMChain(Chain, BaseModel):
@@ -59,12 +60,12 @@ class LLMChain(Chain, BaseModel):
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
return self.llm.generate_prompt(prompts, stop)
return self.llm._generate_prompt(prompts, stop)
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list)
return await self.llm.agenerate_prompt(prompts, stop)
return await self.llm._agenerate_prompt(prompts, stop)
def prep_prompts(
self, input_list: List[Dict[str, Any]]

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.base_language_model import BaseLanguageModel
class BasePromptSelector(BaseModel, ABC):

View File

@@ -15,7 +15,7 @@ from langchain.chains.question_answering import (
stuff_prompt,
)
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.base_language_model import BaseLanguageModel
class LoadingCallable(Protocol):

View File

@@ -8,13 +8,13 @@ from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import (
AIMessage,
BaseLanguageModel,
BaseMessage,
ChatGeneration,
ChatResult,
LLMResult,
PromptValue,
)
from langchain.base_language_model import BaseLanguageModel
def _get_verbosity() -> bool:
@@ -22,26 +22,7 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
"""Base class for chat models."""
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
@@ -56,13 +37,13 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
results = [await self._agenerate(m, stop=stop) for m in messages]
return LLMResult(generations=[res.generations for res in results])
def generate_prompt(
def _generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop)
async def agenerate_prompt(
async def _agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]

View File

@@ -5,16 +5,11 @@ from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import yaml
from pydantic import BaseModel, Extra, Field, validator
from pydantic import BaseModel, Extra
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
def _get_verbosity() -> bool:
return langchain.verbose
from langchain.schema import Generation, LLMResult, PromptValue
from langchain.base_language_model import BaseLanguageModel
def get_prompts(
@@ -57,9 +52,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string."""
cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
@@ -67,27 +59,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@abstractmethod
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
@@ -100,13 +71,13 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
) -> LLMResult:
"""Run the LLM on the given prompts."""
def generate_prompt(
def _generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop)
async def agenerate_prompt(
async def _agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]

View File

@@ -10,7 +10,8 @@ from langchain.memory.prompt import (
)
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage
from langchain.schema import BaseMessage
from langchain.base_language_model import BaseLanguageModel
class ConversationEntityMemory(BaseChatMemory, BaseModel):

View File

@@ -12,7 +12,8 @@ from langchain.memory.prompt import (
)
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, SystemMessage
from langchain.schema import SystemMessage
from langchain.base_language_model import BaseLanguageModel
class ConversationKGMemory(BaseChatMemory, BaseModel):

View File

@@ -7,7 +7,8 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.memory.utils import get_buffer_string
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage
from langchain.schema import BaseMessage, SystemMessage
from langchain.base_language_model import BaseLanguageModel
class SummarizerMixin(BaseModel):

View File

@@ -100,41 +100,6 @@ class PromptValue(BaseModel, ABC):
"""Return prompt as messages."""
class BaseLanguageModel(BaseModel, ABC):
@abstractmethod
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please it install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
class BaseMemory(BaseModel, ABC):
"""Base interface for memory in chains."""