mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 00:47:27 +00:00
x
This commit is contained in:
parent
880bf06290
commit
cafce9ed23
@ -14,12 +14,12 @@ allow it to work with a variety of SQL as a backend.
|
|||||||
* Keys can be deleted.
|
* Keys can be deleted.
|
||||||
"""
|
"""
|
||||||
import contextlib
|
import contextlib
|
||||||
import decimal
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
from sqlalchemy import URL
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
URL,
|
|
||||||
Column,
|
Column,
|
||||||
Engine,
|
Engine,
|
||||||
Float,
|
Float,
|
||||||
@ -201,25 +201,38 @@ class SQLRecordManager(RecordManager):
|
|||||||
for key, group_id in zip(keys, group_ids)
|
for key, group_id in zip(keys, group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.dialect == "sqlite":
|
|
||||||
from sqlalchemy.dialects.sqlite import insert
|
|
||||||
elif self.dialect == "postgresql":
|
|
||||||
from sqlalchemy.dialects.sqlite import insert
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
|
||||||
|
|
||||||
with self._make_session() as session:
|
with self._make_session() as session:
|
||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
if self.dialect == "sqlite":
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
from sqlalchemy.dialects.sqlite import insert
|
||||||
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
|
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
set_=dict(
|
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
|
||||||
# attr-defined type ignore
|
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
set_=dict(
|
||||||
),
|
# attr-defined type ignore
|
||||||
)
|
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||||
|
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif self.dialect == "postgresql":
|
||||||
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
|
|
||||||
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
|
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
|
||||||
|
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||||
|
"uix_key_namespace", # Name of constraint
|
||||||
|
set_=dict(
|
||||||
|
# attr-defined type ignore
|
||||||
|
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||||
|
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
||||||
|
|
||||||
session.execute(stmt)
|
session.execute(stmt)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user