Add indexing support for postgresql (#9933)

Add support to postgresql for the SQL Manager Record

This code was tested locally. I'm looking at how to add testing with
postgres in a separate PR.
This commit is contained in:
Bagatur 2023-08-31 07:27:09 -07:00 committed by GitHub
commit 4b15328767
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 17 deletions

View File

@ -332,9 +332,9 @@ def index(
uids_to_delete = record_manager.list_keys(before=index_start_dt) uids_to_delete = record_manager.list_keys(before=index_start_dt)
if uids_to_delete: if uids_to_delete:
# Then delete from vector store.
vector_store.delete(uids_to_delete)
# First delete from record store. # First delete from record store.
vector_store.delete(uids_to_delete)
# Then delete from record manager.
record_manager.delete_keys(uids_to_delete) record_manager.delete_keys(uids_to_delete)
num_deleted = len(uids_to_delete) num_deleted = len(uids_to_delete)

View File

@ -14,10 +14,12 @@ allow it to work with a variety of SQL as a backend.
* Keys can be deleted. * Keys can be deleted.
""" """
import contextlib import contextlib
import decimal
import uuid import uuid
from typing import Any, Dict, Generator, List, Optional, Sequence from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from sqlalchemy import ( from sqlalchemy import (
URL,
Column, Column,
Engine, Engine,
Float, Float,
@ -28,7 +30,6 @@ from sqlalchemy import (
create_engine, create_engine,
text, text,
) )
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
@ -77,7 +78,7 @@ class SQLRecordManager(RecordManager):
namespace: str, namespace: str,
*, *,
engine: Optional[Engine] = None, engine: Optional[Engine] = None,
db_url: Optional[str] = None, db_url: Union[None, str, URL] = None,
engine_kwargs: Optional[Dict[str, Any]] = None, engine_kwargs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Initialize the SQLRecordManager. """Initialize the SQLRecordManager.
@ -114,6 +115,7 @@ class SQLRecordManager(RecordManager):
raise AssertionError("Something went wrong with configuration of engine.") raise AssertionError("Something went wrong with configuration of engine.")
self.engine = _engine self.engine = _engine
self.dialect = _engine.dialect.name
self.session_factory = sessionmaker(bind=self.engine) self.session_factory = sessionmaker(bind=self.engine)
def create_schema(self) -> None: def create_schema(self) -> None:
@ -145,8 +147,16 @@ class SQLRecordManager(RecordManager):
# 2440587.5 - constant represents the Julian day number for January 1, 1970 # 2440587.5 - constant represents the Julian day number for January 1, 1970
# 86400.0 - constant represents the number of seconds # 86400.0 - constant represents the number of seconds
# in a day (24 hours * 60 minutes * 60 seconds) # in a day (24 hours * 60 minutes * 60 seconds)
if self.dialect == "sqlite":
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
elif self.dialect == "postgresql":
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
else:
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
dt = session.execute(query).scalar() dt = session.execute(query).scalar()
if isinstance(dt, decimal.Decimal):
dt = float(dt)
if not isinstance(dt, float): if not isinstance(dt, float):
raise AssertionError(f"Unexpected type for datetime: {type(dt)}") raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
return dt return dt
@ -192,9 +202,12 @@ class SQLRecordManager(RecordManager):
] ]
with self._make_session() as session: with self._make_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work. # Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects. # This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert) insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace], [UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict( set_=dict(
@ -203,6 +216,23 @@ class SQLRecordManager(RecordManager):
group_id=insert_stmt.excluded.group_id, # type: ignore group_id=insert_stmt.excluded.group_id, # type: ignore
), ),
) )
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
session.execute(stmt) session.execute(stmt)
session.commit() session.commit()