mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
Fix sqlalchemy warnings when running tests (#733)
This has been bugging me when running my own tests that call langchain methods :P
This commit is contained in:
parent
bd0bf4e0a9
commit
fa6826e417
@ -4,8 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
|
||||
from langchain.schema import Generation
|
||||
|
||||
|
@ -86,7 +86,7 @@ class SQLDatabase:
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
"""
|
||||
with self._engine.connect() as connection:
|
||||
with self._engine.begin() as connection:
|
||||
if self._schema is not None:
|
||||
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
||||
cursor = connection.exec_driver_sql(command)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test base LLM functionality."""
|
||||
from sqlalchemy import Column, Integer, Sequence, String, create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
import langchain
|
||||
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
||||
|
@ -40,7 +40,7 @@ def test_sql_database_run() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
stmt = insert(user).values(user_id=13, user_name="Harrison")
|
||||
with engine.connect() as conn:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
command = "select user_name from user where user_id = 13"
|
||||
@ -54,7 +54,7 @@ def test_sql_database_run_update() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
stmt = insert(user).values(user_id=13, user_name="Harrison")
|
||||
with engine.connect() as conn:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
command = "update user set user_name='Updated' where user_id = 13"
|
||||
|
@ -57,7 +57,7 @@ def test_sql_database_run() -> None:
|
||||
engine = create_engine("duckdb:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
stmt = insert(user).values(user_id=13, user_name="Harrison")
|
||||
with engine.connect() as conn:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine, schema="schema_a")
|
||||
command = 'select user_name from "user" where user_id = 13'
|
||||
|
Loading…
Reference in New Issue
Block a user