This commit is contained in:
Eugene Yurtsev 2023-08-30 09:36:27 -04:00
parent cafce9ed23
commit e8f29be350

View File

@ -14,12 +14,12 @@ allow it to work with a variety of SQL as a backend.
* Keys can be deleted. * Keys can be deleted.
""" """
import contextlib import contextlib
import decimal
import uuid import uuid
from typing import Any, Dict, Generator, List, Optional, Sequence, Union from typing import Any, Dict, Generator, List, Optional, Sequence, Union
import decimal
from sqlalchemy import URL
from sqlalchemy import ( from sqlalchemy import (
URL,
Column, Column,
Engine, Engine,
Float, Float,
@ -203,11 +203,11 @@ class SQLRecordManager(RecordManager):
with self._make_session() as session: with self._make_session() as session:
if self.dialect == "sqlite": if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work. # Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects. # This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert) insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace], [UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict( set_=dict(
@ -217,11 +217,11 @@ class SQLRecordManager(RecordManager):
), ),
) )
elif self.dialect == "postgresql": elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work. # Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects. # This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert) insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
"uix_key_namespace", # Name of constraint "uix_key_namespace", # Name of constraint
set_=dict( set_=dict(