community[patch]: restore compatibility with SQLAlchemy 1.x (#22546)

- **Description:** Restores compatibility with SQLAlchemy 1.4.x that was
broken since #18992 and adds a test run for this version on CI (only for
Python 3.11)
- **Issue:** fixes #19681
- **Dependencies:** None
- **Twitter handle:** `@krassowski_m`

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Michał Krassowski 2024-06-19 18:58:57 +01:00 committed by GitHub
parent 48d6ea427f
commit 710197e18c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 48 additions and 19 deletions

View File

@ -1,6 +1,7 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union
import sqlalchemy as sa from sqlalchemy.engine import RowMapping
from sqlalchemy.sql.expression import Select
from langchain_community.docstore.document import Document from langchain_community.docstore.document import Document
from langchain_community.document_loaders.base import BaseLoader from langchain_community.document_loaders.base import BaseLoader
@ -19,7 +20,7 @@ class SQLDatabaseLoader(BaseLoader):
def __init__( def __init__(
self, self,
query: Union[str, sa.Select], query: Union[str, Select],
db: SQLDatabase, db: SQLDatabase,
*, *,
parameters: Optional[Dict[str, Any]] = None, parameters: Optional[Dict[str, Any]] = None,
@ -106,7 +107,7 @@ class SQLDatabaseLoader(BaseLoader):
@staticmethod @staticmethod
def page_content_default_mapper( def page_content_default_mapper(
row: sa.RowMapping, column_names: Optional[List[str]] = None row: RowMapping, column_names: Optional[List[str]] = None
) -> str: ) -> str:
""" """
A reasonable default function to convert a record into a "page content" string. A reasonable default function to convert a record into a "page content" string.
@ -121,7 +122,7 @@ class SQLDatabaseLoader(BaseLoader):
@staticmethod @staticmethod
def metadata_default_mapper( def metadata_default_mapper(
row: sa.RowMapping, column_names: Optional[List[str]] = None row: RowMapping, column_names: Optional[List[str]] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
A reasonable default function to convert a record into a "metadata" dictionary. A reasonable default function to convert a record into a "metadata" dictionary.

View File

@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend.
* Keys can be listed based on the updated at field. * Keys can be listed based on the updated at field.
* Keys can be deleted. * Keys can be deleted.
""" """
import contextlib import contextlib
import decimal import decimal
import uuid import uuid
@ -29,9 +30,7 @@ from typing import (
) )
from sqlalchemy import ( from sqlalchemy import (
URL,
Column, Column,
Engine,
Float, Float,
Index, Index,
String, String,
@ -42,15 +41,21 @@ from sqlalchemy import (
select, select,
text, text,
) )
from sqlalchemy.engine import URL, Engine
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
AsyncEngine, AsyncEngine,
AsyncSession, AsyncSession,
async_sessionmaker,
create_async_engine, create_async_engine,
) )
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
from langchain_community.indexes.base import RecordManager from langchain_community.indexes.base import RecordManager
Base = declarative_base() Base = declarative_base()

View File

@ -20,7 +20,7 @@ from typing import (
import numpy as np import numpy as np
import sqlalchemy import sqlalchemy
from langchain_core._api import deprecated, warn_deprecated from langchain_core._api import deprecated, warn_deprecated
from sqlalchemy import SQLColumnExpression, delete, func from sqlalchemy import delete, func
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
from sqlalchemy.orm import Session, relationship from sqlalchemy.orm import Session, relationship
@ -29,6 +29,12 @@ try:
except ImportError: except ImportError:
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
try:
from sqlalchemy import SQLColumnExpression
except ImportError:
# for sqlalchemy < 2
SQLColumnExpression = Any # type: ignore
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor

View File

@ -4,7 +4,17 @@ from typing import Any, AsyncGenerator, Generator, List, Tuple
import pytest import pytest
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from sqlalchemy import Column, Integer, Text from sqlalchemy import Column, Integer, Text
from sqlalchemy.orm import DeclarativeBase
try:
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass
except ImportError:
# for sqlalchemy < 2
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() # type:ignore
from langchain_community.chat_message_histories import SQLChatMessageHistory from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_community.chat_message_histories.sql import DefaultMessageConverter from langchain_community.chat_message_histories.sql import DefaultMessageConverter
@ -198,9 +208,6 @@ async def test_async_clear_messages(
def test_model_no_session_id_field_error(con_str: str) -> None: def test_model_no_session_id_field_error(con_str: str) -> None:
class Base(DeclarativeBase):
pass
class Model(Base): class Model(Base):
__tablename__ = "test_table" __tablename__ = "test_table"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)

View File

@ -1,21 +1,25 @@
# flake8: noqa: E501 # flake8: noqa: E501
"""Test SQL database wrapper.""" """Test SQL database wrapper."""
import pytest import pytest
import sqlalchemy as sa import sqlalchemy as sa
from packaging import version
from sqlalchemy import ( from sqlalchemy import (
Column, Column,
Integer, Integer,
MetaData, MetaData,
Result,
String, String,
Table, Table,
Text, Text,
insert, insert,
select, select,
) )
from sqlalchemy.engine import Engine, Result
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1
metadata_obj = MetaData() metadata_obj = MetaData()
user = Table( user = Table(
@ -35,18 +39,18 @@ company = Table(
@pytest.fixture @pytest.fixture
def engine() -> sa.Engine: def engine() -> Engine:
return sa.create_engine("sqlite:///:memory:") return sa.create_engine("sqlite:///:memory:")
@pytest.fixture @pytest.fixture
def db(engine: sa.Engine) -> SQLDatabase: def db(engine: Engine) -> SQLDatabase:
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
return SQLDatabase(engine) return SQLDatabase(engine)
@pytest.fixture @pytest.fixture
def db_lazy_reflection(engine: sa.Engine) -> SQLDatabase: def db_lazy_reflection(engine: Engine) -> SQLDatabase:
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
return SQLDatabase(engine, lazy_table_reflection=True) return SQLDatabase(engine, lazy_table_reflection=True)
@ -230,6 +234,7 @@ def test_sql_database_run_update(db: SQLDatabase) -> None:
assert output == expected_output assert output == expected_output
@pytest.mark.skipif(is_sqlalchemy_v1, reason="Requires SQLAlchemy 2 or newer")
def test_sql_database_schema_translate_map() -> None: def test_sql_database_schema_translate_map() -> None:
"""Verify using statement-specific execution options.""" """Verify using statement-specific execution options."""

View File

@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend.
* Keys can be listed based on the updated at field. * Keys can be listed based on the updated at field.
* Keys can be deleted. * Keys can be deleted.
""" """
import contextlib import contextlib
import decimal import decimal
import uuid import uuid
@ -20,9 +21,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequenc
from langchain_core.indexing import RecordManager from langchain_core.indexing import RecordManager
from sqlalchemy import ( from sqlalchemy import (
URL,
Column, Column,
Engine,
Float, Float,
Index, Index,
String, String,
@ -33,15 +32,21 @@ from sqlalchemy import (
select, select,
text, text,
) )
from sqlalchemy.engine import URL, Engine
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
AsyncEngine, AsyncEngine,
AsyncSession, AsyncSession,
async_sessionmaker,
create_async_engine, create_async_engine,
) )
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Query, Session, sessionmaker from sqlalchemy.orm import Query, Session, sessionmaker
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
Base = declarative_base() Base = declarative_base()