mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Harrison/improve cache (#368)
make it so everything goes through generate, which removes the need for two types of caches
This commit is contained in:
parent
8d0869c6d3
commit
3474f39e21
@ -11,7 +11,7 @@ from langchain.agents.tools import Tool
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
@ -87,7 +87,9 @@ class Agent(Chain, BaseModel, ABC):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent:
|
||||
def from_llm_and_tools(
|
||||
cls, llm: BaseLLM, tools: List[Tool], **kwargs: Any
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
||||
|
@ -6,7 +6,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActDocstoreAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
AGENT_TO_CLASS = {
|
||||
"zero-shot-react-description": ZeroShotAgent,
|
||||
@ -17,7 +17,7 @@ AGENT_TO_CLASS = {
|
||||
|
||||
def initialize_agent(
|
||||
tools: List[Tool],
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
agent: str = "zero-shot-react-description",
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, Tuple
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer: "
|
||||
@ -116,7 +116,9 @@ class MRKLChain(ZeroShotAgent):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent:
|
||||
def from_chains(
|
||||
cls, llm: BaseLLM, chains: List[ChainConfig], **kwargs: Any
|
||||
) -> Agent:
|
||||
"""User friendly way to initialize the MRKL chain.
|
||||
|
||||
This is intended to be an easy way to get up and running with the
|
||||
|
@ -11,7 +11,7 @@ from langchain.agents.tools import Tool
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
@ -123,7 +123,7 @@ class ReActChain(ReActDocstoreAgent):
|
||||
react = ReAct(llm=OpenAI())
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any):
|
||||
def __init__(self, llm: BaseLLM, docstore: Docstore, **kwargs: Any):
|
||||
"""Initialize with the LLM and a docstore."""
|
||||
docstore_explorer = DocstoreExplorer(docstore)
|
||||
tools = [
|
||||
|
@ -5,7 +5,7 @@ from langchain.agents.agent import Agent
|
||||
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.serpapi import SerpAPIWrapper
|
||||
|
||||
@ -72,7 +72,7 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
|
||||
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
||||
def __init__(self, llm: BaseLLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
||||
"""Initialize with just an LLM and a search chain."""
|
||||
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Beta Feature: base interface for cache."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||
from sqlalchemy.engine.base import Engine
|
||||
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from langchain.schema import Generation
|
||||
|
||||
RETURN_VAL_TYPE = Union[List[Generation], str]
|
||||
RETURN_VAL_TYPE = List[Generation]
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
@ -43,15 +43,6 @@ class InMemoryCache(BaseCache):
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class LLMCache(Base): # type: ignore
|
||||
"""SQLite table for simple LLM cache (string only)."""
|
||||
|
||||
__tablename__ = "llm_cache"
|
||||
prompt = Column(String, primary_key=True)
|
||||
llm = Column(String, primary_key=True)
|
||||
response = Column(String)
|
||||
|
||||
|
||||
class FullLLMCache(Base): # type: ignore
|
||||
"""SQLite table for full LLM Cache (all generations)."""
|
||||
|
||||
@ -84,29 +75,16 @@ class SQLAlchemyCache(BaseCache):
|
||||
generations.append(Generation(text=row[0]))
|
||||
if len(generations) > 0:
|
||||
return generations
|
||||
stmt = (
|
||||
select(LLMCache.response)
|
||||
.where(LLMCache.prompt == prompt)
|
||||
.where(LLMCache.llm == llm_string)
|
||||
)
|
||||
with Session(self.engine) as session:
|
||||
for row in session.execute(stmt):
|
||||
return row[0]
|
||||
return None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
if isinstance(return_val, str):
|
||||
item = LLMCache(prompt=prompt, llm=llm_string, response=return_val)
|
||||
for i, generation in enumerate(return_val):
|
||||
item = FullLLMCache(
|
||||
prompt=prompt, llm=llm_string, response=generation.text, idx=i
|
||||
)
|
||||
with Session(self.engine) as session, session.begin():
|
||||
session.add(item)
|
||||
else:
|
||||
for i, generation in enumerate(return_val):
|
||||
item = FullLLMCache(
|
||||
prompt=prompt, llm=llm_string, response=generation.text, idx=i
|
||||
)
|
||||
with Session(self.engine) as session, session.begin():
|
||||
session.add(item)
|
||||
|
||||
|
||||
class SQLiteCache(SQLAlchemyCache):
|
||||
|
@ -9,7 +9,7 @@ from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.requests import RequestsWrapper
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ class APIChain(Chain, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_api_docs(
|
||||
cls, llm: LLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
|
||||
cls, llm: BaseLLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
|
||||
) -> APIChain:
|
||||
"""Load chain from just an LLM and the api docs."""
|
||||
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT)
|
||||
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, root_validator
|
||||
from langchain.chains.base import Memory
|
||||
from langchain.chains.conversation.prompt import SUMMARY_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
@ -88,7 +88,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
|
||||
"""Conversation summarizer to memory."""
|
||||
|
||||
buffer: str = ""
|
||||
llm: LLM
|
||||
llm: BaseLLM
|
||||
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra
|
||||
|
||||
import langchain
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ class LLMChain(Chain, BaseModel):
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: LLM
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
output_key: str = "text" #: :meta private:
|
||||
|
||||
|
@ -7,7 +7,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel):
|
||||
llm_bash = LLMBashChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
@ -14,7 +14,7 @@ from langchain.chains.llm_checker.prompt import (
|
||||
REVISED_ANSWER_PROMPT,
|
||||
)
|
||||
from langchain.chains.sequential import SequentialChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ class LLMCheckerChain(Chain, BaseModel):
|
||||
checker_chain = LLMCheckerChain(llm=llm)
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
|
||||
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
|
||||
|
@ -7,7 +7,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.python import PythonREPL
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ class LLMMathChain(Chain, BaseModel):
|
||||
llm_math = LLMMathChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
@ -15,7 +15,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
@ -32,7 +32,7 @@ class MapReduceChain(Chain, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls, llm: LLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
|
||||
cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
|
||||
) -> MapReduceChain:
|
||||
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
|
@ -8,7 +8,7 @@ from pydantic import BaseModel, Extra
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.natbot.prompt import PROMPT
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.openai import OpenAI
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ class NatBotChain(Chain, BaseModel):
|
||||
natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.")
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
objective: str
|
||||
"""Objective that NatBot is tasked with completing."""
|
||||
|
@ -13,7 +13,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.python import PythonREPL
|
||||
|
||||
@ -21,7 +21,7 @@ from langchain.python import PythonREPL
|
||||
class PALChain(Chain, BaseModel):
|
||||
"""Implements Program-Aided Language Models."""
|
||||
|
||||
llm: LLM
|
||||
llm: BaseLLM
|
||||
prompt: BasePromptTemplate
|
||||
stop: str = "\n\n"
|
||||
get_answer_expr: str = "print(solution())"
|
||||
@ -59,7 +59,7 @@ class PALChain(Chain, BaseModel):
|
||||
return {self.output_key: res.strip()}
|
||||
|
||||
@classmethod
|
||||
def from_math_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain:
|
||||
def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
|
||||
"""Load PAL from math prompt."""
|
||||
return cls(
|
||||
llm=llm,
|
||||
@ -70,7 +70,7 @@ class PALChain(Chain, BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_colored_object_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain:
|
||||
def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
|
||||
"""Load PAL from colored object prompt."""
|
||||
return cls(
|
||||
llm=llm,
|
||||
|
@ -11,19 +11,19 @@ from langchain.chains.qa_with_sources import (
|
||||
refine_prompts,
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_variable_name: str = "summaries",
|
||||
**kwargs: Any,
|
||||
@ -38,7 +38,7 @@ def _load_stuff_chain(
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
||||
@ -72,7 +72,7 @@ def _load_map_reduce_chain(
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
|
||||
@ -93,7 +93,7 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_qa_with_sources_chain(
|
||||
llm: LLM, chain_type: str = "stuff", **kwargs: Any
|
||||
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering with sources chain.
|
||||
|
||||
|
@ -18,7 +18,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
||||
QUESTION_PROMPT,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
|
||||
question_prompt: BasePromptTemplate = QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
|
||||
|
@ -11,19 +11,19 @@ from langchain.chains.question_answering import (
|
||||
refine_prompts,
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_variable_name: str = "context",
|
||||
**kwargs: Any,
|
||||
@ -36,7 +36,7 @@ def _load_stuff_chain(
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||
combine_document_variable_name: str = "summaries",
|
||||
@ -67,7 +67,7 @@ def _load_map_reduce_chain(
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
||||
document_variable_name: str = "context_str",
|
||||
@ -86,7 +86,7 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_qa_chain(
|
||||
llm: LLM, chain_type: str = "stuff", **kwargs: Any
|
||||
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering chain.
|
||||
|
||||
|
@ -7,7 +7,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import PROMPT
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
llm: BaseLLM
|
||||
"""LLM wrapper to use."""
|
||||
database: SQLDatabase
|
||||
"""SQL Database to connect to."""
|
||||
|
@ -7,19 +7,19 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_variable_name: str = "text",
|
||||
**kwargs: Any,
|
||||
@ -32,7 +32,7 @@ def _load_stuff_chain(
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||
combine_document_variable_name: str = "text",
|
||||
@ -63,7 +63,7 @@ def _load_map_reduce_chain(
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: LLM,
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
|
||||
document_variable_name: str = "text",
|
||||
@ -82,7 +82,7 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_summarize_chain(
|
||||
llm: LLM, chain_type: str = "stuff", **kwargs: Any
|
||||
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load summarizing chain.
|
||||
|
||||
|
@ -10,7 +10,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.vector_db_qa.prompt import PROMPT
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
@ -84,7 +84,7 @@ class VectorDBQA(Chain, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: LLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
|
||||
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
|
||||
) -> VectorDBQA:
|
||||
"""Initialize from LLM."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
|
@ -2,7 +2,7 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
@ -10,7 +10,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
|
||||
|
||||
|
||||
def generate_example(
|
||||
examples: List[dict], llm: LLM, prompt_template: PromptTemplate
|
||||
examples: List[dict], llm: BaseLLM, prompt_template: PromptTemplate
|
||||
) -> str:
|
||||
"""Return another example given a list of examples for a prompt."""
|
||||
prompt = FewShotPromptTemplate(
|
||||
|
@ -2,7 +2,7 @@
|
||||
from typing import Dict, Type
|
||||
|
||||
from langchain.llms.ai21 import AI21
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.cohere import Cohere
|
||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
@ -18,7 +18,7 @@ __all__ = [
|
||||
"AI21",
|
||||
]
|
||||
|
||||
type_to_cls_dict: Dict[str, Type[LLM]] = {
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"ai21": AI21,
|
||||
"cohere": Cohere,
|
||||
"huggingface_hub": HuggingFaceHub,
|
||||
|
@ -21,7 +21,7 @@ class LLMResult(NamedTuple):
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
|
||||
|
||||
class LLM(BaseModel, ABC):
|
||||
class BaseLLM(BaseModel, ABC):
|
||||
"""LLM wrapper should take in a prompt and return a string."""
|
||||
|
||||
class Config:
|
||||
@ -29,16 +29,11 @@ class LLM(BaseModel, ABC):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# TODO: add caching here.
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
text = self(prompt, stop=stop)
|
||||
generations.append([Generation(text=text)])
|
||||
return LLMResult(generations=generations)
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
@ -88,28 +83,9 @@ class LLM(BaseModel, ABC):
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
@abstractmethod
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
if langchain.llm_cache is None:
|
||||
return self._call(prompt, stop=stop)
|
||||
params = self._llm_dict()
|
||||
params["stop"] = stop
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
if langchain.cache is not None:
|
||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
||||
if cache_val is not None:
|
||||
if isinstance(cache_val, str):
|
||||
return cache_val
|
||||
else:
|
||||
return cache_val[0].text
|
||||
return_val = self._call(prompt, stop=stop)
|
||||
if langchain.cache is not None:
|
||||
langchain.llm_cache.update(prompt, llm_string, return_val)
|
||||
return return_val
|
||||
return self.generate([prompt], stop=stop).generations[0][0].text
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
@ -163,3 +139,26 @@ class LLM(BaseModel, ABC):
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
"""LLM class that expect subclasses to implement a simpler call method.
|
||||
|
||||
The purpose of this class is to expose a simpler interface for working
|
||||
with LLMs, rather than expect the user to implement the full _generate method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# TODO: add caching here.
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
text = self._call(prompt, stop=stop)
|
||||
generations.append([Generation(text=text)])
|
||||
return LLMResult(generations=generations)
|
||||
|
@ -6,10 +6,10 @@ from typing import Union
|
||||
import yaml
|
||||
|
||||
from langchain.llms import type_to_cls_dict
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
|
||||
def load_llm_from_config(config: dict) -> LLM:
|
||||
def load_llm_from_config(config: dict) -> BaseLLM:
|
||||
"""Load LLM from Config Dict."""
|
||||
if "_type" not in config:
|
||||
raise ValueError("Must specify an LLM Type in config")
|
||||
@ -22,7 +22,7 @@ def load_llm_from_config(config: dict) -> LLM:
|
||||
return llm_cls(**config)
|
||||
|
||||
|
||||
def load_llm(file: Union[str, Path]) -> LLM:
|
||||
def load_llm(file: Union[str, Path]) -> BaseLLM:
|
||||
"""Load LLM from file."""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file, str):
|
||||
|
@ -4,12 +4,12 @@ from typing import Any, Dict, Generator, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.llms.base import LLM, LLMResult
|
||||
from langchain.llms.base import BaseLLM, LLMResult
|
||||
from langchain.schema import Generation
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class OpenAI(LLM, BaseModel):
|
||||
class OpenAI(BaseLLM, BaseModel):
|
||||
"""Wrapper around OpenAI large language models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
@ -197,23 +197,6 @@ class OpenAI(LLM, BaseModel):
|
||||
"""Return type of llm."""
|
||||
return "openai"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Call out to OpenAI's create endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = openai("Tell me a joke.")
|
||||
"""
|
||||
return self.generate([prompt], stop=stop).generations[0][0].text
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate num tokens with tiktoken package."""
|
||||
# tiktoken NOT supported for Python 3.8 or below
|
||||
|
@ -6,7 +6,7 @@ from typing import List, Optional, Sequence
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping, print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ class ModelLaboratory:
|
||||
|
||||
@classmethod
|
||||
def from_llms(
|
||||
cls, llms: List[LLM], prompt: Optional[PromptTemplate] = None
|
||||
cls, llms: List[BaseLLM], prompt: Optional[PromptTemplate] = None
|
||||
) -> ModelLaboratory:
|
||||
"""Initialize with LLMs to experiment with and optional prompt.
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
"""Utils for LLM Tests."""
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
|
||||
def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None:
|
||||
def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
|
||||
"""Assert LLM Equality for tests."""
|
||||
# Check that they are the same type.
|
||||
assert type(llm) == type(loaded_llm)
|
||||
|
Loading…
Reference in New Issue
Block a user