mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 06:53:59 +00:00
community: Factorize AstraDB components constructors (#16779)
* Adds `AstraDBEnvironment` class and use it in `AstraDBLoader`, `AstraDBCache`, `AstraDBSemanticCache`, `AstraDBBaseStore` and `AstraDBChatMessageHistory` * Create an `AsyncAstraDB` if we only have an `AstraDB` and vice-versa so: * we always have an instance of `AstraDB` * we always have an instance of `AsyncAstraDB` for recent versions of astrapy * Create collection if not exists in `AstraDBBaseStore` * Some typing improvements Note: `AstraDB` `VectorStore` not using `AstraDBEnvironment` at the moment. This will be done after the `langchain-astradb` package is out.
This commit is contained in:
committed by
GitHub
parent
93366861c7
commit
9d458d089a
@@ -60,12 +60,14 @@ from langchain_core.load.load import loads
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
from langchain_community.utilities.astradb import AstraDBEnvironment
|
||||
from langchain_community.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import momento
|
||||
from astrapy.db import AstraDB
|
||||
from cassandra.cluster import Session as CassandraSession
|
||||
|
||||
|
||||
@@ -1262,7 +1264,7 @@ class AstraDBCache(BaseCache):
|
||||
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@@ -1278,39 +1280,17 @@ class AstraDBCache(BaseCache):
|
||||
namespace (Optional[str]): namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
"""
|
||||
try:
|
||||
from astrapy.db import (
|
||||
AstraDB as LibAstraDB,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
self.namespace = namespace
|
||||
|
||||
if astra_db_client is not None:
|
||||
self.astra_db = astra_db_client
|
||||
else:
|
||||
self.astra_db = LibAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
astra_env = AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
namespace=namespace,
|
||||
)
|
||||
self.astra_db = astra_env.astra_db
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
self.collection_name = collection_name
|
||||
|
||||
@staticmethod
|
||||
def _make_id(prompt: str, llm_string: str) -> str:
|
||||
@@ -1364,7 +1344,7 @@ class AstraDBCache(BaseCache):
|
||||
def delete(self, prompt: str, llm_string: str) -> None:
|
||||
"""Evict from cache if there's an entry."""
|
||||
doc_id = self._make_id(prompt, llm_string)
|
||||
return self.collection.delete_one(doc_id)
|
||||
self.collection.delete_one(doc_id)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache. This is for all LLMs at once."""
|
||||
@@ -1395,7 +1375,7 @@ class AstraDBSemanticCache(BaseCache):
|
||||
collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
embedding: Embeddings,
|
||||
metric: Optional[str] = None,
|
||||
@@ -1423,22 +1403,13 @@ class AstraDBSemanticCache(BaseCache):
|
||||
The default score threshold is tuned to the default metric.
|
||||
Tune it carefully yourself if switching to another distance metric.
|
||||
"""
|
||||
try:
|
||||
from astrapy.db import (
|
||||
AstraDB as LibAstraDB,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' to AstraDB if passing "
|
||||
"'token' and 'api_endpoint'."
|
||||
)
|
||||
astra_env = AstraDBEnvironment(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
namespace=namespace,
|
||||
)
|
||||
self.astra_db = astra_env.astra_db
|
||||
|
||||
self.embedding = embedding
|
||||
self.metric = metric
|
||||
@@ -1457,18 +1428,7 @@ class AstraDBSemanticCache(BaseCache):
|
||||
self.embedding_dimension = self._get_embedding_dimension()
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
self.namespace = namespace
|
||||
|
||||
if astra_db_client is not None:
|
||||
self.astra_db = astra_db_client
|
||||
else:
|
||||
self.astra_db = LibAstraDB(
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
self.collection = self.astra_db.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
dimension=self.embedding_dimension,
|
||||
|
Reference in New Issue
Block a user