This commit is contained in:
Eugene Yurtsev 2023-08-30 09:35:00 -04:00
parent 880bf06290
commit cafce9ed23

View File

@ -14,12 +14,12 @@ allow it to work with a variety of SQL as a backend.
* Keys can be deleted. * Keys can be deleted.
""" """
import contextlib import contextlib
import decimal
import uuid import uuid
from typing import Any, Dict, Generator, List, Optional, Sequence, Union from typing import Any, Dict, Generator, List, Optional, Sequence, Union
import decimal
from sqlalchemy import URL
from sqlalchemy import ( from sqlalchemy import (
URL,
Column, Column,
Engine, Engine,
Float, Float,
@ -201,14 +201,10 @@ class SQLRecordManager(RecordManager):
for key, group_id in zip(keys, group_ids) for key, group_id in zip(keys, group_ids)
] ]
with self._make_session() as session:
if self.dialect == "sqlite": if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert from sqlalchemy.dialects.sqlite import insert
elif self.dialect == "postgresql":
from sqlalchemy.dialects.sqlite import insert
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
with self._make_session() as session:
# Note: uses SQLite insert to make on_conflict_do_update work. # Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects. # This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert) insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
@ -220,6 +216,23 @@ class SQLRecordManager(RecordManager):
group_id=insert_stmt.excluded.group_id, # type: ignore group_id=insert_stmt.excluded.group_id, # type: ignore
), ),
) )
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
"uix_key_namespace", # Name of constraint
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
else:
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
session.execute(stmt) session.execute(stmt)
session.commit() session.commit()