Fix sqlalchemy warnings when running tests (#733)

This has been bugging me when running my own tests that call langchain
methods :P
This commit is contained in:
Amos Ng 2023-01-25 22:14:07 +07:00 committed by GitHub
parent bd0bf4e0a9
commit fa6826e417
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 6 additions and 7 deletions

View File

@ -4,8 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.orm import Session
from langchain.schema import Generation from langchain.schema import Generation

View File

@ -86,7 +86,7 @@ class SQLDatabase:
If the statement returns rows, a string of the results is returned. If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned. If the statement returns no rows, an empty string is returned.
""" """
with self._engine.connect() as connection: with self._engine.begin() as connection:
if self._schema is not None: if self._schema is not None:
connection.exec_driver_sql(f"SET search_path TO {self._schema}") connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.exec_driver_sql(command) cursor = connection.exec_driver_sql(command)

View File

@ -1,6 +1,6 @@
"""Test base LLM functionality.""" """Test base LLM functionality."""
from sqlalchemy import Column, Integer, Sequence, String, create_engine from sqlalchemy import Column, Integer, Sequence, String, create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import declarative_base
import langchain import langchain
from langchain.cache import InMemoryCache, SQLAlchemyCache from langchain.cache import InMemoryCache, SQLAlchemyCache

View File

@ -40,7 +40,7 @@ def test_sql_database_run() -> None:
engine = create_engine("sqlite:///:memory:") engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison") stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn: with engine.begin() as conn:
conn.execute(stmt) conn.execute(stmt)
db = SQLDatabase(engine) db = SQLDatabase(engine)
command = "select user_name from user where user_id = 13" command = "select user_name from user where user_id = 13"
@ -54,7 +54,7 @@ def test_sql_database_run_update() -> None:
engine = create_engine("sqlite:///:memory:") engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison") stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn: with engine.begin() as conn:
conn.execute(stmt) conn.execute(stmt)
db = SQLDatabase(engine) db = SQLDatabase(engine)
command = "update user set user_name='Updated' where user_id = 13" command = "update user set user_name='Updated' where user_id = 13"

View File

@ -57,7 +57,7 @@ def test_sql_database_run() -> None:
engine = create_engine("duckdb:///:memory:") engine = create_engine("duckdb:///:memory:")
metadata_obj.create_all(engine) metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison") stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn: with engine.begin() as conn:
conn.execute(stmt) conn.execute(stmt)
db = SQLDatabase(engine, schema="schema_a") db = SQLDatabase(engine, schema="schema_a")
command = 'select user_name from "user" where user_id = 13' command = 'select user_name from "user" where user_id = 13'