From cafce9ed23fa1b7d2efbf14331e8722c277c6973 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 30 Aug 2023 09:35:00 -0400 Subject: [PATCH] x --- .../langchain/indexes/_sql_record_manager.py | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index ac7cc6a7645..ab0b487eef8 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -14,12 +14,12 @@ allow it to work with a variety of SQL as a backend. * Keys can be deleted. """ import contextlib -import decimal import uuid from typing import Any, Dict, Generator, List, Optional, Sequence, Union +import decimal +from sqlalchemy import URL from sqlalchemy import ( - URL, Column, Engine, Float, @@ -201,25 +201,38 @@ class SQLRecordManager(RecordManager): for key, group_id in zip(keys, group_ids) ] - if self.dialect == "sqlite": - from sqlalchemy.dialects.sqlite import insert - elif self.dialect == "postgresql": - from sqlalchemy.dialects.sqlite import insert - else: - raise NotImplementedError(f"Unsupported dialect {self.dialect}") - with self._make_session() as session: - # Note: uses SQLite insert to make on_conflict_do_update work. - # This code needs to be generalized a bit to work with more dialects. - insert_stmt = insert(UpsertionRecord).values(records_to_upsert) - stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] - [UpsertionRecord.key, UpsertionRecord.namespace], - set_=dict( - # attr-defined type ignore - updated_at=insert_stmt.excluded.updated_at, # type: ignore - group_id=insert_stmt.excluded.group_id, # type: ignore - ), - ) + if self.dialect == "sqlite": + from sqlalchemy.dialects.sqlite import insert + + # Note: uses SQLite insert to make on_conflict_do_update work. + # This code needs to be generalized a bit to work with more dialects. + insert_stmt = insert(UpsertionRecord).values(records_to_upsert) + stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + [UpsertionRecord.key, UpsertionRecord.namespace], + set_=dict( + # attr-defined type ignore + updated_at=insert_stmt.excluded.updated_at, # type: ignore + group_id=insert_stmt.excluded.group_id, # type: ignore + ), + ) + elif self.dialect == "postgresql": + from sqlalchemy.dialects.postgresql import insert + + # Note: uses SQLite insert to make on_conflict_do_update work. + # This code needs to be generalized a bit to work with more dialects. + insert_stmt = insert(UpsertionRecord).values(records_to_upsert) + stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + "uix_key_namespace", # Name of constraint + set_=dict( + # attr-defined type ignore + updated_at=insert_stmt.excluded.updated_at, # type: ignore + group_id=insert_stmt.excluded.group_id, # type: ignore + ), + ) + else: + raise NotImplementedError(f"Unsupported dialect {self.dialect}") + session.execute(stmt) session.commit()