mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
x
This commit is contained in:
parent
e80834d783
commit
9efc29e3d1
@ -332,9 +332,9 @@ def index(
|
|||||||
uids_to_delete = record_manager.list_keys(before=index_start_dt)
|
uids_to_delete = record_manager.list_keys(before=index_start_dt)
|
||||||
|
|
||||||
if uids_to_delete:
|
if uids_to_delete:
|
||||||
# Then delete from vector store.
|
|
||||||
vector_store.delete(uids_to_delete)
|
|
||||||
# First delete from record store.
|
# First delete from record store.
|
||||||
|
vector_store.delete(uids_to_delete)
|
||||||
|
# Then delete from record manager.
|
||||||
record_manager.delete_keys(uids_to_delete)
|
record_manager.delete_keys(uids_to_delete)
|
||||||
num_deleted = len(uids_to_delete)
|
num_deleted = len(uids_to_delete)
|
||||||
|
|
||||||
|
@ -15,8 +15,10 @@ allow it to work with a variety of SQL as a backend.
|
|||||||
"""
|
"""
|
||||||
import contextlib
|
import contextlib
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Generator, List, Optional, Sequence
|
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
from sqlalchemy import URL
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column,
|
Column,
|
||||||
Engine,
|
Engine,
|
||||||
@ -28,7 +30,6 @@ from sqlalchemy import (
|
|||||||
create_engine,
|
create_engine,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.sqlite import insert
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
namespace: str,
|
namespace: str,
|
||||||
*,
|
*,
|
||||||
engine: Optional[Engine] = None,
|
engine: Optional[Engine] = None,
|
||||||
db_url: Optional[str] = None,
|
db_url: Union[None, str, URL] = None,
|
||||||
engine_kwargs: Optional[Dict[str, Any]] = None,
|
engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the SQLRecordManager.
|
"""Initialize the SQLRecordManager.
|
||||||
@ -114,6 +115,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
raise AssertionError("Something went wrong with configuration of engine.")
|
raise AssertionError("Something went wrong with configuration of engine.")
|
||||||
|
|
||||||
self.engine = _engine
|
self.engine = _engine
|
||||||
|
self.dialect = _engine.dialect.name
|
||||||
self.session_factory = sessionmaker(bind=self.engine)
|
self.session_factory = sessionmaker(bind=self.engine)
|
||||||
|
|
||||||
def create_schema(self) -> None:
|
def create_schema(self) -> None:
|
||||||
@ -145,8 +147,16 @@ class SQLRecordManager(RecordManager):
|
|||||||
# 2440587.5 - constant represents the Julian day number for January 1, 1970
|
# 2440587.5 - constant represents the Julian day number for January 1, 1970
|
||||||
# 86400.0 - constant represents the number of seconds
|
# 86400.0 - constant represents the number of seconds
|
||||||
# in a day (24 hours * 60 minutes * 60 seconds)
|
# in a day (24 hours * 60 minutes * 60 seconds)
|
||||||
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
|
if self.dialect == "sqlite":
|
||||||
|
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
|
||||||
|
elif self.dialect == "postgresql":
|
||||||
|
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
|
||||||
|
|
||||||
dt = session.execute(query).scalar()
|
dt = session.execute(query).scalar()
|
||||||
|
if isinstance(dt, decimal.Decimal):
|
||||||
|
dt = float(dt)
|
||||||
if not isinstance(dt, float):
|
if not isinstance(dt, float):
|
||||||
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
||||||
return dt
|
return dt
|
||||||
@ -191,6 +201,13 @@ class SQLRecordManager(RecordManager):
|
|||||||
for key, group_id in zip(keys, group_ids)
|
for key, group_id in zip(keys, group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if self.dialect == "sqlite":
|
||||||
|
from sqlalchemy.dialects.sqlite import insert
|
||||||
|
elif self.dialect == "postgresql":
|
||||||
|
from sqlalchemy.dialects.sqlite import insert
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
||||||
|
|
||||||
with self._make_session() as session:
|
with self._make_session() as session:
|
||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
|
Loading…
Reference in New Issue
Block a user