mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
community[patch]: restore compatibility with SQLAlchemy 1.x (#22546)
- **Description:** Restores compatibility with SQLAlchemy 1.4.x that was broken since #18992 and adds a test run for this version on CI (only for Python 3.11) - **Issue:** fixes #19681 - **Dependencies:** None - **Twitter handle:** `@krassowski_m` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
48d6ea427f
commit
710197e18c
@ -1,6 +1,7 @@
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine import RowMapping
|
||||
from sqlalchemy.sql.expression import Select
|
||||
|
||||
from langchain_community.docstore.document import Document
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
@ -19,7 +20,7 @@ class SQLDatabaseLoader(BaseLoader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: Union[str, sa.Select],
|
||||
query: Union[str, Select],
|
||||
db: SQLDatabase,
|
||||
*,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
@ -106,7 +107,7 @@ class SQLDatabaseLoader(BaseLoader):
|
||||
|
||||
@staticmethod
|
||||
def page_content_default_mapper(
|
||||
row: sa.RowMapping, column_names: Optional[List[str]] = None
|
||||
row: RowMapping, column_names: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""
|
||||
A reasonable default function to convert a record into a "page content" string.
|
||||
@ -121,7 +122,7 @@ class SQLDatabaseLoader(BaseLoader):
|
||||
|
||||
@staticmethod
|
||||
def metadata_default_mapper(
|
||||
row: sa.RowMapping, column_names: Optional[List[str]] = None
|
||||
row: RowMapping, column_names: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
A reasonable default function to convert a record into a "metadata" dictionary.
|
||||
|
@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend.
|
||||
* Keys can be listed based on the updated at field.
|
||||
* Keys can be deleted.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import decimal
|
||||
import uuid
|
||||
@ -29,9 +30,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from sqlalchemy import (
|
||||
URL,
|
||||
Column,
|
||||
Engine,
|
||||
Float,
|
||||
Index,
|
||||
String,
|
||||
@ -42,15 +41,21 @@ from sqlalchemy import (
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import URL, Engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
try:
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
except ImportError:
|
||||
# dummy for sqlalchemy < 2
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
|
||||
|
||||
from langchain_community.indexes.base import RecordManager
|
||||
|
||||
Base = declarative_base()
|
||||
|
@ -20,7 +20,7 @@ from typing import (
|
||||
import numpy as np
|
||||
import sqlalchemy
|
||||
from langchain_core._api import deprecated, warn_deprecated
|
||||
from sqlalchemy import SQLColumnExpression, delete, func
|
||||
from sqlalchemy import delete, func
|
||||
from sqlalchemy.dialects.postgresql import JSON, JSONB, UUID
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
@ -29,6 +29,12 @@ try:
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
try:
|
||||
from sqlalchemy import SQLColumnExpression
|
||||
except ImportError:
|
||||
# for sqlalchemy < 2
|
||||
SQLColumnExpression = Any # type: ignore
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
@ -4,7 +4,17 @@ from typing import Any, AsyncGenerator, Generator, List, Tuple
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from sqlalchemy import Column, Integer, Text
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
except ImportError:
|
||||
# for sqlalchemy < 2
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base() # type:ignore
|
||||
|
||||
from langchain_community.chat_message_histories import SQLChatMessageHistory
|
||||
from langchain_community.chat_message_histories.sql import DefaultMessageConverter
|
||||
@ -198,9 +208,6 @@ async def test_async_clear_messages(
|
||||
|
||||
|
||||
def test_model_no_session_id_field_error(con_str: str) -> None:
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Model(Base):
|
||||
__tablename__ = "test_table"
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
@ -1,21 +1,25 @@
|
||||
# flake8: noqa: E501
|
||||
"""Test SQL database wrapper."""
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from packaging import version
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
Result,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
insert,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.engine import Engine, Result
|
||||
|
||||
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
|
||||
|
||||
is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1
|
||||
|
||||
metadata_obj = MetaData()
|
||||
|
||||
user = Table(
|
||||
@ -35,18 +39,18 @@ company = Table(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine() -> sa.Engine:
|
||||
def engine() -> Engine:
|
||||
return sa.create_engine("sqlite:///:memory:")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(engine: sa.Engine) -> SQLDatabase:
|
||||
def db(engine: Engine) -> SQLDatabase:
|
||||
metadata_obj.create_all(engine)
|
||||
return SQLDatabase(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_lazy_reflection(engine: sa.Engine) -> SQLDatabase:
|
||||
def db_lazy_reflection(engine: Engine) -> SQLDatabase:
|
||||
metadata_obj.create_all(engine)
|
||||
return SQLDatabase(engine, lazy_table_reflection=True)
|
||||
|
||||
@ -230,6 +234,7 @@ def test_sql_database_run_update(db: SQLDatabase) -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_sqlalchemy_v1, reason="Requires SQLAlchemy 2 or newer")
|
||||
def test_sql_database_schema_translate_map() -> None:
|
||||
"""Verify using statement-specific execution options."""
|
||||
|
||||
|
@ -13,6 +13,7 @@ allow it to work with a variety of SQL as a backend.
|
||||
* Keys can be listed based on the updated at field.
|
||||
* Keys can be deleted.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import decimal
|
||||
import uuid
|
||||
@ -20,9 +21,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequenc
|
||||
|
||||
from langchain_core.indexing import RecordManager
|
||||
from sqlalchemy import (
|
||||
URL,
|
||||
Column,
|
||||
Engine,
|
||||
Float,
|
||||
Index,
|
||||
String,
|
||||
@ -33,15 +32,21 @@ from sqlalchemy import (
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import URL, Engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Query, Session, sessionmaker
|
||||
|
||||
try:
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
except ImportError:
|
||||
# dummy for sqlalchemy < 2
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user