mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
community[patch]: restore compatibility with SQLAlchemy 1.x (#22546)
- **Description:** Restores compatibility with SQLAlchemy 1.4.x that was broken since #18992 and adds a test run for this version on CI (only for Python 3.11) - **Issue:** fixes #19681 - **Dependencies:** None - **Twitter handle:** `@krassowski_m` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
48d6ea427f
commit
710197e18c
@ -1,6 +1,7 @@
|
|||||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
from sqlalchemy.engine import RowMapping
|
||||||
|
from sqlalchemy.sql.expression import Select
|
||||||
|
|
||||||
from langchain_community.docstore.document import Document
|
from langchain_community.docstore.document import Document
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
@ -19,7 +20,7 @@ class SQLDatabaseLoader(BaseLoader):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
query: Union[str, sa.Select],
|
query: Union[str, Select],
|
||||||
db: SQLDatabase,
|
db: SQLDatabase,
|
||||||
*,
|
*,
|
||||||
parameters: Optional[Dict[str, Any]] = None,
|
parameters: Optional[Dict[str, Any]] = None,
|
||||||
@ -106,7 +107,7 @@ class SQLDatabaseLoader(BaseLoader):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def page_content_default_mapper(
|
def page_content_default_mapper(
|
||||||
row: sa.RowMapping, column_names: Optional[List[str]] = None
|
row: RowMapping, column_names: Optional[List[str]] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
A reasonable default function to convert a record into a "page content" string.
|
A reasonable default function to convert a record into a "page content" string.
|
||||||
@ -121,7 +122,7 @@ class SQLDatabaseLoader(BaseLoader):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def metadata_default_mapper(
|
def metadata_default_mapper(
|
||||||
row: sa.RowMapping, column_names: Optional[List[str]] = None
|
row: RowMapping, column_names: Optional[List[str]] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
A reasonable default function to convert a record into a "metadata" dictionary.
|
A reasonable default function to convert a record into a "metadata" dictionary.
|
||||||
|
@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend.
|
|||||||
* Keys can be listed based on the updated at field.
|
* Keys can be listed based on the updated at field.
|
||||||
* Keys can be deleted.
|
* Keys can be deleted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import decimal
|
import decimal
|
||||||
import uuid
|
import uuid
|
||||||
@ -29,9 +30,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
URL,
|
|
||||||
Column,
|
Column,
|
||||||
Engine,
|
|
||||||
Float,
|
Float,
|
||||||
Index,
|
Index,
|
||||||
String,
|
String,
|
||||||
@ -42,15 +41,21 @@ from sqlalchemy import (
|
|||||||
select,
|
select,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy.engine import URL, Engine
|
||||||
from sqlalchemy.ext.asyncio import (
|
from sqlalchemy.ext.asyncio import (
|
||||||
AsyncEngine,
|
AsyncEngine,
|
||||||
AsyncSession,
|
AsyncSession,
|
||||||
async_sessionmaker,
|
|
||||||
create_async_engine,
|
create_async_engine,
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||||
|
except ImportError:
|
||||||
|
# dummy for sqlalchemy < 2
|
||||||
|
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
|
||||||
|
|
||||||
from langchain_community.indexes.base import RecordManager
|
from langchain_community.indexes.base import RecordManager
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
@ -20,7 +20,7 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from langchain_core._api import deprecated, warn_deprecated
|
from langchain_core._api import deprecated, warn_deprecated
|
||||||
from sqlalchemy import SQLColumnExpression, delete, func
|
from sqlalchemy import delete, func
|
||||||
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
|
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
|
||||||
from sqlalchemy.orm import Session, relationship
|
from sqlalchemy.orm import Session, relationship
|
||||||
|
|
||||||
@ -29,6 +29,12 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy import SQLColumnExpression
|
||||||
|
except ImportError:
|
||||||
|
# for sqlalchemy < 2
|
||||||
|
SQLColumnExpression = Any # type: ignore
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
|
@ -4,7 +4,17 @@ from typing import Any, AsyncGenerator, Generator, List, Tuple
|
|||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from sqlalchemy import Column, Integer, Text
|
from sqlalchemy import Column, Integer, Text
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
except ImportError:
|
||||||
|
# for sqlalchemy < 2
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
Base = declarative_base() # type:ignore
|
||||||
|
|
||||||
from langchain_community.chat_message_histories import SQLChatMessageHistory
|
from langchain_community.chat_message_histories import SQLChatMessageHistory
|
||||||
from langchain_community.chat_message_histories.sql import DefaultMessageConverter
|
from langchain_community.chat_message_histories.sql import DefaultMessageConverter
|
||||||
@ -198,9 +208,6 @@ async def test_async_clear_messages(
|
|||||||
|
|
||||||
|
|
||||||
def test_model_no_session_id_field_error(con_str: str) -> None:
|
def test_model_no_session_id_field_error(con_str: str) -> None:
|
||||||
class Base(DeclarativeBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Model(Base):
|
class Model(Base):
|
||||||
__tablename__ = "test_table"
|
__tablename__ = "test_table"
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
|
@ -1,21 +1,25 @@
|
|||||||
# flake8: noqa: E501
|
# flake8: noqa: E501
|
||||||
"""Test SQL database wrapper."""
|
"""Test SQL database wrapper."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from packaging import version
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column,
|
Column,
|
||||||
Integer,
|
Integer,
|
||||||
MetaData,
|
MetaData,
|
||||||
Result,
|
|
||||||
String,
|
String,
|
||||||
Table,
|
Table,
|
||||||
Text,
|
Text,
|
||||||
insert,
|
insert,
|
||||||
select,
|
select,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy.engine import Engine, Result
|
||||||
|
|
||||||
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
|
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
|
||||||
|
|
||||||
|
is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1
|
||||||
|
|
||||||
metadata_obj = MetaData()
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
user = Table(
|
user = Table(
|
||||||
@ -35,18 +39,18 @@ company = Table(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def engine() -> sa.Engine:
|
def engine() -> Engine:
|
||||||
return sa.create_engine("sqlite:///:memory:")
|
return sa.create_engine("sqlite:///:memory:")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db(engine: sa.Engine) -> SQLDatabase:
|
def db(engine: Engine) -> SQLDatabase:
|
||||||
metadata_obj.create_all(engine)
|
metadata_obj.create_all(engine)
|
||||||
return SQLDatabase(engine)
|
return SQLDatabase(engine)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db_lazy_reflection(engine: sa.Engine) -> SQLDatabase:
|
def db_lazy_reflection(engine: Engine) -> SQLDatabase:
|
||||||
metadata_obj.create_all(engine)
|
metadata_obj.create_all(engine)
|
||||||
return SQLDatabase(engine, lazy_table_reflection=True)
|
return SQLDatabase(engine, lazy_table_reflection=True)
|
||||||
|
|
||||||
@ -230,6 +234,7 @@ def test_sql_database_run_update(db: SQLDatabase) -> None:
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(is_sqlalchemy_v1, reason="Requires SQLAlchemy 2 or newer")
|
||||||
def test_sql_database_schema_translate_map() -> None:
|
def test_sql_database_schema_translate_map() -> None:
|
||||||
"""Verify using statement-specific execution options."""
|
"""Verify using statement-specific execution options."""
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend.
|
|||||||
* Keys can be listed based on the updated at field.
|
* Keys can be listed based on the updated at field.
|
||||||
* Keys can be deleted.
|
* Keys can be deleted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import decimal
|
import decimal
|
||||||
import uuid
|
import uuid
|
||||||
@ -20,9 +21,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequenc
|
|||||||
|
|
||||||
from langchain_core.indexing import RecordManager
|
from langchain_core.indexing import RecordManager
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
URL,
|
|
||||||
Column,
|
Column,
|
||||||
Engine,
|
|
||||||
Float,
|
Float,
|
||||||
Index,
|
Index,
|
||||||
String,
|
String,
|
||||||
@ -33,15 +32,21 @@ from sqlalchemy import (
|
|||||||
select,
|
select,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy.engine import URL, Engine
|
||||||
from sqlalchemy.ext.asyncio import (
|
from sqlalchemy.ext.asyncio import (
|
||||||
AsyncEngine,
|
AsyncEngine,
|
||||||
AsyncSession,
|
AsyncSession,
|
||||||
async_sessionmaker,
|
|
||||||
create_async_engine,
|
create_async_engine,
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import Query, Session, sessionmaker
|
from sqlalchemy.orm import Query, Session, sessionmaker
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||||
|
except ImportError:
|
||||||
|
# dummy for sqlalchemy < 2
|
||||||
|
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user