diff --git a/libs/community/langchain_community/document_loaders/sql_database.py b/libs/community/langchain_community/document_loaders/sql_database.py index c3ab2e4cc91..c8d03a0db15 100644 --- a/libs/community/langchain_community/document_loaders/sql_database.py +++ b/libs/community/langchain_community/document_loaders/sql_database.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union -import sqlalchemy as sa +from sqlalchemy.engine import RowMapping +from sqlalchemy.sql.expression import Select from langchain_community.docstore.document import Document from langchain_community.document_loaders.base import BaseLoader @@ -19,7 +20,7 @@ class SQLDatabaseLoader(BaseLoader): def __init__( self, - query: Union[str, sa.Select], + query: Union[str, Select], db: SQLDatabase, *, parameters: Optional[Dict[str, Any]] = None, @@ -106,7 +107,7 @@ class SQLDatabaseLoader(BaseLoader): @staticmethod def page_content_default_mapper( - row: sa.RowMapping, column_names: Optional[List[str]] = None + row: RowMapping, column_names: Optional[List[str]] = None ) -> str: """ A reasonable default function to convert a record into a "page content" string. @@ -121,7 +122,7 @@ class SQLDatabaseLoader(BaseLoader): @staticmethod def metadata_default_mapper( - row: sa.RowMapping, column_names: Optional[List[str]] = None + row: RowMapping, column_names: Optional[List[str]] = None ) -> Dict[str, Any]: """ A reasonable default function to convert a record into a "metadata" dictionary. diff --git a/libs/community/langchain_community/indexes/_sql_record_manager.py b/libs/community/langchain_community/indexes/_sql_record_manager.py index 22e56885ae9..423df0d0ef1 100644 --- a/libs/community/langchain_community/indexes/_sql_record_manager.py +++ b/libs/community/langchain_community/indexes/_sql_record_manager.py @@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend. * Keys can be listed based on the updated at field. * Keys can be deleted. """ + import contextlib import decimal import uuid @@ -29,9 +30,7 @@ from typing import ( ) from sqlalchemy import ( - URL, Column, - Engine, Float, Index, String, @@ -42,15 +41,21 @@ from sqlalchemy import ( select, text, ) +from sqlalchemy.engine import URL, Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, - async_sessionmaker, create_async_engine, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker +try: + from sqlalchemy.ext.asyncio import async_sessionmaker +except ImportError: + # dummy for sqlalchemy < 2 + async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore + from langchain_community.indexes.base import RecordManager Base = declarative_base() diff --git a/libs/community/langchain_community/vectorstores/pgvector.py b/libs/community/langchain_community/vectorstores/pgvector.py index ff4cbbd783b..83661b534c8 100644 --- a/libs/community/langchain_community/vectorstores/pgvector.py +++ b/libs/community/langchain_community/vectorstores/pgvector.py @@ -20,7 +20,7 @@ from typing import ( import numpy as np import sqlalchemy from langchain_core._api import deprecated, warn_deprecated -from sqlalchemy import SQLColumnExpression, delete, func +from sqlalchemy import delete, func from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID from sqlalchemy.orm import Session, relationship @@ -29,6 +29,12 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +try: + from sqlalchemy import SQLColumnExpression +except ImportError: + # for sqlalchemy < 2 + SQLColumnExpression = Any # type: ignore + from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.runnables.config import run_in_executor diff --git a/libs/community/tests/unit_tests/chat_message_histories/test_sql.py b/libs/community/tests/unit_tests/chat_message_histories/test_sql.py index 68e7e216cec..9e3eeac7b61 100644 --- a/libs/community/tests/unit_tests/chat_message_histories/test_sql.py +++ b/libs/community/tests/unit_tests/chat_message_histories/test_sql.py @@ -4,7 +4,17 @@ from typing import Any, AsyncGenerator, Generator, List, Tuple import pytest from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from sqlalchemy import Column, Integer, Text -from sqlalchemy.orm import DeclarativeBase + +try: + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass +except ImportError: + # for sqlalchemy < 2 + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() # type:ignore from langchain_community.chat_message_histories import SQLChatMessageHistory from langchain_community.chat_message_histories.sql import DefaultMessageConverter @@ -198,9 +208,6 @@ async def test_async_clear_messages( def test_model_no_session_id_field_error(con_str: str) -> None: - class Base(DeclarativeBase): - pass - class Model(Base): __tablename__ = "test_table" id = Column(Integer, primary_key=True) diff --git a/libs/community/tests/unit_tests/test_sql_database.py b/libs/community/tests/unit_tests/test_sql_database.py index da4c1ddbea6..6bd37d4052a 100644 --- a/libs/community/tests/unit_tests/test_sql_database.py +++ b/libs/community/tests/unit_tests/test_sql_database.py @@ -1,21 +1,25 @@ # flake8: noqa: E501 """Test SQL database wrapper.""" + import pytest import sqlalchemy as sa +from packaging import version from sqlalchemy import ( Column, Integer, MetaData, - Result, String, Table, Text, insert, select, ) +from sqlalchemy.engine import Engine, Result from langchain_community.utilities.sql_database import SQLDatabase, truncate_word +is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1 + metadata_obj = MetaData() user = Table( @@ -35,18 +39,18 @@ company = Table( @pytest.fixture -def engine() -> sa.Engine: +def engine() -> Engine: return sa.create_engine("sqlite:///:memory:") @pytest.fixture -def db(engine: sa.Engine) -> SQLDatabase: +def db(engine: Engine) -> SQLDatabase: metadata_obj.create_all(engine) return SQLDatabase(engine) @pytest.fixture -def db_lazy_reflection(engine: sa.Engine) -> SQLDatabase: +def db_lazy_reflection(engine: Engine) -> SQLDatabase: metadata_obj.create_all(engine) return SQLDatabase(engine, lazy_table_reflection=True) @@ -230,6 +234,7 @@ def test_sql_database_run_update(db: SQLDatabase) -> None: assert output == expected_output +@pytest.mark.skipif(is_sqlalchemy_v1, reason="Requires SQLAlchemy 2 or newer") def test_sql_database_schema_translate_map() -> None: """Verify using statement-specific execution options.""" diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index a61be7776c9..991fd044bf9 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend. * Keys can be listed based on the updated at field. * Keys can be deleted. """ + import contextlib import decimal import uuid @@ -20,9 +21,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequenc from langchain_core.indexing import RecordManager from sqlalchemy import ( - URL, Column, - Engine, Float, Index, String, @@ -33,15 +32,21 @@ from sqlalchemy import ( select, text, ) +from sqlalchemy.engine import URL, Engine from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, - async_sessionmaker, create_async_engine, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Query, Session, sessionmaker +try: + from sqlalchemy.ext.asyncio import async_sessionmaker +except ImportError: + # dummy for sqlalchemy < 2 + async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore + Base = declarative_base()