mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 00:47:27 +00:00
- **Description:** Add relevant type annotations for relevant session and query objects to resolve mypy errors when `# type: ignore` comments are removed. - **Issue:** #17048 - **Dependencies:** None, - **Twitter handle:** [clesiemo3](https://twitter.com/clesiemo3) I attempted to solve the `UpsertionRecord` ignore but it would require added a deprecated plugin or moving completely to sqlalchemy 2.0+ from my understanding. I'm assuming this is not something desired at this point in time.
This commit is contained in:
parent
3d5e988c55
commit
912210ac19
@ -39,7 +39,7 @@ from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Query, Session, sessionmaker
|
||||
|
||||
from langchain.indexes.base import RecordManager
|
||||
|
||||
@ -284,31 +284,35 @@ class SQLRecordManager(RecordManager):
|
||||
|
||||
with self._make_session() as session:
|
||||
if self.dialect == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
|
||||
UpsertionRecord
|
||||
).values(records_to_upsert)
|
||||
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
updated_at=sqlite_insert_stmt.excluded.updated_at,
|
||||
group_id=sqlite_insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# Note: uses postgresql insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
|
||||
records_to_upsert
|
||||
)
|
||||
stmt = pg_insert_stmt.on_conflict_do_update(
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
updated_at=pg_insert_stmt.excluded.updated_at,
|
||||
group_id=pg_insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -359,31 +363,35 @@ class SQLRecordManager(RecordManager):
|
||||
|
||||
async with self._amake_session() as session:
|
||||
if self.dialect == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
|
||||
UpsertionRecord
|
||||
).values(records_to_upsert)
|
||||
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
updated_at=sqlite_insert_stmt.excluded.updated_at,
|
||||
group_id=sqlite_insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
|
||||
records_to_upsert
|
||||
)
|
||||
stmt = pg_insert_stmt.on_conflict_do_update(
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
updated_at=pg_insert_stmt.excluded.updated_at,
|
||||
group_id=pg_insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -394,18 +402,15 @@ class SQLRecordManager(RecordManager):
|
||||
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the given keys exist in the SQLite database."""
|
||||
session: Session
|
||||
with self._make_session() as session:
|
||||
records = (
|
||||
# mypy does not recognize .all()
|
||||
session.query(UpsertionRecord.key) # type: ignore[attr-defined]
|
||||
.filter(
|
||||
and_(
|
||||
UpsertionRecord.key.in_(keys),
|
||||
UpsertionRecord.namespace == self.namespace,
|
||||
)
|
||||
filtered_query: Query = session.query(UpsertionRecord.key).filter(
|
||||
and_(
|
||||
UpsertionRecord.key.in_(keys),
|
||||
UpsertionRecord.namespace == self.namespace,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
records = filtered_query.all()
|
||||
found_keys = set(r.key for r in records)
|
||||
return [k in found_keys for k in keys]
|
||||
|
||||
@ -438,28 +443,22 @@ class SQLRecordManager(RecordManager):
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List records in the SQLite database based on the provided date range."""
|
||||
session: Session
|
||||
with self._make_session() as session:
|
||||
query = session.query(UpsertionRecord).filter(
|
||||
query: Query = session.query(UpsertionRecord).filter(
|
||||
UpsertionRecord.namespace == self.namespace
|
||||
)
|
||||
|
||||
# mypy does not recognize .all() or .filter()
|
||||
if after:
|
||||
query = query.filter( # type: ignore[attr-defined]
|
||||
UpsertionRecord.updated_at > after
|
||||
)
|
||||
query = query.filter(UpsertionRecord.updated_at > after)
|
||||
if before:
|
||||
query = query.filter( # type: ignore[attr-defined]
|
||||
UpsertionRecord.updated_at < before
|
||||
)
|
||||
query = query.filter(UpsertionRecord.updated_at < before)
|
||||
if group_ids:
|
||||
query = query.filter( # type: ignore[attr-defined]
|
||||
UpsertionRecord.group_id.in_(group_ids)
|
||||
)
|
||||
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit) # type: ignore[attr-defined]
|
||||
records = query.all() # type: ignore[attr-defined]
|
||||
query = query.limit(limit)
|
||||
records = query.all()
|
||||
return [r.key for r in records]
|
||||
|
||||
async def alist_keys(
|
||||
@ -471,40 +470,37 @@ class SQLRecordManager(RecordManager):
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List records in the SQLite database based on the provided date range."""
|
||||
session: AsyncSession
|
||||
async with self._amake_session() as session:
|
||||
query = select(UpsertionRecord.key).filter(
|
||||
query: Query = select(UpsertionRecord.key).filter(
|
||||
UpsertionRecord.namespace == self.namespace
|
||||
)
|
||||
|
||||
# mypy does not recognize .all() or .filter()
|
||||
if after:
|
||||
query = query.filter( # type: ignore[attr-defined]
|
||||
UpsertionRecord.updated_at > after
|
||||
)
|
||||
query = query.filter(UpsertionRecord.updated_at > after)
|
||||
if before:
|
||||
query = query.filter( # type: ignore[attr-defined]
|
||||
UpsertionRecord.updated_at < before
|
||||
)
|
||||
query = query.filter(UpsertionRecord.updated_at < before)
|
||||
if group_ids:
|
||||
query = query.filter( # type: ignore[attr-defined]
|
||||
UpsertionRecord.group_id.in_(group_ids)
|
||||
)
|
||||
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit) # type: ignore[attr-defined]
|
||||
query = query.limit(limit)
|
||||
records = (await session.execute(query)).scalars().all()
|
||||
return list(records)
|
||||
|
||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete records from the SQLite database."""
|
||||
session: Session
|
||||
with self._make_session() as session:
|
||||
# mypy does not recognize .delete()
|
||||
session.query(UpsertionRecord).filter(
|
||||
filtered_query: Query = session.query(UpsertionRecord).filter(
|
||||
and_(
|
||||
UpsertionRecord.key.in_(keys),
|
||||
UpsertionRecord.namespace == self.namespace,
|
||||
)
|
||||
).delete() # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
filtered_query.delete()
|
||||
session.commit()
|
||||
|
||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user