diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index 47b9d33ea8f..130a5c685db 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -332,9 +332,9 @@ def index( uids_to_delete = record_manager.list_keys(before=index_start_dt) if uids_to_delete: - # Then delete from vector store. - vector_store.delete(uids_to_delete) # First delete from record store. + vector_store.delete(uids_to_delete) + # Then delete from record manager. record_manager.delete_keys(uids_to_delete) num_deleted = len(uids_to_delete) diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index 9cad02ef93b..f47f4e92397 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -14,10 +14,12 @@ allow it to work with a variety of SQL as a backend. * Keys can be deleted. """ import contextlib +import decimal import uuid -from typing import Any, Dict, Generator, List, Optional, Sequence +from typing import Any, Dict, Generator, List, Optional, Sequence, Union from sqlalchemy import ( + URL, Column, Engine, Float, @@ -28,7 +30,6 @@ from sqlalchemy import ( create_engine, text, ) -from sqlalchemy.dialects.sqlite import insert from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker @@ -77,7 +78,7 @@ class SQLRecordManager(RecordManager): namespace: str, *, engine: Optional[Engine] = None, - db_url: Optional[str] = None, + db_url: Union[None, str, URL] = None, engine_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the SQLRecordManager. @@ -114,6 +115,7 @@ class SQLRecordManager(RecordManager): raise AssertionError("Something went wrong with configuration of engine.") self.engine = _engine + self.dialect = _engine.dialect.name self.session_factory = sessionmaker(bind=self.engine) def create_schema(self) -> None: @@ -145,8 +147,16 @@ class SQLRecordManager(RecordManager): # 2440587.5 - constant represents the Julian day number for January 1, 1970 # 86400.0 - constant represents the number of seconds # in a day (24 hours * 60 minutes * 60 seconds) - query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") + if self.dialect == "sqlite": + query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") + elif self.dialect == "postgresql": + query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);") + else: + raise NotImplementedError(f"Not implemented for dialect {self.dialect}") + dt = session.execute(query).scalar() + if isinstance(dt, decimal.Decimal): + dt = float(dt) if not isinstance(dt, float): raise AssertionError(f"Unexpected type for datetime: {type(dt)}") return dt @@ -192,17 +202,37 @@ class SQLRecordManager(RecordManager): ] with self._make_session() as session: - # Note: uses SQLite insert to make on_conflict_do_update work. - # This code needs to be generalized a bit to work with more dialects. - insert_stmt = insert(UpsertionRecord).values(records_to_upsert) - stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] - [UpsertionRecord.key, UpsertionRecord.namespace], - set_=dict( - # attr-defined type ignore - updated_at=insert_stmt.excluded.updated_at, # type: ignore - group_id=insert_stmt.excluded.group_id, # type: ignore - ), - ) + if self.dialect == "sqlite": + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + + # Note: uses SQLite insert to make on_conflict_do_update work. + # This code needs to be generalized a bit to work with more dialects. + insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert) + stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + [UpsertionRecord.key, UpsertionRecord.namespace], + set_=dict( + # attr-defined type ignore + updated_at=insert_stmt.excluded.updated_at, # type: ignore + group_id=insert_stmt.excluded.group_id, # type: ignore + ), + ) + elif self.dialect == "postgresql": + from sqlalchemy.dialects.postgresql import insert as pg_insert + + # Note: uses SQLite insert to make on_conflict_do_update work. + # This code needs to be generalized a bit to work with more dialects. + insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) + stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + "uix_key_namespace", # Name of constraint + set_=dict( + # attr-defined type ignore + updated_at=insert_stmt.excluded.updated_at, # type: ignore + group_id=insert_stmt.excluded.group_id, # type: ignore + ), + ) + else: + raise NotImplementedError(f"Unsupported dialect {self.dialect}") + session.execute(stmt) session.commit()