mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 08:56:27 +00:00
- **Description:** Add relevant type annotations for relevant session and query objects to resolve mypy errors when `# type: ignore` comments are removed. - **Issue:** #17048 - **Dependencies:** None, - **Twitter handle:** [clesiemo3](https://twitter.com/clesiemo3) I attempted to solve the `UpsertionRecord` ignore but it would require added a deprecated plugin or moving completely to sqlalchemy 2.0+ from my understanding. I'm assuming this is not something desired at this point in time.
This commit is contained in:
parent
3d5e988c55
commit
912210ac19
@ -39,7 +39,7 @@ from sqlalchemy.ext.asyncio import (
|
|||||||
create_async_engine,
|
create_async_engine,
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Query, Session, sessionmaker
|
||||||
|
|
||||||
from langchain.indexes.base import RecordManager
|
from langchain.indexes.base import RecordManager
|
||||||
|
|
||||||
@ -284,31 +284,35 @@ class SQLRecordManager(RecordManager):
|
|||||||
|
|
||||||
with self._make_session() as session:
|
with self._make_session() as session:
|
||||||
if self.dialect == "sqlite":
|
if self.dialect == "sqlite":
|
||||||
|
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
|
|
||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
UpsertionRecord
|
||||||
|
).values(records_to_upsert)
|
||||||
|
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
||||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||||
set_=dict(
|
set_=dict(
|
||||||
# attr-defined type ignore
|
updated_at=sqlite_insert_stmt.excluded.updated_at,
|
||||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
group_id=sqlite_insert_stmt.excluded.group_id,
|
||||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif self.dialect == "postgresql":
|
elif self.dialect == "postgresql":
|
||||||
|
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
||||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
# Note: uses postgresql insert to make on_conflict_do_update work.
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
records_to_upsert
|
||||||
|
)
|
||||||
|
stmt = pg_insert_stmt.on_conflict_do_update(
|
||||||
"uix_key_namespace", # Name of constraint
|
"uix_key_namespace", # Name of constraint
|
||||||
set_=dict(
|
set_=dict(
|
||||||
# attr-defined type ignore
|
updated_at=pg_insert_stmt.excluded.updated_at,
|
||||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
group_id=pg_insert_stmt.excluded.group_id,
|
||||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -359,31 +363,35 @@ class SQLRecordManager(RecordManager):
|
|||||||
|
|
||||||
async with self._amake_session() as session:
|
async with self._amake_session() as session:
|
||||||
if self.dialect == "sqlite":
|
if self.dialect == "sqlite":
|
||||||
|
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
|
|
||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
UpsertionRecord
|
||||||
|
).values(records_to_upsert)
|
||||||
|
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
||||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||||
set_=dict(
|
set_=dict(
|
||||||
# attr-defined type ignore
|
updated_at=sqlite_insert_stmt.excluded.updated_at,
|
||||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
group_id=sqlite_insert_stmt.excluded.group_id,
|
||||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif self.dialect == "postgresql":
|
elif self.dialect == "postgresql":
|
||||||
|
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
||||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||||
# This code needs to be generalized a bit to work with more dialects.
|
# This code needs to be generalized a bit to work with more dialects.
|
||||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
|
||||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
records_to_upsert
|
||||||
|
)
|
||||||
|
stmt = pg_insert_stmt.on_conflict_do_update(
|
||||||
"uix_key_namespace", # Name of constraint
|
"uix_key_namespace", # Name of constraint
|
||||||
set_=dict(
|
set_=dict(
|
||||||
# attr-defined type ignore
|
updated_at=pg_insert_stmt.excluded.updated_at,
|
||||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
group_id=pg_insert_stmt.excluded.group_id,
|
||||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -394,18 +402,15 @@ class SQLRecordManager(RecordManager):
|
|||||||
|
|
||||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||||
"""Check if the given keys exist in the SQLite database."""
|
"""Check if the given keys exist in the SQLite database."""
|
||||||
|
session: Session
|
||||||
with self._make_session() as session:
|
with self._make_session() as session:
|
||||||
records = (
|
filtered_query: Query = session.query(UpsertionRecord.key).filter(
|
||||||
# mypy does not recognize .all()
|
and_(
|
||||||
session.query(UpsertionRecord.key) # type: ignore[attr-defined]
|
UpsertionRecord.key.in_(keys),
|
||||||
.filter(
|
UpsertionRecord.namespace == self.namespace,
|
||||||
and_(
|
|
||||||
UpsertionRecord.key.in_(keys),
|
|
||||||
UpsertionRecord.namespace == self.namespace,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
records = filtered_query.all()
|
||||||
found_keys = set(r.key for r in records)
|
found_keys = set(r.key for r in records)
|
||||||
return [k in found_keys for k in keys]
|
return [k in found_keys for k in keys]
|
||||||
|
|
||||||
@ -438,28 +443,22 @@ class SQLRecordManager(RecordManager):
|
|||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""List records in the SQLite database based on the provided date range."""
|
"""List records in the SQLite database based on the provided date range."""
|
||||||
|
session: Session
|
||||||
with self._make_session() as session:
|
with self._make_session() as session:
|
||||||
query = session.query(UpsertionRecord).filter(
|
query: Query = session.query(UpsertionRecord).filter(
|
||||||
UpsertionRecord.namespace == self.namespace
|
UpsertionRecord.namespace == self.namespace
|
||||||
)
|
)
|
||||||
|
|
||||||
# mypy does not recognize .all() or .filter()
|
|
||||||
if after:
|
if after:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.updated_at > after)
|
||||||
UpsertionRecord.updated_at > after
|
|
||||||
)
|
|
||||||
if before:
|
if before:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.updated_at < before)
|
||||||
UpsertionRecord.updated_at < before
|
|
||||||
)
|
|
||||||
if group_ids:
|
if group_ids:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
|
||||||
UpsertionRecord.group_id.in_(group_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
if limit:
|
if limit:
|
||||||
query = query.limit(limit) # type: ignore[attr-defined]
|
query = query.limit(limit)
|
||||||
records = query.all() # type: ignore[attr-defined]
|
records = query.all()
|
||||||
return [r.key for r in records]
|
return [r.key for r in records]
|
||||||
|
|
||||||
async def alist_keys(
|
async def alist_keys(
|
||||||
@ -471,40 +470,37 @@ class SQLRecordManager(RecordManager):
|
|||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""List records in the SQLite database based on the provided date range."""
|
"""List records in the SQLite database based on the provided date range."""
|
||||||
|
session: AsyncSession
|
||||||
async with self._amake_session() as session:
|
async with self._amake_session() as session:
|
||||||
query = select(UpsertionRecord.key).filter(
|
query: Query = select(UpsertionRecord.key).filter(
|
||||||
UpsertionRecord.namespace == self.namespace
|
UpsertionRecord.namespace == self.namespace
|
||||||
)
|
)
|
||||||
|
|
||||||
# mypy does not recognize .all() or .filter()
|
# mypy does not recognize .all() or .filter()
|
||||||
if after:
|
if after:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.updated_at > after)
|
||||||
UpsertionRecord.updated_at > after
|
|
||||||
)
|
|
||||||
if before:
|
if before:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.updated_at < before)
|
||||||
UpsertionRecord.updated_at < before
|
|
||||||
)
|
|
||||||
if group_ids:
|
if group_ids:
|
||||||
query = query.filter( # type: ignore[attr-defined]
|
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
|
||||||
UpsertionRecord.group_id.in_(group_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
if limit:
|
if limit:
|
||||||
query = query.limit(limit) # type: ignore[attr-defined]
|
query = query.limit(limit)
|
||||||
records = (await session.execute(query)).scalars().all()
|
records = (await session.execute(query)).scalars().all()
|
||||||
return list(records)
|
return list(records)
|
||||||
|
|
||||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||||
"""Delete records from the SQLite database."""
|
"""Delete records from the SQLite database."""
|
||||||
|
session: Session
|
||||||
with self._make_session() as session:
|
with self._make_session() as session:
|
||||||
# mypy does not recognize .delete()
|
filtered_query: Query = session.query(UpsertionRecord).filter(
|
||||||
session.query(UpsertionRecord).filter(
|
|
||||||
and_(
|
and_(
|
||||||
UpsertionRecord.key.in_(keys),
|
UpsertionRecord.key.in_(keys),
|
||||||
UpsertionRecord.namespace == self.namespace,
|
UpsertionRecord.namespace == self.namespace,
|
||||||
)
|
)
|
||||||
).delete() # type: ignore[attr-defined]
|
)
|
||||||
|
|
||||||
|
filtered_query.delete()
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user