Move integration tests

This commit is contained in:
Eugene Yurtsev
2024-05-03 10:00:06 -04:00
parent a51c5937c3
commit 33dac2e00d
47 changed files with 0 additions and 4567 deletions

View File

@@ -1,175 +0,0 @@
import asyncio
import os
import time
import urllib.request
import uuid
from enum import Enum
from typing import Any
from urllib.error import HTTPError
import pytest
from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit
from langchain_community.chat_models import ChatOpenAI
from langchain_community.tools.ainetwork.utils import authenticate
from langchain.agents import AgentType, initialize_agent
class Match(Enum):
__test__ = False
ListWildcard = 1
StrWildcard = 2
DictWildcard = 3
IntWildcard = 4
FloatWildcard = 5
ObjectWildcard = 6
@classmethod
def match(cls, value: Any, template: Any) -> bool:
if template is cls.ListWildcard:
return isinstance(value, list)
elif template is cls.StrWildcard:
return isinstance(value, str)
elif template is cls.DictWildcard:
return isinstance(value, dict)
elif template is cls.IntWildcard:
return isinstance(value, int)
elif template is cls.FloatWildcard:
return isinstance(value, float)
elif template is cls.ObjectWildcard:
return True
elif type(value) != type(template):
return False
elif isinstance(value, dict):
if len(value) != len(template):
return False
for k, v in value.items():
if k not in template or not cls.match(v, template[k]):
return False
return True
elif isinstance(value, list):
if len(value) != len(template):
return False
for i in range(len(value)):
if not cls.match(value[i], template[i]):
return False
return True
else:
return value == template
@pytest.mark.requires("ain")
def test_ainetwork_toolkit() -> None:
def get(path: str, type: str = "value", default: Any = None) -> Any:
ref = ain.db.ref(path)
value = asyncio.run(
{
"value": ref.getValue,
"rule": ref.getRule,
"owner": ref.getOwner,
}[type]()
)
return default if value is None else value
def validate(path: str, template: Any, type: str = "value") -> bool:
value = get(path, type)
return Match.match(value, template)
if not os.environ.get("AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY", None):
from ain.account import Account
account = Account.create()
os.environ["AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY"] = account.private_key
interface = authenticate(network="testnet")
toolkit = AINetworkToolkit(network="testnet", interface=interface)
llm = ChatOpenAI(model="gpt-4", temperature=0)
agent = initialize_agent(
tools=toolkit.get_tools(),
llm=llm,
verbose=True,
agent=AgentType.OPENAI_FUNCTIONS,
)
ain = interface
self_address = ain.wallet.defaultAccount.address
co_address = "0x6813Eb9362372EEF6200f3b1dbC3f819671cBA69"
# Test creating an app
UUID = uuid.UUID(
int=(int(time.time() * 1000) << 64) | (uuid.uuid4().int & ((1 << 64) - 1))
)
app_name = f"_langchain_test__{str(UUID).replace('-', '_')}"
agent.run(f"""Create app {app_name}""")
validate(f"/manage_app/{app_name}/config", {"admin": {self_address: True}})
validate(f"/apps/{app_name}/DB", None, "owner")
# Test reading owner config
agent.run(f"""Read owner config of /apps/{app_name}/DB .""")
assert ...
# Test granting owner config
agent.run(
f"""Grant owner authority to {co_address} for edit write rule permission of /apps/{app_name}/DB_co .""" # noqa: E501
)
validate(
f"/apps/{app_name}/DB_co",
{
".owner": {
"owners": {
co_address: {
"branch_owner": False,
"write_function": False,
"write_owner": False,
"write_rule": True,
}
}
}
},
"owner",
)
# Test reading owner config
agent.run(f"""Read owner config of /apps/{app_name}/DB_co .""")
assert ...
# Test reading owner config
agent.run(f"""Read owner config of /apps/{app_name}/DB .""")
assert ... # Check if owner {self_address} exists
# Test reading a value
agent.run(f"""Read value in /apps/{app_name}/DB""")
assert ... # empty
# Test writing a value
agent.run(f"""Write value {{1: 1904, 2: 43}} in /apps/{app_name}/DB""")
validate(f"/apps/{app_name}/DB", {1: 1904, 2: 43})
# Test reading a value
agent.run(f"""Read value in /apps/{app_name}/DB""")
assert ... # check value
# Test reading a rule
agent.run(f"""Read write rule of app {app_name} .""")
assert ... # check rule that self_address exists
# Test sending AIN
self_balance = get(f"/accounts/{self_address}/balance", default=0)
transaction_history = get(f"/transfer/{self_address}/{co_address}", default={})
if self_balance < 1:
try:
with urllib.request.urlopen(
f"http://faucet.ainetwork.ai/api/test/{self_address}/"
) as response:
try_test = response.getcode()
except HTTPError as e:
try_test = e.getcode()
else:
try_test = 200
if try_test == 200:
agent.run(f"""Send 1 AIN to {co_address}""")
transaction_update = get(f"/transfer/{self_address}/{co_address}", default={})
assert any(
transaction_update[key]["value"] == 1
for key in transaction_update.keys() - transaction_history.keys()
)

View File

@@ -1,46 +0,0 @@
import pytest
from langchain_community.agent_toolkits import PowerBIToolkit, create_pbi_agent
from langchain_community.chat_models import ChatOpenAI
from langchain_community.utilities.powerbi import PowerBIDataset
from langchain_core.utils import get_from_env
def azure_installed() -> bool:
try:
from azure.core.credentials import TokenCredential # noqa: F401
from azure.identity import DefaultAzureCredential # noqa: F401
return True
except Exception as e:
print(f"azure not installed, skipping test {e}") # noqa: T201
return False
@pytest.mark.skipif(not azure_installed(), reason="requires azure package")
def test_daxquery() -> None:
from azure.identity import DefaultAzureCredential
DATASET_ID = get_from_env("", "POWERBI_DATASET_ID")
TABLE_NAME = get_from_env("", "POWERBI_TABLE_NAME")
NUM_ROWS = get_from_env("", "POWERBI_NUMROWS")
fast_llm = ChatOpenAI(
temperature=0.5, max_tokens=1000, model_name="gpt-3.5-turbo", verbose=True
) # type: ignore[call-arg]
smart_llm = ChatOpenAI(
temperature=0, max_tokens=100, model_name="gpt-4", verbose=True
) # type: ignore[call-arg]
toolkit = PowerBIToolkit(
powerbi=PowerBIDataset(
dataset_id=DATASET_ID,
table_names=[TABLE_NAME],
credential=DefaultAzureCredential(),
),
llm=smart_llm,
)
agent_executor = create_pbi_agent(llm=fast_llm, toolkit=toolkit, verbose=True)
output = agent_executor.run(f"How many rows are in the table, {TABLE_NAME}")
assert NUM_ROWS in output

View File

@@ -1,159 +0,0 @@
"""
Test AstraDB caches. Requires an Astra DB vector instance.
Required to run this test:
- a recent `astrapy` Python package available
- an Astra DB instance;
- the two environment variables set:
export ASTRA_DB_API_ENDPOINT="https://<DB-ID>-us-east1.apps.astra.datastax.com"
export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........."
- optionally this as well (otherwise defaults are used):
export ASTRA_DB_KEYSPACE="my_keyspace"
"""
import os
from typing import AsyncIterator, Iterator
import pytest
from langchain_community.cache import AstraDBCache, AstraDBSemanticCache
from langchain_community.utilities.astradb import SetupMode
from langchain_core.caches import BaseCache
from langchain_core.language_models import LLM
from langchain_core.outputs import Generation, LLMResult
from langchain.globals import get_llm_cache, set_llm_cache
from tests.integration_tests.cache.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM
def _has_env_vars() -> bool:
return all(
[
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
"ASTRA_DB_API_ENDPOINT" in os.environ,
]
)
@pytest.fixture(scope="module")
def astradb_cache() -> Iterator[AstraDBCache]:
cache = AstraDBCache(
collection_name="lc_integration_test_cache",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield cache
cache.collection.astra_db.delete_collection("lc_integration_test_cache")
@pytest.fixture
async def async_astradb_cache() -> AsyncIterator[AstraDBCache]:
cache = AstraDBCache(
collection_name="lc_integration_test_cache_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
setup_mode=SetupMode.ASYNC,
)
yield cache
await cache.async_collection.astra_db.delete_collection(
"lc_integration_test_cache_async"
)
@pytest.fixture(scope="module")
def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]:
fake_embe = FakeEmbeddings()
sem_cache = AstraDBSemanticCache(
collection_name="lc_integration_test_sem_cache",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
embedding=fake_embe,
)
yield sem_cache
sem_cache.collection.astra_db.delete_collection("lc_integration_test_sem_cache")
@pytest.fixture
async def async_astradb_semantic_cache() -> AsyncIterator[AstraDBSemanticCache]:
fake_embe = FakeEmbeddings()
sem_cache = AstraDBSemanticCache(
collection_name="lc_integration_test_sem_cache_async",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
embedding=fake_embe,
setup_mode=SetupMode.ASYNC,
)
yield sem_cache
sem_cache.collection.astra_db.delete_collection(
"lc_integration_test_sem_cache_async"
)
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBCaches:
def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None:
self.do_cache_test(FakeLLM(), astradb_cache, "foo")
async def test_astradb_cache_async(self, async_astradb_cache: AstraDBCache) -> None:
await self.ado_cache_test(FakeLLM(), async_astradb_cache, "foo")
def test_astradb_semantic_cache(
self, astradb_semantic_cache: AstraDBSemanticCache
) -> None:
llm = FakeLLM()
self.do_cache_test(llm, astradb_semantic_cache, "bar")
output = llm.generate(["bar"]) # 'fizz' is erased away now
assert output != LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
astradb_semantic_cache.clear()
async def test_astradb_semantic_cache_async(
self, async_astradb_semantic_cache: AstraDBSemanticCache
) -> None:
llm = FakeLLM()
await self.ado_cache_test(llm, async_astradb_semantic_cache, "bar")
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
assert output != LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
await async_astradb_semantic_cache.aclear()
@staticmethod
def do_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
set_llm_cache(cache)
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate([prompt])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
cache.clear()
@staticmethod
async def ado_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
set_llm_cache(cache)
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
output = await llm.agenerate([prompt])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
await cache.aclear()

View File

@@ -1,359 +0,0 @@
"""Test Azure CosmosDB cache functionality.
Required to run this test:
- a recent 'pymongo' Python package available
- an Azure CosmosDB Mongo vCore instance
- one environment variable set:
export MONGODB_VCORE_URI="connection string for azure cosmos db mongo vCore"
"""
import os
import uuid
import pytest
from langchain_community.cache import AzureCosmosDBSemanticCache
from langchain_community.vectorstores.azure_cosmos_db import (
CosmosDBSimilarityType,
CosmosDBVectorSearchType,
)
from langchain_core.outputs import Generation
from langchain.globals import get_llm_cache, set_llm_cache
from tests.integration_tests.cache.fake_embeddings import (
FakeEmbeddings,
)
from tests.unit_tests.llms.fake_llm import FakeLLM
INDEX_NAME = "langchain-test-index"
NAMESPACE = "langchain_test_db.langchain_test_collection"
CONNECTION_STRING: str = os.environ.get("MONGODB_VCORE_URI", "")
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
num_lists = 3
dimensions = 10
similarity_algorithm = CosmosDBSimilarityType.COS
kind = CosmosDBVectorSearchType.VECTOR_IVF
m = 16
ef_construction = 64
ef_search = 40
score_threshold = 0.1
application_name = "LANGCHAIN_CACHING_PYTHON"
def _has_env_vars() -> bool:
return all(["MONGODB_VCORE_URI" in os.environ])
def random_string() -> str:
return str(uuid.uuid4())
@pytest.mark.requires("pymongo")
@pytest.mark.skipif(
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
)
def test_azure_cosmos_db_semantic_cache() -> None:
set_llm_cache(
AzureCosmosDBSemanticCache(
cosmosdb_connection_string=CONNECTION_STRING,
cosmosdb_client=None,
embedding=FakeEmbeddings(),
database_name=DB_NAME,
collection_name=COLLECTION_NAME,
num_lists=num_lists,
similarity=similarity_algorithm,
kind=kind,
dimensions=dimensions,
m=m,
ef_construction=ef_construction,
ef_search=ef_search,
score_threshold=score_threshold,
application_name=application_name,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz")]
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
@pytest.mark.requires("pymongo")
@pytest.mark.skipif(
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
)
def test_azure_cosmos_db_semantic_cache_inner_product() -> None:
set_llm_cache(
AzureCosmosDBSemanticCache(
cosmosdb_connection_string=CONNECTION_STRING,
cosmosdb_client=None,
embedding=FakeEmbeddings(),
database_name=DB_NAME,
collection_name=COLLECTION_NAME,
num_lists=num_lists,
similarity=CosmosDBSimilarityType.IP,
kind=kind,
dimensions=dimensions,
m=m,
ef_construction=ef_construction,
ef_search=ef_search,
score_threshold=score_threshold,
application_name=application_name,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz")]
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
@pytest.mark.requires("pymongo")
@pytest.mark.skipif(
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
)
def test_azure_cosmos_db_semantic_cache_multi() -> None:
set_llm_cache(
AzureCosmosDBSemanticCache(
cosmosdb_connection_string=CONNECTION_STRING,
cosmosdb_client=None,
embedding=FakeEmbeddings(),
database_name=DB_NAME,
collection_name=COLLECTION_NAME,
num_lists=num_lists,
similarity=similarity_algorithm,
kind=kind,
dimensions=dimensions,
m=m,
ef_construction=ef_construction,
ef_search=ef_search,
score_threshold=score_threshold,
application_name=application_name,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
@pytest.mark.requires("pymongo")
@pytest.mark.skipif(
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
)
def test_azure_cosmos_db_semantic_cache_multi_inner_product() -> None:
set_llm_cache(
AzureCosmosDBSemanticCache(
cosmosdb_connection_string=CONNECTION_STRING,
cosmosdb_client=None,
embedding=FakeEmbeddings(),
database_name=DB_NAME,
collection_name=COLLECTION_NAME,
num_lists=num_lists,
similarity=CosmosDBSimilarityType.IP,
kind=kind,
dimensions=dimensions,
m=m,
ef_construction=ef_construction,
ef_search=ef_search,
score_threshold=score_threshold,
application_name=application_name,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
@pytest.mark.requires("pymongo")
@pytest.mark.skipif(
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
)
def test_azure_cosmos_db_semantic_cache_hnsw() -> None:
set_llm_cache(
AzureCosmosDBSemanticCache(
cosmosdb_connection_string=CONNECTION_STRING,
cosmosdb_client=None,
embedding=FakeEmbeddings(),
database_name=DB_NAME,
collection_name=COLLECTION_NAME,
num_lists=num_lists,
similarity=similarity_algorithm,
kind=CosmosDBVectorSearchType.VECTOR_HNSW,
dimensions=dimensions,
m=m,
ef_construction=ef_construction,
ef_search=ef_search,
score_threshold=score_threshold,
application_name=application_name,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz")]
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
@pytest.mark.requires("pymongo")
@pytest.mark.skipif(
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
)
def test_azure_cosmos_db_semantic_cache_inner_product_hnsw() -> None:
set_llm_cache(
AzureCosmosDBSemanticCache(
cosmosdb_connection_string=CONNECTION_STRING,
cosmosdb_client=None,
embedding=FakeEmbeddings(),
database_name=DB_NAME,
collection_name=COLLECTION_NAME,
num_lists=num_lists,
similarity=CosmosDBSimilarityType.IP,
kind=CosmosDBVectorSearchType.VECTOR_HNSW,
dimensions=dimensions,
m=m,
ef_construction=ef_construction,
ef_search=ef_search,
score_threshold=score_threshold,
application_name=application_name,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz")]
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
@pytest.mark.requires("pymongo")
@pytest.mark.skipif(
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
)
def test_azure_cosmos_db_semantic_cache_multi_hnsw() -> None:
set_llm_cache(
AzureCosmosDBSemanticCache(
cosmosdb_connection_string=CONNECTION_STRING,
cosmosdb_client=None,
embedding=FakeEmbeddings(),
database_name=DB_NAME,
collection_name=COLLECTION_NAME,
num_lists=num_lists,
similarity=similarity_algorithm,
kind=CosmosDBVectorSearchType.VECTOR_HNSW,
dimensions=dimensions,
m=m,
ef_construction=ef_construction,
ef_search=ef_search,
score_threshold=score_threshold,
application_name=application_name,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
@pytest.mark.requires("pymongo")
@pytest.mark.skipif(
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
)
def test_azure_cosmos_db_semantic_cache_multi_inner_product_hnsw() -> None:
set_llm_cache(
AzureCosmosDBSemanticCache(
cosmosdb_connection_string=CONNECTION_STRING,
cosmosdb_client=None,
embedding=FakeEmbeddings(),
database_name=DB_NAME,
collection_name=COLLECTION_NAME,
num_lists=num_lists,
similarity=CosmosDBSimilarityType.IP,
kind=CosmosDBVectorSearchType.VECTOR_HNSW,
dimensions=dimensions,
m=m,
ef_construction=ef_construction,
ef_search=ef_search,
score_threshold=score_threshold,
application_name=application_name,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
# clear the cache
get_llm_cache().clear(llm_string=llm_string)

View File

@@ -1,177 +0,0 @@
"""Test Cassandra caches. Requires a running vector-capable Cassandra cluster."""
import asyncio
import os
import time
from typing import Any, Iterator, Tuple
import pytest
from langchain_community.cache import CassandraCache, CassandraSemanticCache
from langchain_community.utilities.cassandra import SetupMode
from langchain_core.outputs import Generation, LLMResult
from langchain.globals import get_llm_cache, set_llm_cache
from tests.integration_tests.cache.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture(scope="module")
def cassandra_connection() -> Iterator[Tuple[Any, str]]:
from cassandra.cluster import Cluster
keyspace = "langchain_cache_test_keyspace"
# get db connection
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = os.environ["CONTACT_POINTS"].split(",")
cluster = Cluster(contact_points)
else:
cluster = Cluster()
#
session = cluster.connect()
# ensure keyspace exists
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
)
)
yield (session, keyspace)
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(session=session, keyspace=keyspace)
set_llm_cache(cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
cache.clear()
async def test_cassandra_cache_async(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(
session=session, keyspace=keyspace, setup_mode=SetupMode.ASYNC
)
set_llm_cache(cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
output = await llm.agenerate(["foo"])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
await cache.aclear()
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
set_llm_cache(cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
output = llm.generate(["foo"])
assert output == expected_output
time.sleep(2.5)
# entry has expired away.
output = llm.generate(["foo"])
assert output != expected_output
cache.clear()
async def test_cassandra_cache_ttl_async(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
cache = CassandraCache(
session=session, keyspace=keyspace, ttl_seconds=2, setup_mode=SetupMode.ASYNC
)
set_llm_cache(cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
output = await llm.agenerate(["foo"])
assert output == expected_output
await asyncio.sleep(2.5)
# entry has expired away.
output = await llm.agenerate(["foo"])
assert output != expected_output
await cache.aclear()
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
session, keyspace = cassandra_connection
sem_cache = CassandraSemanticCache(
session=session,
keyspace=keyspace,
embedding=FakeEmbeddings(),
)
set_llm_cache(sem_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["bar"]) # same embedding as 'foo'
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
sem_cache.clear()
output = llm.generate(["bar"]) # 'fizz' is erased away now
assert output != expected_output
sem_cache.clear()
async def test_cassandra_semantic_cache_async(
cassandra_connection: Tuple[Any, str],
) -> None:
session, keyspace = cassandra_connection
sem_cache = CassandraSemanticCache(
session=session,
keyspace=keyspace,
embedding=FakeEmbeddings(),
setup_mode=SetupMode.ASYNC,
)
set_llm_cache(sem_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
output = await llm.agenerate(["bar"]) # same embedding as 'foo'
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
await sem_cache.aclear()
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
assert output != expected_output
await sem_cache.aclear()

View File

@@ -1,62 +0,0 @@
import os
from typing import Any, Callable, Union
import pytest
from langchain_community.cache import GPTCache
from langchain_core.outputs import Generation
from langchain.globals import get_llm_cache, set_llm_cache
from tests.unit_tests.llms.fake_llm import FakeLLM
try:
from gptcache import Cache # noqa: F401
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
gptcache_installed = True
except ImportError:
gptcache_installed = False
def init_gptcache_map(cache_obj: Any) -> None:
i = getattr(init_gptcache_map, "_i", 0)
cache_path = f"data_map_{i}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
init_gptcache_map._i = i + 1 # type: ignore
def init_gptcache_map_with_llm(cache_obj: Any, llm: str) -> None:
cache_path = f"data_map_{llm}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
@pytest.mark.parametrize(
"init_func", [None, init_gptcache_map, init_gptcache_map_with_llm]
)
def test_gptcache_caching(
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None],
) -> None:
"""Test gptcache default caching behavior."""
set_llm_cache(GPTCache(init_func))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
_ = llm.generate(["foo", "bar", "foo"])
cache_output = get_llm_cache().lookup("foo", llm_string)
assert cache_output == [Generation(text="fizz")]
get_llm_cache().clear()
assert get_llm_cache().lookup("bar", llm_string) is None

View File

@@ -1,97 +0,0 @@
"""Test Momento cache functionality.
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
Momento auth token. This can be obtained by signing up for a free
Momento account at https://gomomento.com/.
"""
from __future__ import annotations
import uuid
from datetime import timedelta
from typing import Iterator
import pytest
from langchain_community.cache import MomentoCache
from langchain_core.outputs import Generation, LLMResult
from langchain.globals import set_llm_cache
from tests.unit_tests.llms.fake_llm import FakeLLM
def random_string() -> str:
return str(uuid.uuid4())
@pytest.fixture(scope="module")
def momento_cache() -> Iterator[MomentoCache]:
from momento import CacheClient, Configurations, CredentialProvider
cache_name = f"langchain-test-cache-{random_string()}"
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_API_KEY"),
default_ttl=timedelta(seconds=30),
)
try:
llm_cache = MomentoCache(client, cache_name)
set_llm_cache(llm_cache)
yield llm_cache
finally:
client.delete_cache(cache_name)
def test_invalid_ttl() -> None:
from momento import CacheClient, Configurations, CredentialProvider
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_API_KEY"),
default_ttl=timedelta(seconds=30),
)
with pytest.raises(ValueError):
MomentoCache(client, cache_name=random_string(), ttl=timedelta(seconds=-1))
def test_momento_cache_miss(momento_cache: MomentoCache) -> None:
llm = FakeLLM()
stub_llm_output = LLMResult(generations=[[Generation(text="foo")]])
assert llm.generate([random_string()]) == stub_llm_output
@pytest.mark.parametrize(
"prompts, generations",
[
# Single prompt, single generation
([random_string()], [[random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# Multiple prompts, multiple generations
(
[random_string(), random_string()],
[[random_string()], [random_string(), random_string()]],
),
],
)
def test_momento_cache_hit(
momento_cache: MomentoCache, prompts: list[str], generations: list[list[str]]
) -> None:
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_generations = [
[
Generation(text=generation, generation_info=params)
for generation in prompt_i_generations
]
for prompt_i_generations in generations
]
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
momento_cache.update(prompt_i, llm_string, llm_generations_i)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}
)

View File

@@ -1,59 +0,0 @@
from langchain_community.cache import OpenSearchSemanticCache
from langchain_core.outputs import Generation
from langchain.globals import get_llm_cache, set_llm_cache
from tests.integration_tests.cache.fake_embeddings import (
FakeEmbeddings,
)
from tests.unit_tests.llms.fake_llm import FakeLLM
DEFAULT_OPENSEARCH_URL = "http://localhost:9200"
def test_opensearch_semantic_cache() -> None:
"""Test opensearch semantic cache functionality."""
set_llm_cache(
OpenSearchSemanticCache(
embedding=FakeEmbeddings(),
opensearch_url=DEFAULT_OPENSEARCH_URL,
score_threshold=0.0,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz")]
get_llm_cache().clear(llm_string=llm_string)
output = get_llm_cache().lookup("bar", llm_string)
assert output != [Generation(text="fizz")]
def test_opensearch_semantic_cache_multi() -> None:
set_llm_cache(
OpenSearchSemanticCache(
embedding=FakeEmbeddings(),
opensearch_url=DEFAULT_OPENSEARCH_URL,
score_threshold=0.0,
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
# foo and bar will have the same embedding produced by FakeEmbeddings
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
# clear the cache
get_llm_cache().clear(llm_string=llm_string)
output = get_llm_cache().lookup("bar", llm_string)
assert output != [Generation(text="fizz"), Generation(text="Buzz")]

View File

@@ -1,319 +0,0 @@
"""Test Redis cache functionality."""
import uuid
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncGenerator, Generator, List, Optional, cast
import pytest
from langchain_community.cache import AsyncRedisCache, RedisCache, RedisSemanticCache
from langchain_core.embeddings import Embeddings
from langchain_core.load.dump import dumps
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
from langchain.globals import get_llm_cache, set_llm_cache
from tests.integration_tests.cache.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
# Using a non-standard port to avoid conflicts with potentially local running
# redis instances
# You can spin up a local redis using docker compose
# cd [repository-root]/docker
# docker-compose up redis
REDIS_TEST_URL = "redis://localhost:6020"
def random_string() -> str:
return str(uuid.uuid4())
@contextmanager
def get_sync_redis(*, ttl: Optional[int] = 1) -> Generator[RedisCache, None, None]:
"""Get a sync RedisCache instance."""
import redis
cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=ttl)
try:
yield cache
finally:
cache.clear()
@asynccontextmanager
async def get_async_redis(
*, ttl: Optional[int] = 1
) -> AsyncGenerator[AsyncRedisCache, None]:
"""Get an async RedisCache instance."""
from redis.asyncio import Redis
cache = AsyncRedisCache(redis_=Redis.from_url(REDIS_TEST_URL), ttl=ttl)
try:
yield cache
finally:
await cache.aclear()
def test_redis_cache_ttl() -> None:
from redis import Redis
with get_sync_redis() as llm_cache:
set_llm_cache(llm_cache)
llm_cache.update("foo", "bar", [Generation(text="fizz")])
key = llm_cache._key("foo", "bar")
assert isinstance(llm_cache.redis, Redis)
assert llm_cache.redis.pttl(key) > 0
async def test_async_redis_cache_ttl() -> None:
from redis.asyncio import Redis as AsyncRedis
async with get_async_redis() as redis_cache:
set_llm_cache(redis_cache)
llm_cache = cast(RedisCache, get_llm_cache())
await llm_cache.aupdate("foo", "bar", [Generation(text="fizz")])
key = llm_cache._key("foo", "bar")
assert isinstance(llm_cache.redis, AsyncRedis)
assert await llm_cache.redis.pttl(key) > 0
def test_sync_redis_cache() -> None:
with get_sync_redis() as llm_cache:
set_llm_cache(llm_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_cache.update("prompt", llm_string, [Generation(text="fizz0")])
output = llm.generate(["prompt"])
expected_output = LLMResult(
generations=[[Generation(text="fizz0")]],
llm_output={},
)
assert output == expected_output
async def test_sync_in_async_redis_cache() -> None:
"""Test the sync RedisCache invoked with async methods"""
with get_sync_redis() as llm_cache:
set_llm_cache(llm_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
# llm_cache.update("meow", llm_string, [Generation(text="meow")])
await llm_cache.aupdate("prompt", llm_string, [Generation(text="fizz1")])
output = await llm.agenerate(["prompt"])
expected_output = LLMResult(
generations=[[Generation(text="fizz1")]],
llm_output={},
)
assert output == expected_output
async def test_async_redis_cache() -> None:
async with get_async_redis() as redis_cache:
set_llm_cache(redis_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_cache = cast(RedisCache, get_llm_cache())
await llm_cache.aupdate("prompt", llm_string, [Generation(text="fizz2")])
output = await llm.agenerate(["prompt"])
expected_output = LLMResult(
generations=[[Generation(text="fizz2")]],
llm_output={},
)
assert output == expected_output
async def test_async_in_sync_redis_cache() -> None:
async with get_async_redis() as redis_cache:
set_llm_cache(redis_cache)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_cache = cast(RedisCache, get_llm_cache())
with pytest.raises(NotImplementedError):
llm_cache.update("foo", llm_string, [Generation(text="fizz")])
def test_redis_cache_chat() -> None:
with get_sync_redis() as redis_cache:
set_llm_cache(redis_cache)
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
llm_cache = cast(RedisCache, get_llm_cache())
llm_cache.update(
dumps(prompt),
llm_string,
[ChatGeneration(message=AIMessage(content="fizz"))],
)
output = llm.generate([prompt])
expected_output = LLMResult(
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
llm_output={},
)
assert output == expected_output
async def test_async_redis_cache_chat() -> None:
async with get_async_redis() as redis_cache:
set_llm_cache(redis_cache)
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
llm_cache = cast(RedisCache, get_llm_cache())
await llm_cache.aupdate(
dumps(prompt),
llm_string,
[ChatGeneration(message=AIMessage(content="fizz"))],
)
output = await llm.agenerate([prompt])
expected_output = LLMResult(
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
llm_output={},
)
assert output == expected_output
def test_redis_semantic_cache() -> None:
"""Test redis semantic cache functionality."""
set_llm_cache(
RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_cache = cast(RedisSemanticCache, get_llm_cache())
llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
llm_cache.clear(llm_string=llm_string)
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
# expect different output now without cached result
assert output != expected_output
llm_cache.clear(llm_string=llm_string)
def test_redis_semantic_cache_multi() -> None:
set_llm_cache(
RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_cache = cast(RedisSemanticCache, get_llm_cache())
llm_cache.update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
expected_output = LLMResult(
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
llm_cache.clear(llm_string=llm_string)
def test_redis_semantic_cache_chat() -> None:
set_llm_cache(
RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
)
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
llm_cache = cast(RedisSemanticCache, get_llm_cache())
llm_cache.update(
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
)
output = llm.generate([prompt])
expected_output = LLMResult(
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
llm_output={},
)
assert output == expected_output
llm_cache.clear(llm_string=llm_string)
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
@pytest.mark.parametrize(
"prompts, generations",
[
# Single prompt, single generation
([random_string()], [[random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# Multiple prompts, multiple generations
(
[random_string(), random_string()],
[[random_string()], [random_string(), random_string()]],
),
],
ids=[
"single_prompt_single_generation",
"single_prompt_multiple_generations",
"single_prompt_multiple_generations",
"multiple_prompts_multiple_generations",
],
)
def test_redis_semantic_cache_hit(
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
) -> None:
set_llm_cache(RedisSemanticCache(embedding=embedding, redis_url=REDIS_TEST_URL))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_generations = [
[
Generation(text=generation, generation_info=params)
for generation in prompt_i_generations
]
for prompt_i_generations in generations
]
llm_cache = cast(RedisSemanticCache, get_llm_cache())
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
print(prompt_i) # noqa: T201
print(llm_generations_i) # noqa: T201
llm_cache.update(prompt_i, llm_string, llm_generations_i)
llm.generate(prompts)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}
)

View File

@@ -1,90 +0,0 @@
"""Test Upstash Redis cache functionality."""
import uuid
import pytest
from langchain_community.cache import UpstashRedisCache
from langchain_core.outputs import Generation, LLMResult
import langchain
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
URL = "<UPSTASH_REDIS_REST_URL>"
TOKEN = "<UPSTASH_REDIS_REST_TOKEN>"
def random_string() -> str:
return str(uuid.uuid4())
@pytest.mark.requires("upstash_redis")
def test_redis_cache_ttl() -> None:
from upstash_redis import Redis
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")])
key = langchain.llm_cache._key("foo", "bar")
assert langchain.llm_cache.redis.pttl(key) > 0
@pytest.mark.requires("upstash_redis")
def test_redis_cache() -> None:
from upstash_redis import Redis
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
lookup_output = langchain.llm_cache.lookup("foo", llm_string)
if lookup_output and len(lookup_output) > 0:
assert lookup_output == expected_output.generations[0]
langchain.llm_cache.clear()
output = llm.generate(["foo"])
assert output != expected_output
langchain.llm_cache.redis.flushall()
def test_redis_cache_multi() -> None:
from upstash_redis import Redis
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
output = llm.generate(
["foo"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
expected_output = LLMResult(
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
langchain.llm_cache.clear()
@pytest.mark.requires("upstash_redis")
def test_redis_cache_chat() -> None:
from upstash_redis import Redis
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm.invoke("foo")
langchain.llm_cache.redis.flushall()

View File

@@ -1,18 +0,0 @@
"""Integration test for Dall-E image generator agent."""
from langchain_community.agent_toolkits.load_tools import load_tools
from langchain_community.llms import OpenAI
from langchain.agents import AgentType, initialize_agent
def test_call() -> None:
"""Test that the agent runs and returns output."""
llm = OpenAI(temperature=0.9)
tools = load_tools(["dalle-image-generator"])
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
output = agent.run("Create an image of a volcano island")
assert output is not None

View File

@@ -1,337 +0,0 @@
"""Test Graph Database Chain."""
import os
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_community.llms.openai import OpenAI
from langchain.chains.loading import load_chain
def test_connect_neo4j() -> None:
"""Test that Neo4j database is correctly instantiated and connected."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
output = graph.query(
"""
RETURN "test" AS output
"""
)
expected_output = [{"output": "test"}]
assert output == expected_output
def test_connect_neo4j_env() -> None:
"""Test that Neo4j database environment variables."""
graph = Neo4jGraph()
output = graph.query(
"""
RETURN "test" AS output
"""
)
expected_output = [{"output": "test"}]
assert output == expected_output
def test_cypher_generating_run() -> None:
"""Test that Cypher statement is correctly generated and executed."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(OpenAI(temperature=0), graph=graph)
output = chain.run("Who played in Pulp Fiction?")
expected_output = " Bruce Willis played in Pulp Fiction."
assert output == expected_output
def test_cypher_top_k() -> None:
"""Test top_k parameter correctly limits the number of results in the context."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
TOP_K = 1
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:ACTED_IN]-(:Actor {name:'Foo'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_direct=True, top_k=TOP_K
)
output = chain.run("Who played in Pulp Fiction?")
assert len(output) == TOP_K
def test_cypher_intermediate_steps() -> None:
"""Test the returning of the intermediate steps."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_intermediate_steps=True
)
output = chain("Who played in Pulp Fiction?")
expected_output = " Bruce Willis played in Pulp Fiction."
assert output["result"] == expected_output
query = output["intermediate_steps"][0]["query"]
# LLM can return variations of the same query
expected_queries = [
(
"\n\nMATCH (a:Actor)-[:ACTED_IN]->"
"(m:Movie {title: 'Pulp Fiction'}) RETURN a.name"
),
(
"\n\nMATCH (a:Actor)-[:ACTED_IN]->"
"(m:Movie {title: 'Pulp Fiction'}) RETURN a.name;"
),
(
"\n\nMATCH (a:Actor)-[:ACTED_IN]->"
"(m:Movie) WHERE m.title = 'Pulp Fiction' RETURN a.name"
),
]
assert query in expected_queries
context = output["intermediate_steps"][1]["context"]
expected_context = [{"a.name": "Bruce Willis"}]
assert context == expected_context
def test_cypher_return_direct() -> None:
"""Test that chain returns direct results."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_direct=True
)
output = chain.run("Who played in Pulp Fiction?")
expected_output = [{"a.name": "Bruce Willis"}]
assert output == expected_output
def test_cypher_save_load() -> None:
"""Test saving and loading."""
FILE_PATH = "cypher.yaml"
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_direct=True
)
chain.save(file_path=FILE_PATH)
qa_loaded = load_chain(FILE_PATH, graph=graph)
assert qa_loaded == chain
def test_exclude_types() -> None:
"""Test exclude types from schema."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:DIRECTED]-(p:Person {name:'John'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, exclude_types=["Person", "DIRECTED"]
)
expected_schema = (
"Node properties are the following:\n"
"Movie {title: STRING},Actor {name: STRING}\n"
"Relationship properties are the following:\n\n"
"The relationships are the following:\n"
"(:Actor)-[:ACTED_IN]->(:Movie)"
)
assert chain.graph_schema == expected_schema
def test_include_types() -> None:
"""Test include types from schema."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:DIRECTED]-(p:Person {name:'John'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, include_types=["Movie", "Actor", "ACTED_IN"]
)
expected_schema = (
"Node properties are the following:\n"
"Movie {title: STRING},Actor {name: STRING}\n"
"Relationship properties are the following:\n\n"
"The relationships are the following:\n"
"(:Actor)-[:ACTED_IN]->(:Movie)"
)
assert chain.graph_schema == expected_schema
def test_include_types2() -> None:
"""Test include types from schema."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:DIRECTED]-(p:Person {name:'John'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, include_types=["Movie", "ACTED_IN"]
)
expected_schema = (
"Node properties are the following:\n"
"Movie {title: STRING}\n"
"Relationship properties are the following:\n\n"
"The relationships are the following:\n"
)
assert chain.graph_schema == expected_schema

View File

@@ -1,93 +0,0 @@
"""Test Graph Database Chain."""
from typing import Any
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain_community.graphs import ArangoGraph
from langchain_community.graphs.arangodb_graph import get_arangodb_client
from langchain_community.llms.openai import OpenAI
def populate_arangodb_database(db: Any) -> None:
if db.has_graph("GameOfThrones"):
return
db.create_graph(
"GameOfThrones",
edge_definitions=[
{
"edge_collection": "ChildOf",
"from_vertex_collections": ["Characters"],
"to_vertex_collections": ["Characters"],
},
],
)
documents = [
{
"_key": "NedStark",
"name": "Ned",
"surname": "Stark",
"alive": True,
"age": 41,
"gender": "male",
},
{
"_key": "AryaStark",
"name": "Arya",
"surname": "Stark",
"alive": True,
"age": 11,
"gender": "female",
},
]
edges = [{"_to": "Characters/NedStark", "_from": "Characters/AryaStark"}]
db.collection("Characters").import_bulk(documents)
db.collection("ChildOf").import_bulk(edges)
def test_connect_arangodb() -> None:
"""Test that the ArangoDB database is correctly instantiated and connected."""
graph = ArangoGraph(get_arangodb_client())
sample_aql_result = graph.query("RETURN 'hello world'")
assert ["hello_world"] == sample_aql_result
def test_empty_schema_on_no_data() -> None:
"""Test that the schema is empty for an empty ArangoDB Database"""
db = get_arangodb_client()
db.delete_graph("GameOfThrones", drop_collections=True, ignore_missing=True)
db.delete_collection("empty_collection", ignore_missing=True)
db.create_collection("empty_collection")
graph = ArangoGraph(db)
assert graph.schema == {
"Graph Schema": [],
"Collection Schema": [],
}
def test_aql_generation() -> None:
"""Test that AQL statement is correctly generated and executed."""
db = get_arangodb_client()
populate_arangodb_database(db)
graph = ArangoGraph(db)
chain = ArangoGraphQAChain.from_llm(OpenAI(temperature=0), graph=graph)
chain.return_aql_result = True
output = chain("Is Ned Stark alive?")
assert output["aql_result"] == [True]
assert "Yes" in output["result"]
output = chain("How old is Arya Stark?")
assert output["aql_result"] == [11]
assert "11" in output["result"]
output = chain("What is the relationship between Arya Stark and Ned Stark?")
assert len(output["aql_result"]) == 1
assert "child of" in output["result"]

View File

@@ -1,268 +0,0 @@
"""Test RDF/ SPARQL Graph Database Chain."""
import pathlib
import re
from unittest.mock import MagicMock, Mock
from langchain_community.chains.graph_qa.sparql import GraphSparqlQAChain
from langchain_community.graphs import RdfGraph
from langchain.chains import LLMChain
"""
cd libs/langchain/tests/integration_tests/chains/docker-compose-ontotext-graphdb
./start.sh
"""
def test_connect_file_rdf() -> None:
"""
Test loading online resource.
"""
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
graph = RdfGraph(
source_file=berners_lee_card,
standard="rdf",
)
query = """SELECT ?s ?p ?o\n""" """WHERE { ?s ?p ?o }"""
output = graph.query(query)
assert len(output) == 86
def test_sparql_select() -> None:
"""
Test for generating and executing simple SPARQL SELECT query.
"""
from langchain_openai import ChatOpenAI
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
graph = RdfGraph(
source_file=berners_lee_card,
standard="rdf",
)
question = "What is Tim Berners-Lee's work homepage?"
answer = "Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/."
chain = GraphSparqlQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
)
chain.sparql_intent_chain = Mock(LLMChain)
chain.sparql_generation_select_chain = Mock(LLMChain)
chain.sparql_generation_update_chain = Mock(LLMChain)
chain.sparql_intent_chain.run = Mock(return_value="SELECT")
chain.sparql_generation_select_chain.run = Mock(
return_value="""PREFIX foaf: <http://xmlns.com/foaf/0.1/>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?workHomepage
WHERE {
?person rdfs:label "Tim Berners-Lee" .
?person foaf:workplaceHomepage ?workHomepage .
}"""
)
chain.qa_chain = MagicMock(
return_value={
"text": answer,
"prompt": question,
"context": [],
}
)
chain.qa_chain.output_key = "text"
output = chain.invoke({chain.input_key: question})[chain.output_key]
assert output == answer
assert chain.sparql_intent_chain.run.call_count == 1
assert chain.sparql_generation_select_chain.run.call_count == 1
assert chain.sparql_generation_update_chain.run.call_count == 0
assert chain.qa_chain.call_count == 1
def test_sparql_insert(tmp_path: pathlib.Path) -> None:
"""
Test for generating and executing simple SPARQL INSERT query.
"""
from langchain_openai import ChatOpenAI
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
local_copy = tmp_path / "test.ttl"
graph = RdfGraph(
source_file=berners_lee_card,
standard="rdf",
local_copy=str(local_copy),
)
query = (
"Save that the person with the name 'Timothy Berners-Lee' "
"has a work homepage at 'http://www.w3.org/foo/bar/'"
)
chain = GraphSparqlQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
)
chain.sparql_intent_chain = Mock(LLMChain)
chain.sparql_generation_select_chain = Mock(LLMChain)
chain.sparql_generation_update_chain = Mock(LLMChain)
chain.qa_chain = Mock(LLMChain)
chain.sparql_intent_chain.run = Mock(return_value="UPDATE")
chain.sparql_generation_update_chain.run = Mock(
return_value="""PREFIX foaf: <http://xmlns.com/foaf/0.1/>
INSERT {
?p foaf:workplaceHomepage <http://www.w3.org/foo/bar/> .
}
WHERE {
?p foaf:name "Timothy Berners-Lee" .
}"""
)
output = chain.invoke({chain.input_key: query})[chain.output_key]
assert output == "Successfully inserted triples into the graph."
assert chain.sparql_intent_chain.run.call_count == 1
assert chain.sparql_generation_select_chain.run.call_count == 0
assert chain.sparql_generation_update_chain.run.call_count == 1
assert chain.qa_chain.call_count == 0
query = (
"""PREFIX foaf: <http://xmlns.com/foaf/0.1/>\n"""
"""SELECT ?hp\n"""
"""WHERE {\n"""
""" ?person foaf:name "Timothy Berners-Lee" . \n"""
""" ?person foaf:workplaceHomepage ?hp .\n"""
"""}"""
)
output = graph.query(query)
assert len(output) == 2
def test_sparql_select_return_query() -> None:
"""
Test for generating and executing simple SPARQL SELECT query
and returning the generated SPARQL query.
"""
from langchain_openai import ChatOpenAI
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
graph = RdfGraph(
source_file=berners_lee_card,
standard="rdf",
)
question = "What is Tim Berners-Lee's work homepage?"
answer = "Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/."
chain = GraphSparqlQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
return_sparql_query=True,
)
chain.sparql_intent_chain = Mock(LLMChain)
chain.sparql_generation_select_chain = Mock(LLMChain)
chain.sparql_generation_update_chain = Mock(LLMChain)
chain.sparql_intent_chain.run = Mock(return_value="SELECT")
chain.sparql_generation_select_chain.run = Mock(
return_value="""PREFIX foaf: <http://xmlns.com/foaf/0.1/>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?workHomepage
WHERE {
?person rdfs:label "Tim Berners-Lee" .
?person foaf:workplaceHomepage ?workHomepage .
}"""
)
chain.qa_chain = MagicMock(
return_value={
"text": answer,
"prompt": question,
"context": [],
}
)
chain.qa_chain.output_key = "text"
output = chain.invoke({chain.input_key: question})
assert output[chain.output_key] == answer
assert "sparql_query" in output
assert chain.sparql_intent_chain.run.call_count == 1
assert chain.sparql_generation_select_chain.run.call_count == 1
assert chain.sparql_generation_update_chain.run.call_count == 0
assert chain.qa_chain.call_count == 1
def test_loading_schema_from_ontotext_graphdb() -> None:
graph = RdfGraph(
query_endpoint="http://localhost:7200/repositories/langchain",
graph_kwargs={"bind_namespaces": "none"},
)
schema = graph.get_schema
prefix = (
"In the following, each IRI is followed by the local name and "
"optionally its description in parentheses. \n"
"The RDF graph supports the following node types:"
)
assert schema.startswith(prefix)
infix = "The RDF graph supports the following relationships:"
assert infix in schema
classes = schema[len(prefix) : schema.index(infix)]
assert len(re.findall("<[^>]+> \\([^)]+\\)", classes)) == 5
relationships = schema[schema.index(infix) + len(infix) :]
assert len(re.findall("<[^>]+> \\([^)]+\\)", relationships)) == 58
def test_graph_qa_chain_with_ontotext_graphdb() -> None:
from langchain_openai import ChatOpenAI
question = "What is Tim Berners-Lee's work homepage?"
answer = "Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/."
graph = RdfGraph(
query_endpoint="http://localhost:7200/repositories/langchain",
graph_kwargs={"bind_namespaces": "none"},
)
chain = GraphSparqlQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
)
chain.sparql_intent_chain = Mock(LLMChain)
chain.sparql_generation_select_chain = Mock(LLMChain)
chain.sparql_generation_update_chain = Mock(LLMChain)
chain.sparql_intent_chain.run = Mock(return_value="SELECT")
chain.sparql_generation_select_chain.run = Mock(
return_value="""PREFIX foaf: <http://xmlns.com/foaf/0.1/>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?workHomepage
WHERE {
?person rdfs:label "Tim Berners-Lee" .
?person foaf:workplaceHomepage ?workHomepage .
}"""
)
chain.qa_chain = MagicMock(
return_value={
"text": answer,
"prompt": question,
"context": [],
}
)
chain.qa_chain.output_key = "text"
output = chain.invoke({chain.input_key: question})[chain.output_key]
assert output == answer
assert chain.sparql_intent_chain.run.call_count == 1
assert chain.sparql_generation_select_chain.run.call_count == 1
assert chain.sparql_generation_update_chain.run.call_count == 0
assert chain.qa_chain.call_count == 1

View File

@@ -1,385 +0,0 @@
from unittest.mock import MagicMock, Mock
import pytest
from langchain_community.chains.graph_qa.ontotext_graphdb import OntotextGraphDBQAChain
from langchain_community.graphs import OntotextGraphDBGraph
from langchain.chains import LLMChain
"""
cd libs/langchain/tests/integration_tests/chains/docker-compose-ontotext-graphdb
./start.sh
"""
@pytest.mark.requires("langchain_openai", "rdflib")
@pytest.mark.parametrize("max_fix_retries", [-2, -1, 0, 1, 2])
def test_valid_sparql(max_fix_retries: int) -> None:
from langchain_openai import ChatOpenAI
question = "What is Luke Skywalker's home planet?"
answer = "Tatooine"
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/starwars",
query_ontology="CONSTRUCT {?s ?p ?o} "
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
)
chain = OntotextGraphDBQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
max_fix_retries=max_fix_retries,
)
chain.sparql_generation_chain = Mock(LLMChain)
chain.sparql_fix_chain = Mock(LLMChain)
chain.qa_chain = Mock(LLMChain)
chain.sparql_generation_chain.output_key = "text"
chain.sparql_generation_chain.invoke = MagicMock(
return_value={
"text": "SELECT * {?s ?p ?o} LIMIT 1",
"prompt": question,
"schema": "",
}
)
chain.sparql_fix_chain.output_key = "text"
chain.sparql_fix_chain.invoke = MagicMock()
chain.qa_chain.output_key = "text"
chain.qa_chain.invoke = MagicMock(
return_value={
"text": answer,
"prompt": question,
"context": [],
}
)
result = chain.invoke({chain.input_key: question})
assert chain.sparql_generation_chain.invoke.call_count == 1
assert chain.sparql_fix_chain.invoke.call_count == 0
assert chain.qa_chain.invoke.call_count == 1
assert result == {chain.output_key: answer, chain.input_key: question}
@pytest.mark.requires("langchain_openai", "rdflib")
@pytest.mark.parametrize("max_fix_retries", [-2, -1, 0])
def test_invalid_sparql_non_positive_max_fix_retries(
max_fix_retries: int,
) -> None:
from langchain_openai import ChatOpenAI
question = "What is Luke Skywalker's home planet?"
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/starwars",
query_ontology="CONSTRUCT {?s ?p ?o} "
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
)
chain = OntotextGraphDBQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
max_fix_retries=max_fix_retries,
)
chain.sparql_generation_chain = Mock(LLMChain)
chain.sparql_fix_chain = Mock(LLMChain)
chain.qa_chain = Mock(LLMChain)
chain.sparql_generation_chain.output_key = "text"
chain.sparql_generation_chain.invoke = MagicMock(
return_value={
"text": "```sparql SELECT * {?s ?p ?o} LIMIT 1```",
"prompt": question,
"schema": "",
}
)
chain.sparql_fix_chain.output_key = "text"
chain.sparql_fix_chain.invoke = MagicMock()
chain.qa_chain.output_key = "text"
chain.qa_chain.invoke = MagicMock()
with pytest.raises(ValueError) as e:
chain.invoke({chain.input_key: question})
assert str(e.value) == "The generated SPARQL query is invalid."
assert chain.sparql_generation_chain.invoke.call_count == 1
assert chain.sparql_fix_chain.invoke.call_count == 0
assert chain.qa_chain.invoke.call_count == 0
@pytest.mark.requires("langchain_openai", "rdflib")
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
def test_valid_sparql_after_first_retry(max_fix_retries: int) -> None:
from langchain_openai import ChatOpenAI
question = "What is Luke Skywalker's home planet?"
answer = "Tatooine"
generated_invalid_sparql = "```sparql SELECT * {?s ?p ?o} LIMIT 1```"
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/starwars",
query_ontology="CONSTRUCT {?s ?p ?o} "
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
)
chain = OntotextGraphDBQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
max_fix_retries=max_fix_retries,
)
chain.sparql_generation_chain = Mock(LLMChain)
chain.sparql_fix_chain = Mock(LLMChain)
chain.qa_chain = Mock(LLMChain)
chain.sparql_generation_chain.output_key = "text"
chain.sparql_generation_chain.invoke = MagicMock(
return_value={
"text": generated_invalid_sparql,
"prompt": question,
"schema": "",
}
)
chain.sparql_fix_chain.output_key = "text"
chain.sparql_fix_chain.invoke = MagicMock(
return_value={
"text": "SELECT * {?s ?p ?o} LIMIT 1",
"error_message": "pyparsing.exceptions.ParseException: "
"Expected {SelectQuery | ConstructQuery | DescribeQuery | AskQuery}, "
"found '`' (at char 0), (line:1, col:1)",
"generated_sparql": generated_invalid_sparql,
"schema": "",
}
)
chain.qa_chain.output_key = "text"
chain.qa_chain.invoke = MagicMock(
return_value={
"text": answer,
"prompt": question,
"context": [],
}
)
result = chain.invoke({chain.input_key: question})
assert chain.sparql_generation_chain.invoke.call_count == 1
assert chain.sparql_fix_chain.invoke.call_count == 1
assert chain.qa_chain.invoke.call_count == 1
assert result == {chain.output_key: answer, chain.input_key: question}
@pytest.mark.requires("langchain_openai", "rdflib")
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
def test_invalid_sparql_server_response_400(max_fix_retries: int) -> None:
from langchain_openai import ChatOpenAI
question = "Who is the oldest character?"
generated_invalid_sparql = (
"PREFIX : <https://swapi.co/vocabulary/> "
"PREFIX owl: <http://www.w3.org/2002/07/owl#> "
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
"PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> "
"SELECT ?character (MAX(?lifespan) AS ?maxLifespan) "
"WHERE {"
" ?species a :Species ;"
" :character ?character ;"
" :averageLifespan ?lifespan ."
" FILTER(xsd:integer(?lifespan))"
"} "
"ORDER BY DESC(?maxLifespan) "
"LIMIT 1"
)
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/starwars",
query_ontology="CONSTRUCT {?s ?p ?o} "
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
)
chain = OntotextGraphDBQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
max_fix_retries=max_fix_retries,
)
chain.sparql_generation_chain = Mock(LLMChain)
chain.sparql_fix_chain = Mock(LLMChain)
chain.qa_chain = Mock(LLMChain)
chain.sparql_generation_chain.output_key = "text"
chain.sparql_generation_chain.invoke = MagicMock(
return_value={
"text": generated_invalid_sparql,
"prompt": question,
"schema": "",
}
)
chain.sparql_fix_chain.output_key = "text"
chain.sparql_fix_chain.invoke = MagicMock()
chain.qa_chain.output_key = "text"
chain.qa_chain.invoke = MagicMock()
with pytest.raises(ValueError) as e:
chain.invoke({chain.input_key: question})
assert str(e.value) == "Failed to execute the generated SPARQL query."
assert chain.sparql_generation_chain.invoke.call_count == 1
assert chain.sparql_fix_chain.invoke.call_count == 0
assert chain.qa_chain.invoke.call_count == 0
@pytest.mark.requires("langchain_openai", "rdflib")
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
def test_invalid_sparql_after_all_retries(max_fix_retries: int) -> None:
from langchain_openai import ChatOpenAI
question = "What is Luke Skywalker's home planet?"
generated_invalid_sparql = "```sparql SELECT * {?s ?p ?o} LIMIT 1```"
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/starwars",
query_ontology="CONSTRUCT {?s ?p ?o} "
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
)
chain = OntotextGraphDBQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
max_fix_retries=max_fix_retries,
)
chain.sparql_generation_chain = Mock(LLMChain)
chain.sparql_fix_chain = Mock(LLMChain)
chain.qa_chain = Mock(LLMChain)
chain.sparql_generation_chain.output_key = "text"
chain.sparql_generation_chain.invoke = MagicMock(
return_value={
"text": generated_invalid_sparql,
"prompt": question,
"schema": "",
}
)
chain.sparql_fix_chain.output_key = "text"
chain.sparql_fix_chain.invoke = MagicMock(
return_value={
"text": generated_invalid_sparql,
"error_message": "pyparsing.exceptions.ParseException: "
"Expected {SelectQuery | ConstructQuery | DescribeQuery | AskQuery}, "
"found '`' (at char 0), (line:1, col:1)",
"generated_sparql": generated_invalid_sparql,
"schema": "",
}
)
chain.qa_chain.output_key = "text"
chain.qa_chain.invoke = MagicMock()
with pytest.raises(ValueError) as e:
chain.invoke({chain.input_key: question})
assert str(e.value) == "The generated SPARQL query is invalid."
assert chain.sparql_generation_chain.invoke.call_count == 1
assert chain.sparql_fix_chain.invoke.call_count == max_fix_retries
assert chain.qa_chain.invoke.call_count == 0
@pytest.mark.requires("langchain_openai", "rdflib")
@pytest.mark.parametrize(
"max_fix_retries,number_of_invalid_responses",
[(1, 0), (2, 0), (2, 1), (10, 6)],
)
def test_valid_sparql_after_some_retries(
max_fix_retries: int, number_of_invalid_responses: int
) -> None:
from langchain_openai import ChatOpenAI
question = "What is Luke Skywalker's home planet?"
answer = "Tatooine"
generated_invalid_sparql = "```sparql SELECT * {?s ?p ?o} LIMIT 1```"
generated_valid_sparql_query = "SELECT * {?s ?p ?o} LIMIT 1"
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/starwars",
query_ontology="CONSTRUCT {?s ?p ?o} "
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
)
chain = OntotextGraphDBQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
max_fix_retries=max_fix_retries,
)
chain.sparql_generation_chain = Mock(LLMChain)
chain.sparql_fix_chain = Mock(LLMChain)
chain.qa_chain = Mock(LLMChain)
chain.sparql_generation_chain.output_key = "text"
chain.sparql_generation_chain.invoke = MagicMock(
return_value={
"text": generated_invalid_sparql,
"prompt": question,
"schema": "",
}
)
chain.sparql_fix_chain.output_key = "text"
chain.sparql_fix_chain.invoke = Mock()
chain.sparql_fix_chain.invoke.side_effect = [
{
"text": generated_invalid_sparql,
"error_message": "pyparsing.exceptions.ParseException: "
"Expected {SelectQuery | ConstructQuery | DescribeQuery | AskQuery}, "
"found '`' (at char 0), (line:1, col:1)",
"generated_sparql": generated_invalid_sparql,
"schema": "",
}
] * number_of_invalid_responses + [
{
"text": generated_valid_sparql_query,
"error_message": "pyparsing.exceptions.ParseException: "
"Expected {SelectQuery | ConstructQuery | DescribeQuery | AskQuery}, "
"found '`' (at char 0), (line:1, col:1)",
"generated_sparql": generated_invalid_sparql,
"schema": "",
}
]
chain.qa_chain.output_key = "text"
chain.qa_chain.invoke = MagicMock(
return_value={
"text": answer,
"prompt": question,
"context": [],
}
)
result = chain.invoke({chain.input_key: question})
assert chain.sparql_generation_chain.invoke.call_count == 1
assert chain.sparql_fix_chain.invoke.call_count == number_of_invalid_responses + 1
assert chain.qa_chain.invoke.call_count == 1
assert result == {chain.output_key: answer, chain.input_key: question}
@pytest.mark.requires("langchain_openai", "rdflib")
@pytest.mark.parametrize(
"model_name,question",
[
("gpt-3.5-turbo-1106", "What is the average height of the Wookiees?"),
("gpt-3.5-turbo-1106", "What is the climate on Tatooine?"),
("gpt-3.5-turbo-1106", "What is Luke Skywalker's home planet?"),
("gpt-4-1106-preview", "What is the average height of the Wookiees?"),
("gpt-4-1106-preview", "What is the climate on Tatooine?"),
("gpt-4-1106-preview", "What is Luke Skywalker's home planet?"),
],
)
def test_chain(model_name: str, question: str) -> None:
from langchain_openai import ChatOpenAI
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/starwars",
query_ontology="CONSTRUCT {?s ?p ?o} "
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
)
chain = OntotextGraphDBQAChain.from_llm(
ChatOpenAI(temperature=0, model_name=model_name),
graph=graph,
verbose=True, # type: ignore[call-arg]
)
try:
chain.invoke({chain.input_key: question})
except ValueError:
pass

View File

@@ -1,19 +0,0 @@
"""Integration test for self ask with search."""
from langchain_community.docstore import Wikipedia
from langchain_community.llms.openai import OpenAI
from langchain.agents.react.base import ReActChain
def test_react() -> None:
"""Test functionality on a prompt."""
llm = OpenAI(temperature=0, model_name="gpt-3.5-turbo-instruct") # type: ignore[call-arg]
react = ReActChain(llm=llm, docstore=Wikipedia())
question = (
"Author David Chanoff has collaborated with a U.S. Navy admiral "
"who served as the ambassador to the United Kingdom under "
"which President?"
)
output = react.run(question)
assert output == "Bill Clinton"

View File

@@ -1,29 +0,0 @@
"""Test RetrievalQA functionality."""
from pathlib import Path
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.llms import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_text_splitters.character import CharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain.chains.loading import load_chain
def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:
"""Test saving and loading."""
loader = TextLoader("docs/extras/modules/state_of_the_union.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
docsearch = FAISS.from_documents(texts, embeddings)
qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever())
qa.run("What did the president say about Ketanji Brown Jackson?")
file_path = tmp_path / "RetrievalQA_chain.yaml"
qa.save(file_path=file_path)
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())
assert qa_loaded == qa

View File

@@ -1,39 +0,0 @@
"""Test RetrievalQA functionality."""
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.llms import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_text_splitters.character import CharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.loading import load_chain
def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None:
"""Test saving and loading."""
loader = DirectoryLoader("docs/extras/modules/", glob="*.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
docsearch = FAISS.from_documents(texts, embeddings)
qa = RetrievalQAWithSourcesChain.from_llm(
llm=OpenAI(), retriever=docsearch.as_retriever()
)
result = qa("What did the president say about Ketanji Brown Jackson?")
assert "question" in result.keys()
assert "answer" in result.keys()
assert "sources" in result.keys()
file_path = str(tmp_path) + "/RetrievalQAWithSourcesChain.yaml"
qa.save(file_path=file_path)
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())
assert qa_loaded == qa
qa2 = RetrievalQAWithSourcesChain.from_chain_type(
llm=OpenAI(), retriever=docsearch.as_retriever(), chain_type="stuff"
)
result2 = qa2("What did the president say about Ketanji Brown Jackson?")
assert "question" in result2.keys()
assert "answer" in result2.keys()
assert "sources" in result2.keys()

View File

@@ -1,19 +0,0 @@
"""Integration test for self ask with search."""
from langchain_community.llms.openai import OpenAI
from langchain_community.utilities.searchapi import SearchApiAPIWrapper
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
def test_self_ask_with_search() -> None:
"""Test functionality on a prompt."""
question = "What is the hometown of the reigning men's U.S. Open champion?"
chain = SelfAskWithSearchChain(
llm=OpenAI(temperature=0),
search_chain=SearchApiAPIWrapper(),
input_key="q",
output_key="a",
)
answer = chain.run(question)
final_answer = answer.split("\n")[-1]
assert final_answer == "Belgrade, Serbia"

View File

@@ -1,202 +0,0 @@
import os
from typing import AsyncIterable, Iterable
import pytest
from langchain_community.chat_message_histories.astradb import (
AstraDBChatMessageHistory,
)
from langchain_community.utilities.astradb import SetupMode
from langchain_core.messages import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
def _has_env_vars() -> bool:
return all(
[
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
"ASTRA_DB_API_ENDPOINT" in os.environ,
]
)
@pytest.fixture(scope="function")
def history1() -> Iterable[AstraDBChatMessageHistory]:
history1 = AstraDBChatMessageHistory(
session_id="session-test-1",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield history1
history1.collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture(scope="function")
def history2() -> Iterable[AstraDBChatMessageHistory]:
history2 = AstraDBChatMessageHistory(
session_id="session-test-2",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield history2
history2.collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture
async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]:
history1 = AstraDBChatMessageHistory(
session_id="async-session-test-1",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
setup_mode=SetupMode.ASYNC,
)
yield history1
await history1.async_collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture(scope="function")
async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]:
history2 = AstraDBChatMessageHistory(
session_id="async-session-test-2",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
setup_mode=SetupMode.ASYNC,
)
yield history2
await history2.async_collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
"""Test the memory with a message store."""
memory = ConversationBufferMemory(
memory_key="baz",
chat_memory=history1,
return_messages=True,
)
assert memory.chat_memory.messages == []
# add some messages
memory.chat_memory.add_messages(
[
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
)
messages = memory.chat_memory.messages
expected = [
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
assert messages == expected
# clear the store
memory.chat_memory.clear()
assert memory.chat_memory.messages == []
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
async def test_memory_with_message_store_async(
async_history1: AstraDBChatMessageHistory,
) -> None:
"""Test the memory with a message store."""
memory = ConversationBufferMemory(
memory_key="baz",
chat_memory=async_history1,
return_messages=True,
)
assert await memory.chat_memory.aget_messages() == []
# add some messages
await memory.chat_memory.aadd_messages(
[
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
)
messages = await memory.chat_memory.aget_messages()
expected = [
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
assert messages == expected
# clear the store
await memory.chat_memory.aclear()
assert await memory.chat_memory.aget_messages() == []
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
def test_memory_separate_session_ids(
history1: AstraDBChatMessageHistory, history2: AstraDBChatMessageHistory
) -> None:
"""Test that separate session IDs do not share entries."""
memory1 = ConversationBufferMemory(
memory_key="mk1",
chat_memory=history1,
return_messages=True,
)
memory2 = ConversationBufferMemory(
memory_key="mk2",
chat_memory=history2,
return_messages=True,
)
memory1.chat_memory.add_messages([AIMessage(content="Just saying.")])
assert memory2.chat_memory.messages == []
memory2.chat_memory.clear()
assert memory1.chat_memory.messages != []
memory1.chat_memory.clear()
assert memory1.chat_memory.messages == []
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
async def test_memory_separate_session_ids_async(
async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory
) -> None:
"""Test that separate session IDs do not share entries."""
memory1 = ConversationBufferMemory(
memory_key="mk1",
chat_memory=async_history1,
return_messages=True,
)
memory2 = ConversationBufferMemory(
memory_key="mk2",
chat_memory=async_history2,
return_messages=True,
)
await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")])
assert await memory2.chat_memory.aget_messages() == []
await memory2.chat_memory.aclear()
assert await memory1.chat_memory.aget_messages() != []
await memory1.chat_memory.aclear()
assert await memory1.chat_memory.aget_messages() == []

View File

@@ -1,116 +0,0 @@
import os
import time
from typing import Optional
from langchain_community.chat_message_histories.cassandra import (
CassandraChatMessageHistory,
)
from langchain_core.messages import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
def _chat_message_history(
session_id: str = "test-session",
drop: bool = True,
ttl_seconds: Optional[int] = None,
) -> CassandraChatMessageHistory:
from cassandra.cluster import Cluster
keyspace = "cmh_test_keyspace"
table_name = "cmh_test_table"
# get db connection
if "CASSANDRA_CONTACT_POINTS" in os.environ:
contact_points = os.environ["CONTACT_POINTS"].split(",")
cluster = Cluster(contact_points)
else:
cluster = Cluster()
#
session = cluster.connect()
# ensure keyspace exists
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
)
)
# drop table if required
if drop:
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}")
#
return CassandraChatMessageHistory(
session_id=session_id,
session=session,
keyspace=keyspace,
table_name=table_name,
**({} if ttl_seconds is None else {"ttl_seconds": ttl_seconds}),
)
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup cassandra as a message store
message_history = _chat_message_history()
memory = ConversationBufferMemory(
memory_key="baz",
chat_memory=message_history,
return_messages=True,
)
assert memory.chat_memory.messages == []
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
messages = memory.chat_memory.messages
expected = [
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
assert messages == expected
# clear the store
memory.chat_memory.clear()
assert memory.chat_memory.messages == []
def test_memory_separate_session_ids() -> None:
"""Test that separate session IDs do not share entries."""
message_history1 = _chat_message_history(session_id="test-session1")
memory1 = ConversationBufferMemory(
memory_key="mk1",
chat_memory=message_history1,
return_messages=True,
)
message_history2 = _chat_message_history(session_id="test-session2")
memory2 = ConversationBufferMemory(
memory_key="mk2",
chat_memory=message_history2,
return_messages=True,
)
memory1.chat_memory.add_ai_message("Just saying.")
assert memory2.chat_memory.messages == []
memory1.chat_memory.clear()
memory2.chat_memory.clear()
def test_memory_ttl() -> None:
"""Test time-to-live feature of the memory."""
message_history = _chat_message_history(ttl_seconds=5)
memory = ConversationBufferMemory(
memory_key="baz",
chat_memory=message_history,
return_messages=True,
)
#
assert memory.chat_memory.messages == []
memory.chat_memory.add_ai_message("Nothing special here.")
time.sleep(2)
assert memory.chat_memory.messages != []
time.sleep(5)
assert memory.chat_memory.messages == []

View File

@@ -1,45 +0,0 @@
import json
import os
from langchain_community.chat_message_histories import CosmosDBChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
# Replace these with your Azure Cosmos DB endpoint and key
endpoint = os.environ.get("COSMOS_DB_ENDPOINT", "")
credential = os.environ.get("COSMOS_DB_KEY", "")
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup Azure Cosmos DB as a message store
message_history = CosmosDBChatMessageHistory(
cosmos_endpoint=endpoint,
cosmos_database="chat_history",
cosmos_container="messages",
credential=credential,
session_id="my-test-session",
user_id="my-test-user",
ttl=10,
)
message_history.prepare_cosmos()
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Azure Cosmos DB, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@@ -1,91 +0,0 @@
import json
import os
import uuid
from typing import Generator, Union
import pytest
from langchain_community.chat_message_histories import ElasticsearchChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
"""
cd tests/integration_tests/memory/docker-compose
docker-compose -f elasticsearch.yml up
By default runs against local docker instance of Elasticsearch.
To run against Elastic Cloud, set the following environment variables:
- ES_CLOUD_ID
- ES_USERNAME
- ES_PASSWORD
"""
class TestElasticsearch:
@pytest.fixture(scope="class", autouse=True)
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]: # type: ignore[return]
# Run this integration test against Elasticsearch on localhost,
# or an Elastic Cloud instance
from elasticsearch import Elasticsearch
es_url = os.environ.get("ES_URL", "http://localhost:9200")
es_cloud_id = os.environ.get("ES_CLOUD_ID")
es_username = os.environ.get("ES_USERNAME", "elastic")
es_password = os.environ.get("ES_PASSWORD", "changeme")
if es_cloud_id:
es = Elasticsearch(
cloud_id=es_cloud_id,
basic_auth=(es_username, es_password),
)
yield {
"es_cloud_id": es_cloud_id,
"es_user": es_username,
"es_password": es_password,
}
else:
# Running this integration test with local docker instance
es = Elasticsearch(hosts=es_url)
yield {"es_url": es_url}
# Clear all indexes
index_names = es.indices.get(index="_all").keys()
for index_name in index_names:
if index_name.startswith("test_"):
es.indices.delete(index=index_name)
es.indices.refresh(index="_all")
@pytest.fixture(scope="function")
def index_name(self) -> str:
"""Return the index name."""
return f"test_{uuid.uuid4().hex}"
def test_memory_with_message_store(
self, elasticsearch_connection: dict, index_name: str
) -> None:
"""Test the memory with a message store."""
# setup Elasticsearch as a message store
message_history = ElasticsearchChatMessageHistory(
**elasticsearch_connection, index=index_name, session_id="test-session"
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Elasticsearch, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@@ -1,44 +0,0 @@
import json
from langchain_community.chat_message_histories import FirestoreChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
message_history = FirestoreChatMessageHistory(
collection_name="chat_history",
session_id="my-test-session",
user_id="my-test-user",
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store
# and check if the messages are there as expected
message_history = FirestoreChatMessageHistory(
collection_name="chat_history",
session_id="my-test-session",
user_id="my-test-user",
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Firestore, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@@ -1,71 +0,0 @@
"""Test Momento chat message history functionality.
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
Momento auth token. This can be obtained by signing up for a free
Momento account at https://gomomento.com/.
"""
import json
import uuid
from datetime import timedelta
from typing import Iterator
import pytest
from langchain_community.chat_message_histories import MomentoChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
def random_string() -> str:
return str(uuid.uuid4())
@pytest.fixture(scope="function")
def message_history() -> Iterator[MomentoChatMessageHistory]:
from momento import CacheClient, Configurations, CredentialProvider
cache_name = f"langchain-test-cache-{random_string()}"
client = CacheClient(
Configurations.Laptop.v1(),
CredentialProvider.from_environment_variable("MOMENTO_API_KEY"),
default_ttl=timedelta(seconds=30),
)
try:
chat_message_history = MomentoChatMessageHistory(
session_id="my-test-session",
cache_client=client,
cache_name=cache_name,
)
yield chat_message_history
finally:
client.delete_cache(cache_name)
def test_memory_empty_on_new_session(
message_history: MomentoChatMessageHistory,
) -> None:
memory = ConversationBufferMemory(
memory_key="foo", chat_memory=message_history, return_messages=True
)
assert memory.chat_memory.messages == []
def test_memory_with_message_store(message_history: MomentoChatMessageHistory) -> None:
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# Add some messages to the memory store
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# Verify that the messages are in the store
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# Verify clearing the store
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@@ -1,37 +0,0 @@
import json
import os
from langchain_community.chat_message_histories import MongoDBChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
# Replace these with your mongodb connection string
connection_string = os.environ.get("MONGODB_CONNECTION_STRING", "")
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup MongoDB as a message store
message_history = MongoDBChatMessageHistory(
connection_string=connection_string, session_id="test-session"
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Azure Cosmos DB, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@@ -1,31 +0,0 @@
import json
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup MongoDB as a message store
message_history = Neo4jChatMessageHistory(session_id="test-session")
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Azure Cosmos DB, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@@ -1,31 +0,0 @@
import json
from langchain_community.chat_message_histories import RedisChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup Redis as a message store
message_history = RedisChatMessageHistory(
url="redis://localhost:6379/0", ttl=10, session_id="my-test-session"
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Redis, so the next test run won't pick it up
memory.chat_memory.clear()

View File

@@ -1,64 +0,0 @@
"""Tests RocksetChatMessageHistory by creating a collection
for message history, adding to it, and clearing it.
To run these tests, make sure you have the ROCKSET_API_KEY
and ROCKSET_REGION environment variables set.
"""
import json
import os
from langchain_community.chat_message_histories import RocksetChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
collection_name = "langchain_demo"
session_id = "MySession"
class TestRockset:
memory: RocksetChatMessageHistory
@classmethod
def setup_class(cls) -> None:
from rockset import DevRegions, Regions, RocksetClient
assert os.environ.get("ROCKSET_API_KEY") is not None
assert os.environ.get("ROCKSET_REGION") is not None
api_key = os.environ.get("ROCKSET_API_KEY")
region = os.environ.get("ROCKSET_REGION")
if region == "use1a1":
host = Regions.use1a1
elif region == "usw2a1" or not region:
host = Regions.usw2a1
elif region == "euc1a1":
host = Regions.euc1a1
elif region == "dev":
host = DevRegions.usw2a1
else:
host = region
client = RocksetClient(host, api_key)
cls.memory = RocksetChatMessageHistory(
session_id, client, collection_name, sync=True
)
def test_memory_with_message_store(self) -> None:
memory = ConversationBufferMemory(
memory_key="messages", chat_memory=self.memory, return_messages=True
)
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@@ -1,37 +0,0 @@
import json
from langchain_community.chat_message_histories import SingleStoreDBChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
# Replace these with your mongodb connection string
TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db"
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup SingleStoreDB as a message store
message_history = SingleStoreDBChatMessageHistory(
session_id="test-session",
host=TEST_SINGLESTOREDB_URL,
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from SingleStoreDB, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@@ -1,38 +0,0 @@
import json
import pytest
from langchain_community.chat_message_histories.upstash_redis import (
UpstashRedisChatMessageHistory,
)
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
URL = "<UPSTASH_REDIS_REST_URL>"
TOKEN = "<UPSTASH_REDIS_REST_TOKEN>"
@pytest.mark.requires("upstash_redis")
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup Upstash Redis as a message store
message_history = UpstashRedisChatMessageHistory(
url=URL, token=TOKEN, ttl=10, session_id="my-test-session"
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Redis, so the next test run won't pick it up
memory.chat_memory.clear()

View File

@@ -1,42 +0,0 @@
"""Test Xata chat memory store functionality.
Before running this test, please create a Xata database.
"""
import json
import os
from langchain_community.chat_message_histories import XataChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain.memory import ConversationBufferMemory
class TestXata:
@classmethod
def setup_class(cls) -> None:
assert os.getenv("XATA_API_KEY"), "XATA_API_KEY environment variable is not set"
assert os.getenv("XATA_DB_URL"), "XATA_DB_URL environment variable is not set"
def test_xata_chat_memory(self) -> None:
message_history = XataChatMessageHistory(
api_key=os.getenv("XATA_API_KEY", ""),
db_url=os.getenv("XATA_DB_URL", ""),
session_id="integration-test-session",
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from Redis, so the next test run won't pick it up
memory.chat_memory.clear()

View File

@@ -1,72 +0,0 @@
"""Test functionality related to ngram overlap based selector."""
import pytest
from langchain_community.example_selectors import (
NGramOverlapExampleSelector,
ngram_overlap_score,
)
from langchain_core.prompts import PromptTemplate
EXAMPLES = [
{"input": "See Spot run.", "output": "foo1"},
{"input": "My dog barks.", "output": "foo2"},
{"input": "Spot can run.", "output": "foo3"},
]
@pytest.fixture
def selector() -> NGramOverlapExampleSelector:
"""Get ngram overlap based selector to use in tests."""
prompts = PromptTemplate(
input_variables=["input", "output"], template="Input: {input}\nOutput: {output}"
)
selector = NGramOverlapExampleSelector(
examples=EXAMPLES,
example_prompt=prompts,
)
return selector
def test_selector_valid(selector: NGramOverlapExampleSelector) -> None:
"""Test NGramOverlapExampleSelector can select examples."""
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]]
def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None:
"""Test NGramOverlapExampleSelector can add an example."""
new_example = {"input": "Spot plays fetch.", "output": "foo4"}
selector.add_example(new_example)
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]]
def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None:
"""Tests NGramOverlapExampleSelector threshold set to 0.0."""
selector.threshold = 0.0
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == [EXAMPLES[2], EXAMPLES[0]]
def test_selector_threshold_more_than_one(
selector: NGramOverlapExampleSelector,
) -> None:
"""Tests NGramOverlapExampleSelector threshold greater than 1.0."""
selector.threshold = 1.0 + 1e-9
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == []
def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None:
"""Tests that ngram_overlap_score returns correct values."""
selector.threshold = 1.0 + 1e-9
none = ngram_overlap_score(["Spot can run."], ["My dog barks."])
some = ngram_overlap_score(["Spot can run."], ["See Spot run."])
complete = ngram_overlap_score(["Spot can run."], ["Spot can run."])
check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9]
assert check == [True, True, True]

View File

@@ -1,29 +0,0 @@
"""Integration test for compression pipelines."""
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_text_splitters.character import CharacterTextSplitter
from langchain.retrievers.document_compressors import (
DocumentCompressorPipeline,
EmbeddingsFilter,
)
def test_document_compressor_pipeline() -> None:
embeddings = OpenAIEmbeddings()
splitter = CharacterTextSplitter(chunk_size=20, chunk_overlap=0, separator=". ")
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8)
pipeline_filter = DocumentCompressorPipeline(
transformers=[splitter, redundant_filter, relevant_filter]
)
texts = [
"This sentence is about cows",
"This sentence was about cows",
"foo bar baz",
]
docs = [Document(page_content=". ".join(texts))]
actual = pipeline_filter.compress_documents(docs, "Tell me about farm animals")
assert len(actual) == 1
assert actual[0].page_content in texts[:2]

View File

@@ -1,45 +0,0 @@
"""Integration test for LLMChainExtractor."""
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import LLMChainExtractor
def test_llm_construction_with_kwargs() -> None:
llm_chain_kwargs = {"verbose": True}
compressor = LLMChainExtractor.from_llm(
ChatOpenAI(), llm_chain_kwargs=llm_chain_kwargs
)
assert compressor.llm_chain.verbose is True
def test_llm_chain_extractor() -> None:
texts = [
"The Roman Empire followed the Roman Republic.",
"I love chocolate chip cookies—my mother makes great cookies.",
"The first Roman emperor was Caesar Augustus.",
"Don't you just love Caesar salad?",
"The Roman Empire collapsed in 476 AD after the fall of Rome.",
"Let's go to Olive Garden!",
]
doc = Document(page_content=" ".join(texts))
compressor = LLMChainExtractor.from_llm(ChatOpenAI())
actual = compressor.compress_documents([doc], "Tell me about the Roman Empire")[
0
].page_content
expected_returned = [0, 2, 4]
expected_not_returned = [1, 3, 5]
assert all([texts[i] in actual for i in expected_returned])
assert all([texts[i] not in actual for i in expected_not_returned])
def test_llm_chain_extractor_empty() -> None:
texts = [
"I love chocolate chip cookies—my mother makes great cookies.",
"Don't you just love Caesar salad?",
"Let's go to Olive Garden!",
]
doc = Document(page_content=" ".join(texts))
compressor = LLMChainExtractor.from_llm(ChatOpenAI())
actual = compressor.compress_documents([doc], "Tell me about the Roman Empire")
assert len(actual) == 0

View File

@@ -1,18 +0,0 @@
"""Integration test for llm-based relevant doc filtering."""
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import LLMChainFilter
def test_llm_chain_filter() -> None:
texts = [
"What happened to all of my cookies?",
"I wish there were better Italian restaurants in my neighborhood.",
"My favorite color is green",
]
docs = [Document(page_content=t) for t in texts]
relevant_filter = LLMChainFilter.from_llm(llm=ChatOpenAI())
actual = relevant_filter.compress_documents(docs, "Things I said related to food")
assert len(actual) == 2
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2

View File

@@ -1,43 +0,0 @@
"""Integration test for embedding-based relevant doc filtering."""
import numpy as np
from langchain_community.document_transformers.embeddings_redundant_filter import (
_DocumentWithState,
)
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import EmbeddingsFilter
def test_embeddings_filter() -> None:
texts = [
"What happened to all of my cookies?",
"I wish there were better Italian restaurants in my neighborhood.",
"My favorite color is green",
]
docs = [Document(page_content=t) for t in texts]
embeddings = OpenAIEmbeddings()
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
actual = relevant_filter.compress_documents(docs, "What did I say about food?")
assert len(actual) == 2
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
def test_embeddings_filter_with_state() -> None:
texts = [
"What happened to all of my cookies?",
"I wish there were better Italian restaurants in my neighborhood.",
"My favorite color is green",
]
query = "What did I say about food?"
embeddings = OpenAIEmbeddings()
embedded_query = embeddings.embed_query(query)
state = {"embedded_doc": np.zeros(len(embedded_query))}
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
docs[-1].state = {"embedded_doc": embedded_query}
relevant_filter = EmbeddingsFilter( # type: ignore[call-arg]
embeddings=embeddings, similarity_threshold=0.75, return_similarity_scores=True
)
actual = relevant_filter.compress_documents(docs, query)
assert len(actual) == 1
assert texts[-1] == actual[0].page_content

View File

@@ -1,26 +0,0 @@
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
def test_contextual_compression_retriever_get_relevant_docs() -> None:
"""Test get_relevant_docs."""
texts = [
"This is a document about the Boston Celtics",
"The Boston Celtics won the game by 20 points",
"I simply love going to the movies",
]
embeddings = OpenAIEmbeddings()
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
base_retriever = FAISS.from_texts(texts, embedding=embeddings).as_retriever(
search_kwargs={"k": len(texts)}
)
retriever = ContextualCompressionRetriever(
base_compressor=base_compressor, base_retriever=base_retriever
)
actual = retriever.invoke("Tell me about the Celtics")
assert len(actual) == 2
assert texts[-1] not in [d.page_content for d in actual]

View File

@@ -1,33 +0,0 @@
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.retrievers.merger_retriever import MergerRetriever
def test_merger_retriever_get_relevant_docs() -> None:
"""Test get_relevant_docs."""
texts_group_a = [
"This is a document about the Boston Celtics",
"Fly me to the moon is one of my favourite songs."
"I simply love going to the movies",
]
texts_group_b = [
"This is a document about the Poenix Suns",
"The Boston Celtics won the game by 20 points",
"Real stupidity beats artificial intelligence every time. TP",
]
embeddings = OpenAIEmbeddings()
retriever_a = Chroma.from_texts(texts_group_a, embedding=embeddings).as_retriever(
search_kwargs={"k": 1}
)
retriever_b = Chroma.from_texts(texts_group_b, embedding=embeddings).as_retriever(
search_kwargs={"k": 1}
)
# The Lord of the Retrievers.
lotr = MergerRetriever(retrievers=[retriever_a, retriever_b])
actual = lotr.invoke("Tell me about the Celtics")
assert len(actual) == 2
assert texts_group_a[0] in [d.page_content for d in actual]
assert texts_group_b[1] in [d.page_content for d in actual]

View File

@@ -1,503 +0,0 @@
from typing import Iterator, List, Optional
from uuid import uuid4
import pytest
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms.openai import OpenAI
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts.chat import ChatPromptTemplate
from langsmith import Client as Client
from langsmith.evaluation import run_evaluator
from langsmith.schemas import DataType, Example, Run
from langchain.chains.llm import LLMChain
from langchain.evaluation import EvaluatorType
from langchain.smith import RunEvalConfig, run_on_dataset
from langchain.smith.evaluation import InputFormatError
from langchain.smith.evaluation.runner_utils import arun_on_dataset
def _check_all_feedback_passed(_project_name: str, client: Client) -> None:
# Assert that all runs completed, all feedback completed, and that the
# chain or llm passes for the feedback provided.
runs = list(client.list_runs(project_name=_project_name, execution_order=1))
if not runs:
# Queue delays. We are mainly just smoke checking rn.
return
feedback = list(client.list_feedback(run_ids=[run.id for run in runs]))
if not feedback:
return
assert all([bool(f.score) for f in feedback])
@run_evaluator
def not_empty(run: Run, example: Optional[Example] = None) -> dict:
return {
"score": run.outputs and next(iter(run.outputs.values())),
"key": "not_empty",
}
@pytest.fixture
def eval_project_name() -> str:
return f"lcp integration tests - {str(uuid4())[-8:]}"
@pytest.fixture(scope="module")
def client() -> Client:
return Client()
@pytest.fixture(
scope="module",
)
def kv_dataset_name() -> Iterator[str]:
import pandas as pd
client = Client()
df = pd.DataFrame(
{
"some_input": [
"What's the capital of California?",
"What's the capital of Nevada?",
"What's the capital of Oregon?",
"What's the capital of Washington?",
],
"other_input": [
"a",
"b",
"c",
"d",
],
"some_output": ["Sacramento", "Carson City", "Salem", "Olympia"],
"other_output": ["e", "f", "g", "h"],
}
)
uid = str(uuid4())[-8:]
_dataset_name = f"lcp kv dataset integration tests - {uid}"
client.upload_dataframe(
df,
name=_dataset_name,
input_keys=["some_input", "other_input"],
output_keys=["some_output", "other_output"],
description="Integration test dataset",
)
yield _dataset_name
def test_chat_model(
kv_dataset_name: str, eval_project_name: str, client: Client
) -> None:
llm = ChatOpenAI(temperature=0)
eval_config = RunEvalConfig(
evaluators=[EvaluatorType.QA], custom_evaluators=[not_empty]
)
with pytest.raises(ValueError, match="Must specify reference_key"):
run_on_dataset(
dataset_name=kv_dataset_name,
llm_or_chain_factory=llm,
evaluation=eval_config,
client=client,
)
eval_config = RunEvalConfig(
evaluators=[EvaluatorType.QA],
reference_key="some_output",
)
with pytest.raises(
InputFormatError, match="Example inputs do not match language model"
):
run_on_dataset(
dataset_name=kv_dataset_name,
llm_or_chain_factory=llm,
evaluation=eval_config,
client=client,
)
def input_mapper(d: dict) -> List[BaseMessage]:
return [HumanMessage(content=d["some_input"])]
run_on_dataset(
client=client,
dataset_name=kv_dataset_name,
llm_or_chain_factory=input_mapper | llm,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
def test_llm(kv_dataset_name: str, eval_project_name: str, client: Client) -> None:
llm = OpenAI(temperature=0)
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA])
with pytest.raises(ValueError, match="Must specify reference_key"):
run_on_dataset(
dataset_name=kv_dataset_name,
llm_or_chain_factory=llm,
evaluation=eval_config,
client=client,
)
eval_config = RunEvalConfig(
evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA],
reference_key="some_output",
)
with pytest.raises(InputFormatError, match="Example inputs"):
run_on_dataset(
dataset_name=kv_dataset_name,
llm_or_chain_factory=llm,
evaluation=eval_config,
client=client,
)
def input_mapper(d: dict) -> str:
return d["some_input"]
run_on_dataset(
client=client,
dataset_name=kv_dataset_name,
llm_or_chain_factory=input_mapper | llm,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
def test_chain(kv_dataset_name: str, eval_project_name: str, client: Client) -> None:
llm = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(llm, "The answer to the {question} is: ")
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
with pytest.raises(ValueError, match="Must specify reference_key"):
run_on_dataset(
dataset_name=kv_dataset_name,
llm_or_chain_factory=lambda: chain,
evaluation=eval_config,
client=client,
)
eval_config = RunEvalConfig(
evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA],
reference_key="some_output",
)
with pytest.raises(InputFormatError, match="Example inputs"):
run_on_dataset(
dataset_name=kv_dataset_name,
llm_or_chain_factory=lambda: chain,
evaluation=eval_config,
client=client,
)
eval_config = RunEvalConfig(
custom_evaluators=[not_empty],
)
def right_input_mapper(d: dict) -> dict:
return {"question": d["some_input"]}
run_on_dataset(
dataset_name=kv_dataset_name,
llm_or_chain_factory=lambda: right_input_mapper | chain,
client=client,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
### Testing Chat Datasets
@pytest.fixture(
scope="module",
)
def chat_dataset_name() -> Iterator[str]:
def _create_message(txt: str, role: str = "human") -> List[dict]:
return [{"type": role, "data": {"content": txt}}]
import pandas as pd
client = Client()
df = pd.DataFrame(
{
"input": [
_create_message(txt)
for txt in (
"What's the capital of California?",
"What's the capital of Nevada?",
"What's the capital of Oregon?",
"What's the capital of Washington?",
)
],
"output": [
_create_message(txt, role="ai")[0]
for txt in ("Sacramento", "Carson City", "Salem", "Olympia")
],
}
)
uid = str(uuid4())[-8:]
_dataset_name = f"lcp chat dataset integration tests - {uid}"
ds = client.create_dataset(
_dataset_name, description="Integration test dataset", data_type=DataType.chat
)
for row in df.itertuples():
client.create_example(
dataset_id=ds.id,
inputs={"input": row.input},
outputs={"output": row.output},
)
yield _dataset_name
def test_chat_model_on_chat_dataset(
chat_dataset_name: str, eval_project_name: str, client: Client
) -> None:
llm = ChatOpenAI(temperature=0)
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
run_on_dataset(
dataset_name=chat_dataset_name,
llm_or_chain_factory=llm,
evaluation=eval_config,
client=client,
project_name=eval_project_name,
)
_check_all_feedback_passed(eval_project_name, client)
def test_llm_on_chat_dataset(
chat_dataset_name: str, eval_project_name: str, client: Client
) -> None:
llm = OpenAI(temperature=0)
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
run_on_dataset(
dataset_name=chat_dataset_name,
llm_or_chain_factory=llm,
client=client,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
def test_chain_on_chat_dataset(chat_dataset_name: str, client: Client) -> None:
llm = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(llm, "The answer to the {question} is: ")
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
with pytest.raises(
ValueError, match="Cannot evaluate a chain on dataset with data_type=chat"
):
run_on_dataset(
dataset_name=chat_dataset_name,
client=client,
llm_or_chain_factory=lambda: chain,
evaluation=eval_config,
)
@pytest.fixture(
scope="module",
)
def llm_dataset_name() -> Iterator[str]:
import pandas as pd
client = Client()
df = pd.DataFrame(
{
"input": [
"What's the capital of California?",
"What's the capital of Nevada?",
"What's the capital of Oregon?",
"What's the capital of Washington?",
],
"output": ["Sacramento", "Carson City", "Salem", "Olympia"],
}
)
uid = str(uuid4())[-8:]
_dataset_name = f"lcp llm dataset integration tests - {uid}"
client.upload_dataframe(
df,
name=_dataset_name,
input_keys=["input"],
output_keys=["output"],
description="Integration test dataset",
data_type=DataType.llm,
)
yield _dataset_name
def test_chat_model_on_llm_dataset(
llm_dataset_name: str, eval_project_name: str, client: Client
) -> None:
llm = ChatOpenAI(temperature=0)
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
run_on_dataset(
client=client,
dataset_name=llm_dataset_name,
llm_or_chain_factory=llm,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
def test_llm_on_llm_dataset(
llm_dataset_name: str, eval_project_name: str, client: Client
) -> None:
llm = OpenAI(temperature=0)
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
run_on_dataset(
client=client,
dataset_name=llm_dataset_name,
llm_or_chain_factory=llm,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
def test_chain_on_llm_dataset(llm_dataset_name: str, client: Client) -> None:
llm = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(llm, "The answer to the {question} is: ")
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
with pytest.raises(
ValueError, match="Cannot evaluate a chain on dataset with data_type=llm"
):
run_on_dataset(
client=client,
dataset_name=llm_dataset_name,
llm_or_chain_factory=lambda: chain,
evaluation=eval_config,
)
@pytest.fixture(
scope="module",
)
def kv_singleio_dataset_name() -> Iterator[str]:
import pandas as pd
client = Client()
df = pd.DataFrame(
{
"the wackiest input": [
"What's the capital of California?",
"What's the capital of Nevada?",
"What's the capital of Oregon?",
"What's the capital of Washington?",
],
"unthinkable output": ["Sacramento", "Carson City", "Salem", "Olympia"],
}
)
uid = str(uuid4())[-8:]
_dataset_name = f"lcp singleio kv dataset integration tests - {uid}"
client.upload_dataframe(
df,
name=_dataset_name,
input_keys=["the wackiest input"],
output_keys=["unthinkable output"],
description="Integration test dataset",
)
yield _dataset_name
def test_chat_model_on_kv_singleio_dataset(
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
) -> None:
llm = ChatOpenAI(temperature=0)
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
run_on_dataset(
dataset_name=kv_singleio_dataset_name,
llm_or_chain_factory=llm,
evaluation=eval_config,
client=client,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
def test_llm_on_kv_singleio_dataset(
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
) -> None:
llm = OpenAI(temperature=0)
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
run_on_dataset(
dataset_name=kv_singleio_dataset_name,
llm_or_chain_factory=llm,
client=client,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
def test_chain_on_kv_singleio_dataset(
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
) -> None:
llm = ChatOpenAI(temperature=0)
chain = LLMChain.from_string(llm, "The answer to the {question} is: ")
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
run_on_dataset(
dataset_name=kv_singleio_dataset_name,
llm_or_chain_factory=lambda: chain,
client=client,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
async def test_runnable_on_kv_singleio_dataset(
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
) -> None:
runnable = (
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
| ChatOpenAI()
)
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
await arun_on_dataset(
dataset_name=kv_singleio_dataset_name,
llm_or_chain_factory=runnable,
client=client,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)
async def test_arb_func_on_kv_singleio_dataset(
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
) -> None:
runnable = (
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
| ChatOpenAI()
)
def my_func(x: dict) -> str:
content = runnable.invoke(x).content
if isinstance(content, str):
return content
else:
raise ValueError(
f"Expected message with content type string, got {content}"
)
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
await arun_on_dataset(
dataset_name=kv_singleio_dataset_name,
llm_or_chain_factory=my_func,
client=client,
evaluation=eval_config,
project_name=eval_project_name,
tags=["shouldpass"],
)
_check_all_feedback_passed(eval_project_name, client)

View File

@@ -1,9 +0,0 @@
"""Integration test for DallE API Wrapper."""
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
def test_call() -> None:
"""Test that call returns a URL in the output."""
search = DallEAPIWrapper() # type: ignore[call-arg]
output = search.run("volcano island")
assert "https://oaidalleapi" in output

View File

@@ -1,72 +0,0 @@
"""Integration test for embedding-based redundant doc filtering."""
from langchain_community.document_transformers.embeddings_redundant_filter import (
EmbeddingsClusteringFilter,
EmbeddingsRedundantFilter,
_DocumentWithState,
)
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document
def test_embeddings_redundant_filter() -> None:
texts = [
"What happened to all of my cookies?",
"Where did all of my cookies go?",
"I wish there were better Italian restaurants in my neighborhood.",
]
docs = [Document(page_content=t) for t in texts]
embeddings = OpenAIEmbeddings()
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
actual = redundant_filter.transform_documents(docs)
assert len(actual) == 2
assert set(texts[:2]).intersection([d.page_content for d in actual])
def test_embeddings_redundant_filter_with_state() -> None:
texts = ["What happened to all of my cookies?", "foo bar baz"]
state = {"embedded_doc": [0.5] * 10}
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
embeddings = OpenAIEmbeddings()
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
actual = redundant_filter.transform_documents(docs)
assert len(actual) == 1
def test_embeddings_clustering_filter() -> None:
texts = [
"What happened to all of my cookies?",
"A cookie is a small, baked sweet treat and you can find it in the cookie",
"monsters' jar.",
"Cookies are good.",
"I have nightmares about the cookie monster.",
"The most popular pizza styles are: Neapolitan, New York-style and",
"Chicago-style. You can find them on iconic restaurants in major cities.",
"Neapolitan pizza: This is the original pizza style,hailing from Naples,",
"Italy.",
"I wish there were better Italian Pizza restaurants in my neighborhood.",
"New York-style pizza: This is characterized by its large, thin crust, and",
"generous toppings.",
"The first movie to feature a robot was 'A Trip to the Moon' (1902).",
"The first movie to feature a robot that could pass for a human was",
"'Blade Runner' (1982)",
"The first movie to feature a robot that could fall in love with a human",
"was 'Her' (2013)",
"A robot is a machine capable of carrying out complex actions automatically.",
"There are certainly hundreds, if not thousands movies about robots like:",
"'Blade Runner', 'Her' and 'A Trip to the Moon'",
]
docs = [Document(page_content=t) for t in texts]
embeddings = OpenAIEmbeddings()
redundant_filter = EmbeddingsClusteringFilter(
embeddings=embeddings,
num_clusters=3,
num_closest=1,
sorted=True,
)
actual = redundant_filter.transform_documents(docs)
assert len(actual) == 3
assert texts[1] in [d.page_content for d in actual]
assert texts[4] in [d.page_content for d in actual]
assert texts[11] in [d.page_content for d in actual]

View File

@@ -1,37 +0,0 @@
"""Integration test for doc reordering."""
from langchain_community.document_transformers.long_context_reorder import (
LongContextReorder,
)
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
def test_long_context_reorder() -> None:
"""Test Lost in the middle reordering get_relevant_docs."""
texts = [
"Basquetball is a great sport.",
"Fly me to the moon is one of my favourite songs.",
"The Celtics are my favourite team.",
"This is a document about the Boston Celtics",
"I simply love going to the movies",
"The Boston Celtics won the game by 20 points",
"This is just a random text.",
"Elden Ring is one of the best games in the last 15 years.",
"L. Kornet is one of the best Celtics players.",
"Larry Bird was an iconic NBA player.",
]
embeddings = OpenAIEmbeddings()
retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever(
search_kwargs={"k": 10}
)
reordering = LongContextReorder()
docs = retriever.invoke("Tell me about the Celtics")
actual = reordering.transform_documents(docs)
# First 2 and Last 2 elements must contain the most relevant
first_and_last = list(actual[:2]) + list(actual[-2:])
assert len(actual) == 10
assert texts[2] in [d.page_content for d in first_and_last]
assert texts[3] in [d.page_content for d in first_and_last]
assert texts[5] in [d.page_content for d in first_and_last]
assert texts[8] in [d.page_content for d in first_and_last]

View File

@@ -1,62 +0,0 @@
import asyncio
import json
from typing import Any
from unittest import mock
from langchain_community.document_transformers.nuclia_text_transform import (
NucliaTextTransformer,
)
from langchain_community.tools.nuclia.tool import NucliaUnderstandingAPI
from langchain_core.documents import Document
def fakerun(**args: Any) -> Any:
async def run(self: Any, **args: Any) -> str:
await asyncio.sleep(0.1)
data = {
"extracted_text": [{"body": {"text": "Hello World"}}],
"file_extracted_data": [{"language": "en"}],
"field_metadata": [
{
"metadata": {
"metadata": {
"paragraphs": [
{"end": 66, "sentences": [{"start": 1, "end": 67}]}
]
}
}
}
],
}
return json.dumps(data)
return run
async def test_nuclia_loader() -> None:
with mock.patch(
"langchain_community.tools.nuclia.tool.NucliaUnderstandingAPI._arun",
new_callable=fakerun,
):
with mock.patch("os.environ.get", return_value="_a_key_"):
nua = NucliaUnderstandingAPI(enable_ml=False)
documents = [
Document(page_content="Hello, my name is Alice", metadata={}),
Document(page_content="Hello, my name is Bob", metadata={}),
]
nuclia_transformer = NucliaTextTransformer(nua)
transformed_documents = await nuclia_transformer.atransform_documents(
documents
)
assert len(transformed_documents) == 2
assert (
transformed_documents[0].metadata["nuclia"]["file"]["language"] == "en"
)
assert (
len(
transformed_documents[1].metadata["nuclia"]["metadata"]["metadata"][
"metadata"
]["paragraphs"]
)
== 1
)

View File

@@ -1,19 +0,0 @@
"""Test splitting with page numbers included."""
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
def test_pdf_pagesplitter() -> None:
"""Test splitting with page numbers included."""
script_dir = os.path.dirname(__file__)
loader = PyPDFLoader(os.path.join(script_dir, "examples/hello.pdf"))
docs = loader.load()
assert "page" in docs[0].metadata
assert "source" in docs[0].metadata
faiss_index = FAISS.from_documents(docs, OpenAIEmbeddings())
docs = faiss_index.similarity_search("Complete this sentence: Hello", k=1)
assert "Hello world" in docs[0].page_content