mv base cache to schema (#9953)

if you remove all other imports from langchain.init it exposes a
circular dep
This commit is contained in:
Bagatur 2023-08-30 08:10:51 -07:00 committed by GitHub
parent 9870bfb9cd
commit 9828701de1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 23 deletions

View File

@ -4,7 +4,6 @@ from importlib import metadata
from typing import Optional from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.chains import ( from langchain.chains import (
ConversationChain, ConversationChain,
LLMBashChain, LLMBashChain,
@ -40,6 +39,7 @@ from langchain.prompts import (
Prompt, Prompt,
PromptTemplate, PromptTemplate,
) )
from langchain.schema.cache import BaseCache
from langchain.schema.prompt_template import BasePromptTemplate from langchain.schema.prompt_template import BasePromptTemplate
from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.arxiv import ArxivAPIWrapper
from langchain.utilities.golden_query import GoldenQueryAPIWrapper from langchain.utilities.golden_query import GoldenQueryAPIWrapper

View File

@ -26,7 +26,6 @@ import inspect
import json import json
import logging import logging
import warnings import warnings
from abc import ABC, abstractmethod
from datetime import timedelta from datetime import timedelta
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -35,7 +34,6 @@ from typing import (
Dict, Dict,
List, List,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
Union, Union,
@ -46,17 +44,18 @@ from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from langchain.utils import get_from_env
try: try:
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
except ImportError: except ImportError:
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.load.dump import dumps from langchain.load.dump import dumps
from langchain.load.load import loads from langchain.load.load import loads
from langchain.schema import ChatGeneration, Generation from langchain.schema import ChatGeneration, Generation
from langchain.schema.cache import RETURN_VAL_TYPE, BaseCache
from langchain.utils import get_from_env
from langchain.vectorstores.redis import Redis as RedisVectorstore from langchain.vectorstores.redis import Redis as RedisVectorstore
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
@ -64,8 +63,6 @@ logger = logging.getLogger(__file__)
if TYPE_CHECKING: if TYPE_CHECKING:
import momento import momento
RETURN_VAL_TYPE = Sequence[Generation]
def _hash(_input: str) -> str: def _hash(_input: str) -> str:
"""Use a deterministic hashing approach.""" """Use a deterministic hashing approach."""
@ -105,22 +102,6 @@ def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE:
) )
class BaseCache(ABC):
"""Base interface for cache."""
@abstractmethod
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
@abstractmethod
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
class InMemoryCache(BaseCache): class InMemoryCache(BaseCache):
"""Cache that stores things in memory.""" """Cache that stores things in memory."""

View File

@ -1,5 +1,6 @@
"""**Schemas** are the LangChain Base Classes and Interfaces.""" """**Schemas** are the LangChain Base Classes and Interfaces."""
from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.cache import BaseCache
from langchain.schema.chat_history import BaseChatMessageHistory from langchain.schema.chat_history import BaseChatMessageHistory
from langchain.schema.document import BaseDocumentTransformer, Document from langchain.schema.document import BaseDocumentTransformer, Document
from langchain.schema.exceptions import LangChainException from langchain.schema.exceptions import LangChainException
@ -39,6 +40,7 @@ RUN_KEY = "__run"
Memory = BaseMemory Memory = BaseMemory
__all__ = [ __all__ = [
"BaseCache",
"BaseMemory", "BaseMemory",
"BaseStore", "BaseStore",
"AgentFinish", "AgentFinish",

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence
from langchain.schema.output import Generation
RETURN_VAL_TYPE = Sequence[Generation]
class BaseCache(ABC):
"""Base interface for cache."""
@abstractmethod
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
@abstractmethod
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""