Add indexing support for postgresql (#9933)

Add support to postgresql for the SQL Manager Record

This code was tested locally. I'm looking at how to add testing with
postgres in a separate PR.
This commit is contained in:
Bagatur 2023-08-31 07:27:09 -07:00 committed by GitHub
commit 4b15328767
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 17 deletions

View File

@ -332,9 +332,9 @@ def index(
uids_to_delete = record_manager.list_keys(before=index_start_dt)
if uids_to_delete:
# Then delete from vector store.
vector_store.delete(uids_to_delete)
# First delete from record store.
vector_store.delete(uids_to_delete)
# Then delete from record manager.
record_manager.delete_keys(uids_to_delete)
num_deleted = len(uids_to_delete)

View File

@ -14,10 +14,12 @@ allow it to work with a variety of SQL as a backend.
* Keys can be deleted.
"""
import contextlib
import decimal
import uuid
from typing import Any, Dict, Generator, List, Optional, Sequence
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from sqlalchemy import (
URL,
Column,
Engine,
Float,
@ -28,7 +30,6 @@ from sqlalchemy import (
create_engine,
text,
)
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
@ -77,7 +78,7 @@ class SQLRecordManager(RecordManager):
namespace: str,
*,
engine: Optional[Engine] = None,
db_url: Optional[str] = None,
db_url: Union[None, str, URL] = None,
engine_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the SQLRecordManager.
@ -114,6 +115,7 @@ class SQLRecordManager(RecordManager):
raise AssertionError("Something went wrong with configuration of engine.")
self.engine = _engine
self.dialect = _engine.dialect.name
self.session_factory = sessionmaker(bind=self.engine)
def create_schema(self) -> None:
@ -145,8 +147,16 @@ class SQLRecordManager(RecordManager):
# 2440587.5 - constant represents the Julian day number for January 1, 1970
# 86400.0 - constant represents the number of seconds
# in a day (24 hours * 60 minutes * 60 seconds)
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
if self.dialect == "sqlite":
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
elif self.dialect == "postgresql":
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
else:
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
dt = session.execute(query).scalar()
if isinstance(dt, decimal.Decimal):
dt = float(dt)
if not isinstance(dt, float):
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
return dt
@ -192,17 +202,37 @@ class SQLRecordManager(RecordManager):
]
with self._make_session() as session:
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
session.execute(stmt)
session.commit()