mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 14:35:50 +00:00
Add indexing support for postgresql (#9933)
Add support to postgresql for the SQL Manager Record This code was tested locally. I'm looking at how to add testing with postgres in a separate PR.
This commit is contained in:
commit
4b15328767
@ -332,9 +332,9 @@ def index(
|
|||||||
uids_to_delete = record_manager.list_keys(before=index_start_dt)
|
uids_to_delete = record_manager.list_keys(before=index_start_dt)
|
||||||
|
|
||||||
if uids_to_delete:
|
if uids_to_delete:
|
||||||
# Then delete from vector store.
|
|
||||||
vector_store.delete(uids_to_delete)
|
|
||||||
# First delete from record store.
|
# First delete from record store.
|
||||||
|
vector_store.delete(uids_to_delete)
|
||||||
|
# Then delete from record manager.
|
||||||
record_manager.delete_keys(uids_to_delete)
|
record_manager.delete_keys(uids_to_delete)
|
||||||
num_deleted = len(uids_to_delete)
|
num_deleted = len(uids_to_delete)
|
||||||
|
|
||||||
|
@ -14,10 +14,12 @@ allow it to work with a variety of SQL as a backend.
|
|||||||
* Keys can be deleted.
|
* Keys can be deleted.
|
||||||
"""
|
"""
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import decimal
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Generator, List, Optional, Sequence
|
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
URL,
|
||||||
Column,
|
Column,
|
||||||
Engine,
|
Engine,
|
||||||
Float,
|
Float,
|
||||||
@ -28,7 +30,6 @@ from sqlalchemy import (
|
|||||||
create_engine,
|
create_engine,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.sqlite import insert
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
namespace: str,
|
namespace: str,
|
||||||
*,
|
*,
|
||||||
engine: Optional[Engine] = None,
|
engine: Optional[Engine] = None,
|
||||||
db_url: Optional[str] = None,
|
db_url: Union[None, str, URL] = None,
|
||||||
engine_kwargs: Optional[Dict[str, Any]] = None,
|
engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the SQLRecordManager.
|
"""Initialize the SQLRecordManager.
|
||||||
@ -114,6 +115,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
raise AssertionError("Something went wrong with configuration of engine.")
|
raise AssertionError("Something went wrong with configuration of engine.")
|
||||||
|
|
||||||
self.engine = _engine
|
self.engine = _engine
|
||||||
|
self.dialect = _engine.dialect.name
|
||||||
self.session_factory = sessionmaker(bind=self.engine)
|
self.session_factory = sessionmaker(bind=self.engine)
|
||||||
|
|
||||||
def create_schema(self) -> None:
|
def create_schema(self) -> None:
|
||||||
@ -145,8 +147,16 @@ class SQLRecordManager(RecordManager):
|
|||||||
# 2440587.5 - constant represents the Julian day number for January 1, 1970
|
# 2440587.5 - constant represents the Julian day number for January 1, 1970
|
||||||
# 86400.0 - constant represents the number of seconds
|
# 86400.0 - constant represents the number of seconds
|
||||||
# in a day (24 hours * 60 minutes * 60 seconds)
|
# in a day (24 hours * 60 minutes * 60 seconds)
|
||||||
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
|
if self.dialect == "sqlite":
|
||||||
|
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
|
||||||
|
elif self.dialect == "postgresql":
|
||||||
|
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
|
||||||
|
|
||||||
dt = session.execute(query).scalar()
|
dt = session.execute(query).scalar()
|
||||||
|
if isinstance(dt, decimal.Decimal):
|
||||||
|
dt = float(dt)
|
||||||
if not isinstance(dt, float):
|
if not isinstance(dt, float):
|
||||||
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
||||||
return dt
|
return dt
|
||||||
@ -192,17 +202,37 @@ class SQLRecordManager(RecordManager):
|
|||||||
]
|
]
|
||||||
|
|
||||||
with self._make_session() as session:
|
with self._make_session() as session:
|
||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
if self.dialect == "sqlite":
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
|
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
set_=dict(
|
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
||||||
# attr-defined type ignore
|
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
set_=dict(
|
||||||
),
|
# attr-defined type ignore
|
||||||
)
|
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||||
|
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif self.dialect == "postgresql":
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
|
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
||||||
|
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||||
|
"uix_key_namespace", # Name of constraint
|
||||||
|
set_=dict(
|
||||||
|
# attr-defined type ignore
|
||||||
|
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||||
|
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
||||||
|
|
||||||
session.execute(stmt)
|
session.execute(stmt)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user