diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index 14e2355af22..76dfd967234 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -39,7 +39,7 @@ from sqlalchemy.ext.asyncio import ( create_async_engine, ) from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Query, Session, sessionmaker from langchain.indexes.base import RecordManager @@ -284,31 +284,35 @@ class SQLRecordManager(RecordManager): with self._make_session() as session: if self.dialect == "sqlite": + from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType from sqlalchemy.dialects.sqlite import insert as sqlite_insert # Note: uses SQLite insert to make on_conflict_do_update work. # This code needs to be generalized a bit to work with more dialects. - insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert) - stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + sqlite_insert_stmt: SqliteInsertType = sqlite_insert( + UpsertionRecord + ).values(records_to_upsert) + stmt = sqlite_insert_stmt.on_conflict_do_update( [UpsertionRecord.key, UpsertionRecord.namespace], set_=dict( - # attr-defined type ignore - updated_at=insert_stmt.excluded.updated_at, # type: ignore - group_id=insert_stmt.excluded.group_id, # type: ignore + updated_at=sqlite_insert_stmt.excluded.updated_at, + group_id=sqlite_insert_stmt.excluded.group_id, ), ) elif self.dialect == "postgresql": + from sqlalchemy.dialects.postgresql import Insert as PgInsertType from sqlalchemy.dialects.postgresql import insert as pg_insert - # Note: uses SQLite insert to make on_conflict_do_update work. + # Note: uses postgresql insert to make on_conflict_do_update work. # This code needs to be generalized a bit to work with more dialects. - insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) - stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values( + records_to_upsert + ) + stmt = pg_insert_stmt.on_conflict_do_update( "uix_key_namespace", # Name of constraint set_=dict( - # attr-defined type ignore - updated_at=insert_stmt.excluded.updated_at, # type: ignore - group_id=insert_stmt.excluded.group_id, # type: ignore + updated_at=pg_insert_stmt.excluded.updated_at, + group_id=pg_insert_stmt.excluded.group_id, ), ) else: @@ -359,31 +363,35 @@ class SQLRecordManager(RecordManager): async with self._amake_session() as session: if self.dialect == "sqlite": + from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType from sqlalchemy.dialects.sqlite import insert as sqlite_insert # Note: uses SQLite insert to make on_conflict_do_update work. # This code needs to be generalized a bit to work with more dialects. - insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert) - stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + sqlite_insert_stmt: SqliteInsertType = sqlite_insert( + UpsertionRecord + ).values(records_to_upsert) + stmt = sqlite_insert_stmt.on_conflict_do_update( [UpsertionRecord.key, UpsertionRecord.namespace], set_=dict( - # attr-defined type ignore - updated_at=insert_stmt.excluded.updated_at, # type: ignore - group_id=insert_stmt.excluded.group_id, # type: ignore + updated_at=sqlite_insert_stmt.excluded.updated_at, + group_id=sqlite_insert_stmt.excluded.group_id, ), ) elif self.dialect == "postgresql": + from sqlalchemy.dialects.postgresql import Insert as PgInsertType from sqlalchemy.dialects.postgresql import insert as pg_insert # Note: uses SQLite insert to make on_conflict_do_update work. # This code needs to be generalized a bit to work with more dialects. - insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) - stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values( + records_to_upsert + ) + stmt = pg_insert_stmt.on_conflict_do_update( "uix_key_namespace", # Name of constraint set_=dict( - # attr-defined type ignore - updated_at=insert_stmt.excluded.updated_at, # type: ignore - group_id=insert_stmt.excluded.group_id, # type: ignore + updated_at=pg_insert_stmt.excluded.updated_at, + group_id=pg_insert_stmt.excluded.group_id, ), ) else: @@ -394,18 +402,15 @@ class SQLRecordManager(RecordManager): def exists(self, keys: Sequence[str]) -> List[bool]: """Check if the given keys exist in the SQLite database.""" + session: Session with self._make_session() as session: - records = ( - # mypy does not recognize .all() - session.query(UpsertionRecord.key) # type: ignore[attr-defined] - .filter( - and_( - UpsertionRecord.key.in_(keys), - UpsertionRecord.namespace == self.namespace, - ) + filtered_query: Query = session.query(UpsertionRecord.key).filter( + and_( + UpsertionRecord.key.in_(keys), + UpsertionRecord.namespace == self.namespace, ) - .all() ) + records = filtered_query.all() found_keys = set(r.key for r in records) return [k in found_keys for k in keys] @@ -438,28 +443,22 @@ class SQLRecordManager(RecordManager): limit: Optional[int] = None, ) -> List[str]: """List records in the SQLite database based on the provided date range.""" + session: Session with self._make_session() as session: - query = session.query(UpsertionRecord).filter( + query: Query = session.query(UpsertionRecord).filter( UpsertionRecord.namespace == self.namespace ) - # mypy does not recognize .all() or .filter() if after: - query = query.filter( # type: ignore[attr-defined] - UpsertionRecord.updated_at > after - ) + query = query.filter(UpsertionRecord.updated_at > after) if before: - query = query.filter( # type: ignore[attr-defined] - UpsertionRecord.updated_at < before - ) + query = query.filter(UpsertionRecord.updated_at < before) if group_ids: - query = query.filter( # type: ignore[attr-defined] - UpsertionRecord.group_id.in_(group_ids) - ) + query = query.filter(UpsertionRecord.group_id.in_(group_ids)) if limit: - query = query.limit(limit) # type: ignore[attr-defined] - records = query.all() # type: ignore[attr-defined] + query = query.limit(limit) + records = query.all() return [r.key for r in records] async def alist_keys( @@ -471,40 +470,37 @@ class SQLRecordManager(RecordManager): limit: Optional[int] = None, ) -> List[str]: """List records in the SQLite database based on the provided date range.""" + session: AsyncSession async with self._amake_session() as session: - query = select(UpsertionRecord.key).filter( + query: Query = select(UpsertionRecord.key).filter( UpsertionRecord.namespace == self.namespace ) # mypy does not recognize .all() or .filter() if after: - query = query.filter( # type: ignore[attr-defined] - UpsertionRecord.updated_at > after - ) + query = query.filter(UpsertionRecord.updated_at > after) if before: - query = query.filter( # type: ignore[attr-defined] - UpsertionRecord.updated_at < before - ) + query = query.filter(UpsertionRecord.updated_at < before) if group_ids: - query = query.filter( # type: ignore[attr-defined] - UpsertionRecord.group_id.in_(group_ids) - ) + query = query.filter(UpsertionRecord.group_id.in_(group_ids)) if limit: - query = query.limit(limit) # type: ignore[attr-defined] + query = query.limit(limit) records = (await session.execute(query)).scalars().all() return list(records) def delete_keys(self, keys: Sequence[str]) -> None: """Delete records from the SQLite database.""" + session: Session with self._make_session() as session: - # mypy does not recognize .delete() - session.query(UpsertionRecord).filter( + filtered_query: Query = session.query(UpsertionRecord).filter( and_( UpsertionRecord.key.in_(keys), UpsertionRecord.namespace == self.namespace, ) - ).delete() # type: ignore[attr-defined] + ) + + filtered_query.delete() session.commit() async def adelete_keys(self, keys: Sequence[str]) -> None: