community[patch]: restore compatibility with SQLAlchemy 1.x (#22546)

- **Description:** Restores compatibility with SQLAlchemy 1.4.x that was
broken since #18992 and adds a test run for this version on CI (only for
Python 3.11)
- **Issue:** fixes #19681
- **Dependencies:** None
- **Twitter handle:** `@krassowski_m`

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Michał Krassowski 2024-06-19 18:58:57 +01:00 committed by GitHub
parent 48d6ea427f
commit 710197e18c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 48 additions and 19 deletions

View File

@ -1,6 +1,7 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy.engine import RowMapping
from sqlalchemy.sql.expression import Select
from langchain_community.docstore.document import Document
from langchain_community.document_loaders.base import BaseLoader
@ -19,7 +20,7 @@ class SQLDatabaseLoader(BaseLoader):
def __init__(
self,
query: Union[str, sa.Select],
query: Union[str, Select],
db: SQLDatabase,
*,
parameters: Optional[Dict[str, Any]] = None,
@ -106,7 +107,7 @@ class SQLDatabaseLoader(BaseLoader):
@staticmethod
def page_content_default_mapper(
row: sa.RowMapping, column_names: Optional[List[str]] = None
row: RowMapping, column_names: Optional[List[str]] = None
) -> str:
"""
A reasonable default function to convert a record into a "page content" string.
@ -121,7 +122,7 @@ class SQLDatabaseLoader(BaseLoader):
@staticmethod
def metadata_default_mapper(
row: sa.RowMapping, column_names: Optional[List[str]] = None
row: RowMapping, column_names: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
A reasonable default function to convert a record into a "metadata" dictionary.

View File

@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend.
* Keys can be listed based on the updated at field.
* Keys can be deleted.
"""
import contextlib
import decimal
import uuid
@ -29,9 +30,7 @@ from typing import (
)
from sqlalchemy import (
URL,
Column,
Engine,
Float,
Index,
String,
@ -42,15 +41,21 @@ from sqlalchemy import (
select,
text,
)
from sqlalchemy.engine import URL, Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
from langchain_community.indexes.base import RecordManager
Base = declarative_base()

View File

@ -20,7 +20,7 @@ from typing import (
import numpy as np
import sqlalchemy
from langchain_core._api import deprecated, warn_deprecated
from sqlalchemy import SQLColumnExpression, delete, func
from sqlalchemy import delete, func
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
from sqlalchemy.orm import Session, relationship
@ -29,6 +29,12 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
try:
from sqlalchemy import SQLColumnExpression
except ImportError:
# for sqlalchemy < 2
SQLColumnExpression = Any # type: ignore
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor

View File

@ -4,7 +4,17 @@ from typing import Any, AsyncGenerator, Generator, List, Tuple
import pytest
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from sqlalchemy import Column, Integer, Text
from sqlalchemy.orm import DeclarativeBase
try:
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass
except ImportError:
# for sqlalchemy < 2
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() # type:ignore
from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_community.chat_message_histories.sql import DefaultMessageConverter
@ -198,9 +208,6 @@ async def test_async_clear_messages(
def test_model_no_session_id_field_error(con_str: str) -> None:
class Base(DeclarativeBase):
pass
class Model(Base):
__tablename__ = "test_table"
id = Column(Integer, primary_key=True)

View File

@ -1,21 +1,25 @@
# flake8: noqa: E501
"""Test SQL database wrapper."""
import pytest
import sqlalchemy as sa
from packaging import version
from sqlalchemy import (
Column,
Integer,
MetaData,
Result,
String,
Table,
Text,
insert,
select,
)
from sqlalchemy.engine import Engine, Result
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1
metadata_obj = MetaData()
user = Table(
@ -35,18 +39,18 @@ company = Table(
@pytest.fixture
def engine() -> sa.Engine:
def engine() -> Engine:
return sa.create_engine("sqlite:///:memory:")
@pytest.fixture
def db(engine: sa.Engine) -> SQLDatabase:
def db(engine: Engine) -> SQLDatabase:
metadata_obj.create_all(engine)
return SQLDatabase(engine)
@pytest.fixture
def db_lazy_reflection(engine: sa.Engine) -> SQLDatabase:
def db_lazy_reflection(engine: Engine) -> SQLDatabase:
metadata_obj.create_all(engine)
return SQLDatabase(engine, lazy_table_reflection=True)
@ -230,6 +234,7 @@ def test_sql_database_run_update(db: SQLDatabase) -> None:
assert output == expected_output
@pytest.mark.skipif(is_sqlalchemy_v1, reason="Requires SQLAlchemy 2 or newer")
def test_sql_database_schema_translate_map() -> None:
"""Verify using statement-specific execution options."""

View File

@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend.
* Keys can be listed based on the updated at field.
* Keys can be deleted.
"""
import contextlib
import decimal
import uuid
@ -20,9 +21,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequenc
from langchain_core.indexing import RecordManager
from sqlalchemy import (
URL,
Column,
Engine,
Float,
Index,
String,
@ -33,15 +32,21 @@ from sqlalchemy import (
select,
text,
)
from sqlalchemy.engine import URL, Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Query, Session, sessionmaker
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
Base = declarative_base()