mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
x
This commit is contained in:
parent
880bf06290
commit
cafce9ed23
@ -14,12 +14,12 @@ allow it to work with a variety of SQL as a backend.
|
||||
* Keys can be deleted.
|
||||
"""
|
||||
import contextlib
|
||||
import decimal
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||
import decimal
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy import (
|
||||
URL,
|
||||
Column,
|
||||
Engine,
|
||||
Float,
|
||||
@ -201,25 +201,38 @@ class SQLRecordManager(RecordManager):
|
||||
for key, group_id in zip(keys, group_ids)
|
||||
]
|
||||
|
||||
if self.dialect == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
elif self.dialect == "postgresql":
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
||||
|
||||
with self._make_session() as session:
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
),
|
||||
)
|
||||
if self.dialect == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
),
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
||||
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user