mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
community[patch]: restore compatibility with SQLAlchemy 1.x (#22546)
- **Description:** Restores compatibility with SQLAlchemy 1.4.x that was broken since #18992 and adds a test run for this version on CI (only for Python 3.11) - **Issue:** fixes #19681 - **Dependencies:** None - **Twitter handle:** `@krassowski_m` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
48d6ea427f
commit
710197e18c
@@ -4,7 +4,17 @@ from typing import Any, AsyncGenerator, Generator, List, Tuple
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from sqlalchemy import Column, Integer, Text
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
except ImportError:
|
||||
# for sqlalchemy < 2
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base() # type:ignore
|
||||
|
||||
from langchain_community.chat_message_histories import SQLChatMessageHistory
|
||||
from langchain_community.chat_message_histories.sql import DefaultMessageConverter
|
||||
@@ -198,9 +208,6 @@ async def test_async_clear_messages(
|
||||
|
||||
|
||||
def test_model_no_session_id_field_error(con_str: str) -> None:
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Model(Base):
|
||||
__tablename__ = "test_table"
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
@@ -1,21 +1,25 @@
|
||||
# flake8: noqa: E501
|
||||
"""Test SQL database wrapper."""
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from packaging import version
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
Result,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
insert,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.engine import Engine, Result
|
||||
|
||||
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
|
||||
|
||||
is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1
|
||||
|
||||
metadata_obj = MetaData()
|
||||
|
||||
user = Table(
|
||||
@@ -35,18 +39,18 @@ company = Table(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine() -> sa.Engine:
|
||||
def engine() -> Engine:
|
||||
return sa.create_engine("sqlite:///:memory:")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(engine: sa.Engine) -> SQLDatabase:
|
||||
def db(engine: Engine) -> SQLDatabase:
|
||||
metadata_obj.create_all(engine)
|
||||
return SQLDatabase(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_lazy_reflection(engine: sa.Engine) -> SQLDatabase:
|
||||
def db_lazy_reflection(engine: Engine) -> SQLDatabase:
|
||||
metadata_obj.create_all(engine)
|
||||
return SQLDatabase(engine, lazy_table_reflection=True)
|
||||
|
||||
@@ -230,6 +234,7 @@ def test_sql_database_run_update(db: SQLDatabase) -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_sqlalchemy_v1, reason="Requires SQLAlchemy 2 or newer")
|
||||
def test_sql_database_schema_translate_map() -> None:
|
||||
"""Verify using statement-specific execution options."""
|
||||
|
||||
|
Reference in New Issue
Block a user