This commit is contained in:
Eugene Yurtsev 2023-08-30 09:35:00 -04:00
parent 880bf06290
commit cafce9ed23

View File

@ -14,12 +14,12 @@ allow it to work with a variety of SQL as a backend.
* Keys can be deleted.
"""
import contextlib
import decimal
import uuid
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
import decimal
from sqlalchemy import URL
from sqlalchemy import (
URL,
Column,
Engine,
Float,
@ -201,25 +201,38 @@ class SQLRecordManager(RecordManager):
for key, group_id in zip(keys, group_ids)
]
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert
elif self.dialect == "postgresql":
from sqlalchemy.dialects.sqlite import insert
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
with self._make_session() as session:
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
session.execute(stmt)
session.commit()