From b7747017d72eaeabfa65edb7eec413d3fd006ddb Mon Sep 17 00:00:00 2001 From: Shahriar Tajbakhsh Date: Sat, 11 Feb 2023 02:33:47 +0000 Subject: [PATCH] Import of `declarative_base` when SQLAlchemy <1.4 (#883) In [pyproject.toml](https://github.com/hwchase17/langchain/blob/master/pyproject.toml), the expectation is `SQLAlchemy = "^1"`. But, the way `declarative_base` is imported in [cache.py](https://github.com/hwchase17/langchain/blob/master/langchain/cache.py) will only work with SQLAlchemy >=1.4. This PR makes sure Langchain can be run in environments with SQLAlchemy <1.4 --- langchain/cache.py | 7 ++++++- tests/unit_tests/llms/test_base.py | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/langchain/cache.py b/langchain/cache.py index c81142ccd1b..f7cecf45ff2 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -4,7 +4,12 @@ from typing import Any, Dict, List, Optional, Tuple from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy.engine.base import Engine -from sqlalchemy.orm import Session, declarative_base +from sqlalchemy.orm import Session + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base from langchain.schema import Generation diff --git a/tests/unit_tests/llms/test_base.py b/tests/unit_tests/llms/test_base.py index 8838997dc71..973c1a185b4 100644 --- a/tests/unit_tests/llms/test_base.py +++ b/tests/unit_tests/llms/test_base.py @@ -1,6 +1,10 @@ """Test base LLM functionality.""" from sqlalchemy import Column, Integer, Sequence, String, create_engine -from sqlalchemy.orm import declarative_base + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base import langchain from langchain.cache import InMemoryCache, SQLAlchemyCache