mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 17:08:47 +00:00
Harrison/improve cache (#368)
make it so everything goes through generate, which removes the need for two types of caches
This commit is contained in:
parent
8d0869c6d3
commit
3474f39e21
@ -11,7 +11,7 @@ from langchain.agents.tools import Tool
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping
|
from langchain.input import get_color_mapping
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import AgentAction
|
from langchain.schema import AgentAction
|
||||||
|
|
||||||
@ -87,7 +87,9 @@ class Agent(Chain, BaseModel, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent:
|
def from_llm_and_tools(
|
||||||
|
cls, llm: BaseLLM, tools: List[Tool], **kwargs: Any
|
||||||
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
||||||
|
@ -6,7 +6,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
|
|||||||
from langchain.agents.react.base import ReActDocstoreAgent
|
from langchain.agents.react.base import ReActDocstoreAgent
|
||||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
AGENT_TO_CLASS = {
|
AGENT_TO_CLASS = {
|
||||||
"zero-shot-react-description": ZeroShotAgent,
|
"zero-shot-react-description": ZeroShotAgent,
|
||||||
@ -17,7 +17,7 @@ AGENT_TO_CLASS = {
|
|||||||
|
|
||||||
def initialize_agent(
|
def initialize_agent(
|
||||||
tools: List[Tool],
|
tools: List[Tool],
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
agent: str = "zero-shot-react-description",
|
agent: str = "zero-shot-react-description",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, Tuple
|
|||||||
from langchain.agents.agent import Agent
|
from langchain.agents.agent import Agent
|
||||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
FINAL_ANSWER_ACTION = "Final Answer: "
|
FINAL_ANSWER_ACTION = "Final Answer: "
|
||||||
@ -116,7 +116,9 @@ class MRKLChain(ZeroShotAgent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent:
|
def from_chains(
|
||||||
|
cls, llm: BaseLLM, chains: List[ChainConfig], **kwargs: Any
|
||||||
|
) -> Agent:
|
||||||
"""User friendly way to initialize the MRKL chain.
|
"""User friendly way to initialize the MRKL chain.
|
||||||
|
|
||||||
This is intended to be an easy way to get up and running with the
|
This is intended to be an easy way to get up and running with the
|
||||||
|
@ -11,7 +11,7 @@ from langchain.agents.tools import Tool
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.base import Docstore
|
from langchain.docstore.base import Docstore
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
@ -123,7 +123,7 @@ class ReActChain(ReActDocstoreAgent):
|
|||||||
react = ReAct(llm=OpenAI())
|
react = ReAct(llm=OpenAI())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any):
|
def __init__(self, llm: BaseLLM, docstore: Docstore, **kwargs: Any):
|
||||||
"""Initialize with the LLM and a docstore."""
|
"""Initialize with the LLM and a docstore."""
|
||||||
docstore_explorer = DocstoreExplorer(docstore)
|
docstore_explorer = DocstoreExplorer(docstore)
|
||||||
tools = [
|
tools = [
|
||||||
|
@ -5,7 +5,7 @@ from langchain.agents.agent import Agent
|
|||||||
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.serpapi import SerpAPIWrapper
|
from langchain.serpapi import SerpAPIWrapper
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
|
|||||||
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
def __init__(self, llm: BaseLLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
||||||
"""Initialize with just an LLM and a search chain."""
|
"""Initialize with just an LLM and a search chain."""
|
||||||
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
||||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Beta Feature: base interface for cache."""
|
"""Beta Feature: base interface for cache."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from langchain.schema import Generation
|
from langchain.schema import Generation
|
||||||
|
|
||||||
RETURN_VAL_TYPE = Union[List[Generation], str]
|
RETURN_VAL_TYPE = List[Generation]
|
||||||
|
|
||||||
|
|
||||||
class BaseCache(ABC):
|
class BaseCache(ABC):
|
||||||
@ -43,15 +43,6 @@ class InMemoryCache(BaseCache):
|
|||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class LLMCache(Base): # type: ignore
|
|
||||||
"""SQLite table for simple LLM cache (string only)."""
|
|
||||||
|
|
||||||
__tablename__ = "llm_cache"
|
|
||||||
prompt = Column(String, primary_key=True)
|
|
||||||
llm = Column(String, primary_key=True)
|
|
||||||
response = Column(String)
|
|
||||||
|
|
||||||
|
|
||||||
class FullLLMCache(Base): # type: ignore
|
class FullLLMCache(Base): # type: ignore
|
||||||
"""SQLite table for full LLM Cache (all generations)."""
|
"""SQLite table for full LLM Cache (all generations)."""
|
||||||
|
|
||||||
@ -84,23 +75,10 @@ class SQLAlchemyCache(BaseCache):
|
|||||||
generations.append(Generation(text=row[0]))
|
generations.append(Generation(text=row[0]))
|
||||||
if len(generations) > 0:
|
if len(generations) > 0:
|
||||||
return generations
|
return generations
|
||||||
stmt = (
|
|
||||||
select(LLMCache.response)
|
|
||||||
.where(LLMCache.prompt == prompt)
|
|
||||||
.where(LLMCache.llm == llm_string)
|
|
||||||
)
|
|
||||||
with Session(self.engine) as session:
|
|
||||||
for row in session.execute(stmt):
|
|
||||||
return row[0]
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Look up based on prompt and llm_string."""
|
"""Look up based on prompt and llm_string."""
|
||||||
if isinstance(return_val, str):
|
|
||||||
item = LLMCache(prompt=prompt, llm=llm_string, response=return_val)
|
|
||||||
with Session(self.engine) as session, session.begin():
|
|
||||||
session.add(item)
|
|
||||||
else:
|
|
||||||
for i, generation in enumerate(return_val):
|
for i, generation in enumerate(return_val):
|
||||||
item = FullLLMCache(
|
item = FullLLMCache(
|
||||||
prompt=prompt, llm=llm_string, response=generation.text, idx=i
|
prompt=prompt, llm=llm_string, response=generation.text, idx=i
|
||||||
|
@ -9,7 +9,7 @@ from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.requests import RequestsWrapper
|
from langchain.requests import RequestsWrapper
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ class APIChain(Chain, BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_api_docs(
|
def from_llm_and_api_docs(
|
||||||
cls, llm: LLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
|
cls, llm: BaseLLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
|
||||||
) -> APIChain:
|
) -> APIChain:
|
||||||
"""Load chain from just an LLM and the api docs."""
|
"""Load chain from just an LLM and the api docs."""
|
||||||
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT)
|
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT)
|
||||||
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, root_validator
|
|||||||
from langchain.chains.base import Memory
|
from langchain.chains.base import Memory
|
||||||
from langchain.chains.conversation.prompt import SUMMARY_PROMPT
|
from langchain.chains.conversation.prompt import SUMMARY_PROMPT
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
|
|||||||
"""Conversation summarizer to memory."""
|
"""Conversation summarizer to memory."""
|
||||||
|
|
||||||
buffer: str = ""
|
buffer: str = ""
|
||||||
llm: LLM
|
llm: BaseLLM
|
||||||
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
||||||
memory_key: str = "history" #: :meta private:
|
memory_key: str = "history" #: :meta private:
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra
|
|||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ class LLMChain(Chain, BaseModel):
|
|||||||
|
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
"""Prompt object to use."""
|
"""Prompt object to use."""
|
||||||
llm: LLM
|
llm: BaseLLM
|
||||||
"""LLM wrapper to use."""
|
"""LLM wrapper to use."""
|
||||||
output_key: str = "text" #: :meta private:
|
output_key: str = "text" #: :meta private:
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_bash.prompt import PROMPT
|
from langchain.chains.llm_bash.prompt import PROMPT
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.utilities.bash import BashProcess
|
from langchain.utilities.bash import BashProcess
|
||||||
|
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel):
|
|||||||
llm_bash = LLMBashChain(llm=OpenAI())
|
llm_bash = LLMBashChain(llm=OpenAI())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm: LLM
|
llm: BaseLLM
|
||||||
"""LLM wrapper to use."""
|
"""LLM wrapper to use."""
|
||||||
input_key: str = "question" #: :meta private:
|
input_key: str = "question" #: :meta private:
|
||||||
output_key: str = "answer" #: :meta private:
|
output_key: str = "answer" #: :meta private:
|
||||||
|
@ -14,7 +14,7 @@ from langchain.chains.llm_checker.prompt import (
|
|||||||
REVISED_ANSWER_PROMPT,
|
REVISED_ANSWER_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.chains.sequential import SequentialChain
|
from langchain.chains.sequential import SequentialChain
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ class LLMCheckerChain(Chain, BaseModel):
|
|||||||
checker_chain = LLMCheckerChain(llm=llm)
|
checker_chain = LLMCheckerChain(llm=llm)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm: LLM
|
llm: BaseLLM
|
||||||
"""LLM wrapper to use."""
|
"""LLM wrapper to use."""
|
||||||
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
|
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
|
||||||
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
|
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
|
||||||
|
@ -7,7 +7,7 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_math.prompt import PROMPT
|
from langchain.chains.llm_math.prompt import PROMPT
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.python import PythonREPL
|
from langchain.python import PythonREPL
|
||||||
|
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class LLMMathChain(Chain, BaseModel):
|
|||||||
llm_math = LLMMathChain(llm=OpenAI())
|
llm_math = LLMMathChain(llm=OpenAI())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm: LLM
|
llm: BaseLLM
|
||||||
"""LLM wrapper to use."""
|
"""LLM wrapper to use."""
|
||||||
input_key: str = "question" #: :meta private:
|
input_key: str = "question" #: :meta private:
|
||||||
output_key: str = "answer" #: :meta private:
|
output_key: str = "answer" #: :meta private:
|
||||||
|
@ -15,7 +15,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
|
|||||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.text_splitter import TextSplitter
|
from langchain.text_splitter import TextSplitter
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ class MapReduceChain(Chain, BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_params(
|
def from_params(
|
||||||
cls, llm: LLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
|
cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
|
||||||
) -> MapReduceChain:
|
) -> MapReduceChain:
|
||||||
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
@ -8,7 +8,7 @@ from pydantic import BaseModel, Extra
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.natbot.prompt import PROMPT
|
from langchain.chains.natbot.prompt import PROMPT
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ class NatBotChain(Chain, BaseModel):
|
|||||||
natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.")
|
natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm: LLM
|
llm: BaseLLM
|
||||||
"""LLM wrapper to use."""
|
"""LLM wrapper to use."""
|
||||||
objective: str
|
objective: str
|
||||||
"""Objective that NatBot is tasked with completing."""
|
"""Objective that NatBot is tasked with completing."""
|
||||||
|
@ -13,7 +13,7 @@ from langchain.chains.llm import LLMChain
|
|||||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.python import PythonREPL
|
from langchain.python import PythonREPL
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ from langchain.python import PythonREPL
|
|||||||
class PALChain(Chain, BaseModel):
|
class PALChain(Chain, BaseModel):
|
||||||
"""Implements Program-Aided Language Models."""
|
"""Implements Program-Aided Language Models."""
|
||||||
|
|
||||||
llm: LLM
|
llm: BaseLLM
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
stop: str = "\n\n"
|
stop: str = "\n\n"
|
||||||
get_answer_expr: str = "print(solution())"
|
get_answer_expr: str = "print(solution())"
|
||||||
@ -59,7 +59,7 @@ class PALChain(Chain, BaseModel):
|
|||||||
return {self.output_key: res.strip()}
|
return {self.output_key: res.strip()}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_math_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain:
|
def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
|
||||||
"""Load PAL from math prompt."""
|
"""Load PAL from math prompt."""
|
||||||
return cls(
|
return cls(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
@ -70,7 +70,7 @@ class PALChain(Chain, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_colored_object_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain:
|
def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
|
||||||
"""Load PAL from colored object prompt."""
|
"""Load PAL from colored object prompt."""
|
||||||
return cls(
|
return cls(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
@ -11,19 +11,19 @@ from langchain.chains.qa_with_sources import (
|
|||||||
refine_prompts,
|
refine_prompts,
|
||||||
stuff_prompt,
|
stuff_prompt,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class LoadingCallable(Protocol):
|
class LoadingCallable(Protocol):
|
||||||
"""Interface for loading the combine documents chain."""
|
"""Interface for loading the combine documents chain."""
|
||||||
|
|
||||||
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||||
"""Callable to load the combine documents chain."""
|
"""Callable to load the combine documents chain."""
|
||||||
|
|
||||||
|
|
||||||
def _load_stuff_chain(
|
def _load_stuff_chain(
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||||
document_variable_name: str = "summaries",
|
document_variable_name: str = "summaries",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -38,7 +38,7 @@ def _load_stuff_chain(
|
|||||||
|
|
||||||
|
|
||||||
def _load_map_reduce_chain(
|
def _load_map_reduce_chain(
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
||||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||||
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
||||||
@ -72,7 +72,7 @@ def _load_map_reduce_chain(
|
|||||||
|
|
||||||
|
|
||||||
def _load_refine_chain(
|
def _load_refine_chain(
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
||||||
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
||||||
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
|
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
|
||||||
@ -93,7 +93,7 @@ def _load_refine_chain(
|
|||||||
|
|
||||||
|
|
||||||
def load_qa_with_sources_chain(
|
def load_qa_with_sources_chain(
|
||||||
llm: LLM, chain_type: str = "stuff", **kwargs: Any
|
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
|
||||||
) -> BaseCombineDocumentsChain:
|
) -> BaseCombineDocumentsChain:
|
||||||
"""Load question answering with sources chain.
|
"""Load question answering with sources chain.
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
|||||||
QUESTION_PROMPT,
|
QUESTION_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls,
|
cls,
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
|
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
|
||||||
question_prompt: BasePromptTemplate = QUESTION_PROMPT,
|
question_prompt: BasePromptTemplate = QUESTION_PROMPT,
|
||||||
combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
|
combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
|
||||||
|
@ -11,19 +11,19 @@ from langchain.chains.question_answering import (
|
|||||||
refine_prompts,
|
refine_prompts,
|
||||||
stuff_prompt,
|
stuff_prompt,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class LoadingCallable(Protocol):
|
class LoadingCallable(Protocol):
|
||||||
"""Interface for loading the combine documents chain."""
|
"""Interface for loading the combine documents chain."""
|
||||||
|
|
||||||
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||||
"""Callable to load the combine documents chain."""
|
"""Callable to load the combine documents chain."""
|
||||||
|
|
||||||
|
|
||||||
def _load_stuff_chain(
|
def _load_stuff_chain(
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||||
document_variable_name: str = "context",
|
document_variable_name: str = "context",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -36,7 +36,7 @@ def _load_stuff_chain(
|
|||||||
|
|
||||||
|
|
||||||
def _load_map_reduce_chain(
|
def _load_map_reduce_chain(
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
||||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||||
combine_document_variable_name: str = "summaries",
|
combine_document_variable_name: str = "summaries",
|
||||||
@ -67,7 +67,7 @@ def _load_map_reduce_chain(
|
|||||||
|
|
||||||
|
|
||||||
def _load_refine_chain(
|
def _load_refine_chain(
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
||||||
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
||||||
document_variable_name: str = "context_str",
|
document_variable_name: str = "context_str",
|
||||||
@ -86,7 +86,7 @@ def _load_refine_chain(
|
|||||||
|
|
||||||
|
|
||||||
def load_qa_chain(
|
def load_qa_chain(
|
||||||
llm: LLM, chain_type: str = "stuff", **kwargs: Any
|
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
|
||||||
) -> BaseCombineDocumentsChain:
|
) -> BaseCombineDocumentsChain:
|
||||||
"""Load question answering chain.
|
"""Load question answering chain.
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.sql_database.prompt import PROMPT
|
from langchain.chains.sql_database.prompt import PROMPT
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
|
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
|||||||
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
|
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm: LLM
|
llm: BaseLLM
|
||||||
"""LLM wrapper to use."""
|
"""LLM wrapper to use."""
|
||||||
database: SQLDatabase
|
database: SQLDatabase
|
||||||
"""SQL Database to connect to."""
|
"""SQL Database to connect to."""
|
||||||
|
@ -7,19 +7,19 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
|||||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class LoadingCallable(Protocol):
|
class LoadingCallable(Protocol):
|
||||||
"""Interface for loading the combine documents chain."""
|
"""Interface for loading the combine documents chain."""
|
||||||
|
|
||||||
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||||
"""Callable to load the combine documents chain."""
|
"""Callable to load the combine documents chain."""
|
||||||
|
|
||||||
|
|
||||||
def _load_stuff_chain(
|
def _load_stuff_chain(
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||||
document_variable_name: str = "text",
|
document_variable_name: str = "text",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -32,7 +32,7 @@ def _load_stuff_chain(
|
|||||||
|
|
||||||
|
|
||||||
def _load_map_reduce_chain(
|
def _load_map_reduce_chain(
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||||
combine_document_variable_name: str = "text",
|
combine_document_variable_name: str = "text",
|
||||||
@ -63,7 +63,7 @@ def _load_map_reduce_chain(
|
|||||||
|
|
||||||
|
|
||||||
def _load_refine_chain(
|
def _load_refine_chain(
|
||||||
llm: LLM,
|
llm: BaseLLM,
|
||||||
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
|
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
|
||||||
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
|
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
|
||||||
document_variable_name: str = "text",
|
document_variable_name: str = "text",
|
||||||
@ -82,7 +82,7 @@ def _load_refine_chain(
|
|||||||
|
|
||||||
|
|
||||||
def load_summarize_chain(
|
def load_summarize_chain(
|
||||||
llm: LLM, chain_type: str = "stuff", **kwargs: Any
|
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
|
||||||
) -> BaseCombineDocumentsChain:
|
) -> BaseCombineDocumentsChain:
|
||||||
"""Load summarizing chain.
|
"""Load summarizing chain.
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|||||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.vector_db_qa.prompt import PROMPT
|
from langchain.chains.vector_db_qa.prompt import PROMPT
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ class VectorDBQA(Chain, BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls, llm: LLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
|
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
|
||||||
) -> VectorDBQA:
|
) -> VectorDBQA:
|
||||||
"""Initialize from LLM."""
|
"""Initialize from LLM."""
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
|
|||||||
|
|
||||||
|
|
||||||
def generate_example(
|
def generate_example(
|
||||||
examples: List[dict], llm: LLM, prompt_template: PromptTemplate
|
examples: List[dict], llm: BaseLLM, prompt_template: PromptTemplate
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return another example given a list of examples for a prompt."""
|
"""Return another example given a list of examples for a prompt."""
|
||||||
prompt = FewShotPromptTemplate(
|
prompt = FewShotPromptTemplate(
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from typing import Dict, Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
from langchain.llms.ai21 import AI21
|
from langchain.llms.ai21 import AI21
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.llms.cohere import Cohere
|
from langchain.llms.cohere import Cohere
|
||||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
@ -18,7 +18,7 @@ __all__ = [
|
|||||||
"AI21",
|
"AI21",
|
||||||
]
|
]
|
||||||
|
|
||||||
type_to_cls_dict: Dict[str, Type[LLM]] = {
|
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||||
"ai21": AI21,
|
"ai21": AI21,
|
||||||
"cohere": Cohere,
|
"cohere": Cohere,
|
||||||
"huggingface_hub": HuggingFaceHub,
|
"huggingface_hub": HuggingFaceHub,
|
||||||
|
@ -21,7 +21,7 @@ class LLMResult(NamedTuple):
|
|||||||
"""For arbitrary LLM provider specific output."""
|
"""For arbitrary LLM provider specific output."""
|
||||||
|
|
||||||
|
|
||||||
class LLM(BaseModel, ABC):
|
class BaseLLM(BaseModel, ABC):
|
||||||
"""LLM wrapper should take in a prompt and return a string."""
|
"""LLM wrapper should take in a prompt and return a string."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -29,16 +29,11 @@ class LLM(BaseModel, ABC):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def _generate(
|
def _generate(
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompts."""
|
||||||
# TODO: add caching here.
|
|
||||||
generations = []
|
|
||||||
for prompt in prompts:
|
|
||||||
text = self(prompt, stop=stop)
|
|
||||||
generations.append([Generation(text=text)])
|
|
||||||
return LLMResult(generations=generations)
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
@ -88,28 +83,9 @@ class LLM(BaseModel, ABC):
|
|||||||
# calculate the number of tokens in the tokenized text
|
# calculate the number of tokens in the tokenized text
|
||||||
return len(tokenized_text)
|
return len(tokenized_text)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
||||||
"""Run the LLM on the given prompt and input."""
|
|
||||||
|
|
||||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
"""Check Cache and run the LLM on the given prompt and input."""
|
"""Check Cache and run the LLM on the given prompt and input."""
|
||||||
if langchain.llm_cache is None:
|
return self.generate([prompt], stop=stop).generations[0][0].text
|
||||||
return self._call(prompt, stop=stop)
|
|
||||||
params = self._llm_dict()
|
|
||||||
params["stop"] = stop
|
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
||||||
if langchain.cache is not None:
|
|
||||||
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
|
|
||||||
if cache_val is not None:
|
|
||||||
if isinstance(cache_val, str):
|
|
||||||
return cache_val
|
|
||||||
else:
|
|
||||||
return cache_val[0].text
|
|
||||||
return_val = self._call(prompt, stop=stop)
|
|
||||||
if langchain.cache is not None:
|
|
||||||
langchain.llm_cache.update(prompt, llm_string, return_val)
|
|
||||||
return return_val
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
@ -163,3 +139,26 @@ class LLM(BaseModel, ABC):
|
|||||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{save_path} must be json or yaml")
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
|
|
||||||
|
class LLM(BaseLLM):
|
||||||
|
"""LLM class that expect subclasses to implement a simpler call method.
|
||||||
|
|
||||||
|
The purpose of this class is to expose a simpler interface for working
|
||||||
|
with LLMs, rather than expect the user to implement the full _generate method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
"""Run the LLM on the given prompt and input."""
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Run the LLM on the given prompt and input."""
|
||||||
|
# TODO: add caching here.
|
||||||
|
generations = []
|
||||||
|
for prompt in prompts:
|
||||||
|
text = self._call(prompt, stop=stop)
|
||||||
|
generations.append([Generation(text=text)])
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
@ -6,10 +6,10 @@ from typing import Union
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from langchain.llms import type_to_cls_dict
|
from langchain.llms import type_to_cls_dict
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
def load_llm_from_config(config: dict) -> LLM:
|
def load_llm_from_config(config: dict) -> BaseLLM:
|
||||||
"""Load LLM from Config Dict."""
|
"""Load LLM from Config Dict."""
|
||||||
if "_type" not in config:
|
if "_type" not in config:
|
||||||
raise ValueError("Must specify an LLM Type in config")
|
raise ValueError("Must specify an LLM Type in config")
|
||||||
@ -22,7 +22,7 @@ def load_llm_from_config(config: dict) -> LLM:
|
|||||||
return llm_cls(**config)
|
return llm_cls(**config)
|
||||||
|
|
||||||
|
|
||||||
def load_llm(file: Union[str, Path]) -> LLM:
|
def load_llm(file: Union[str, Path]) -> BaseLLM:
|
||||||
"""Load LLM from file."""
|
"""Load LLM from file."""
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file, str):
|
if isinstance(file, str):
|
||||||
|
@ -4,12 +4,12 @@ from typing import Any, Dict, Generator, List, Mapping, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
from langchain.llms.base import LLM, LLMResult
|
from langchain.llms.base import BaseLLM, LLMResult
|
||||||
from langchain.schema import Generation
|
from langchain.schema import Generation
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(LLM, BaseModel):
|
class OpenAI(BaseLLM, BaseModel):
|
||||||
"""Wrapper around OpenAI large language models.
|
"""Wrapper around OpenAI large language models.
|
||||||
|
|
||||||
To use, you should have the ``openai`` python package installed, and the
|
To use, you should have the ``openai`` python package installed, and the
|
||||||
@ -197,23 +197,6 @@ class OpenAI(LLM, BaseModel):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "openai"
|
return "openai"
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
||||||
"""Call out to OpenAI's create endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The prompt to pass into the model.
|
|
||||||
stop: Optional list of stop words to use when generating.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The string generated by the model.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
response = openai("Tell me a joke.")
|
|
||||||
"""
|
|
||||||
return self.generate([prompt], stop=stop).generations[0][0].text
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""Calculate num tokens with tiktoken package."""
|
"""Calculate num tokens with tiktoken package."""
|
||||||
# tiktoken NOT supported for Python 3.8 or below
|
# tiktoken NOT supported for Python 3.8 or below
|
||||||
|
@ -6,7 +6,7 @@ from typing import List, Optional, Sequence
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping, print_text
|
from langchain.input import get_color_mapping, print_text
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ class ModelLaboratory:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llms(
|
def from_llms(
|
||||||
cls, llms: List[LLM], prompt: Optional[PromptTemplate] = None
|
cls, llms: List[BaseLLM], prompt: Optional[PromptTemplate] = None
|
||||||
) -> ModelLaboratory:
|
) -> ModelLaboratory:
|
||||||
"""Initialize with LLMs to experiment with and optional prompt.
|
"""Initialize with LLMs to experiment with and optional prompt.
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
"""Utils for LLM Tests."""
|
"""Utils for LLM Tests."""
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None:
|
def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
|
||||||
"""Assert LLM Equality for tests."""
|
"""Assert LLM Equality for tests."""
|
||||||
# Check that they are the same type.
|
# Check that they are the same type.
|
||||||
assert type(llm) == type(loaded_llm)
|
assert type(llm) == type(loaded_llm)
|
||||||
|
Loading…
Reference in New Issue
Block a user