mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
Move integration tests
This commit is contained in:
@@ -1,175 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import urllib.request
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from urllib.error import HTTPError
|
||||
|
||||
import pytest
|
||||
from langchain_community.agent_toolkits.ainetwork.toolkit import AINetworkToolkit
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.tools.ainetwork.utils import authenticate
|
||||
|
||||
from langchain.agents import AgentType, initialize_agent
|
||||
|
||||
|
||||
class Match(Enum):
|
||||
__test__ = False
|
||||
ListWildcard = 1
|
||||
StrWildcard = 2
|
||||
DictWildcard = 3
|
||||
IntWildcard = 4
|
||||
FloatWildcard = 5
|
||||
ObjectWildcard = 6
|
||||
|
||||
@classmethod
|
||||
def match(cls, value: Any, template: Any) -> bool:
|
||||
if template is cls.ListWildcard:
|
||||
return isinstance(value, list)
|
||||
elif template is cls.StrWildcard:
|
||||
return isinstance(value, str)
|
||||
elif template is cls.DictWildcard:
|
||||
return isinstance(value, dict)
|
||||
elif template is cls.IntWildcard:
|
||||
return isinstance(value, int)
|
||||
elif template is cls.FloatWildcard:
|
||||
return isinstance(value, float)
|
||||
elif template is cls.ObjectWildcard:
|
||||
return True
|
||||
elif type(value) != type(template):
|
||||
return False
|
||||
elif isinstance(value, dict):
|
||||
if len(value) != len(template):
|
||||
return False
|
||||
for k, v in value.items():
|
||||
if k not in template or not cls.match(v, template[k]):
|
||||
return False
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
if len(value) != len(template):
|
||||
return False
|
||||
for i in range(len(value)):
|
||||
if not cls.match(value[i], template[i]):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return value == template
|
||||
|
||||
|
||||
@pytest.mark.requires("ain")
|
||||
def test_ainetwork_toolkit() -> None:
|
||||
def get(path: str, type: str = "value", default: Any = None) -> Any:
|
||||
ref = ain.db.ref(path)
|
||||
value = asyncio.run(
|
||||
{
|
||||
"value": ref.getValue,
|
||||
"rule": ref.getRule,
|
||||
"owner": ref.getOwner,
|
||||
}[type]()
|
||||
)
|
||||
return default if value is None else value
|
||||
|
||||
def validate(path: str, template: Any, type: str = "value") -> bool:
|
||||
value = get(path, type)
|
||||
return Match.match(value, template)
|
||||
|
||||
if not os.environ.get("AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY", None):
|
||||
from ain.account import Account
|
||||
|
||||
account = Account.create()
|
||||
os.environ["AIN_BLOCKCHAIN_ACCOUNT_PRIVATE_KEY"] = account.private_key
|
||||
|
||||
interface = authenticate(network="testnet")
|
||||
toolkit = AINetworkToolkit(network="testnet", interface=interface)
|
||||
llm = ChatOpenAI(model="gpt-4", temperature=0)
|
||||
agent = initialize_agent(
|
||||
tools=toolkit.get_tools(),
|
||||
llm=llm,
|
||||
verbose=True,
|
||||
agent=AgentType.OPENAI_FUNCTIONS,
|
||||
)
|
||||
ain = interface
|
||||
self_address = ain.wallet.defaultAccount.address
|
||||
co_address = "0x6813Eb9362372EEF6200f3b1dbC3f819671cBA69"
|
||||
|
||||
# Test creating an app
|
||||
UUID = uuid.UUID(
|
||||
int=(int(time.time() * 1000) << 64) | (uuid.uuid4().int & ((1 << 64) - 1))
|
||||
)
|
||||
app_name = f"_langchain_test__{str(UUID).replace('-', '_')}"
|
||||
agent.run(f"""Create app {app_name}""")
|
||||
validate(f"/manage_app/{app_name}/config", {"admin": {self_address: True}})
|
||||
validate(f"/apps/{app_name}/DB", None, "owner")
|
||||
|
||||
# Test reading owner config
|
||||
agent.run(f"""Read owner config of /apps/{app_name}/DB .""")
|
||||
assert ...
|
||||
|
||||
# Test granting owner config
|
||||
agent.run(
|
||||
f"""Grant owner authority to {co_address} for edit write rule permission of /apps/{app_name}/DB_co .""" # noqa: E501
|
||||
)
|
||||
validate(
|
||||
f"/apps/{app_name}/DB_co",
|
||||
{
|
||||
".owner": {
|
||||
"owners": {
|
||||
co_address: {
|
||||
"branch_owner": False,
|
||||
"write_function": False,
|
||||
"write_owner": False,
|
||||
"write_rule": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"owner",
|
||||
)
|
||||
|
||||
# Test reading owner config
|
||||
agent.run(f"""Read owner config of /apps/{app_name}/DB_co .""")
|
||||
assert ...
|
||||
|
||||
# Test reading owner config
|
||||
agent.run(f"""Read owner config of /apps/{app_name}/DB .""")
|
||||
assert ... # Check if owner {self_address} exists
|
||||
|
||||
# Test reading a value
|
||||
agent.run(f"""Read value in /apps/{app_name}/DB""")
|
||||
assert ... # empty
|
||||
|
||||
# Test writing a value
|
||||
agent.run(f"""Write value {{1: 1904, 2: 43}} in /apps/{app_name}/DB""")
|
||||
validate(f"/apps/{app_name}/DB", {1: 1904, 2: 43})
|
||||
|
||||
# Test reading a value
|
||||
agent.run(f"""Read value in /apps/{app_name}/DB""")
|
||||
assert ... # check value
|
||||
|
||||
# Test reading a rule
|
||||
agent.run(f"""Read write rule of app {app_name} .""")
|
||||
assert ... # check rule that self_address exists
|
||||
|
||||
# Test sending AIN
|
||||
self_balance = get(f"/accounts/{self_address}/balance", default=0)
|
||||
transaction_history = get(f"/transfer/{self_address}/{co_address}", default={})
|
||||
if self_balance < 1:
|
||||
try:
|
||||
with urllib.request.urlopen(
|
||||
f"http://faucet.ainetwork.ai/api/test/{self_address}/"
|
||||
) as response:
|
||||
try_test = response.getcode()
|
||||
except HTTPError as e:
|
||||
try_test = e.getcode()
|
||||
else:
|
||||
try_test = 200
|
||||
|
||||
if try_test == 200:
|
||||
agent.run(f"""Send 1 AIN to {co_address}""")
|
||||
transaction_update = get(f"/transfer/{self_address}/{co_address}", default={})
|
||||
assert any(
|
||||
transaction_update[key]["value"] == 1
|
||||
for key in transaction_update.keys() - transaction_history.keys()
|
||||
)
|
||||
@@ -1,46 +0,0 @@
|
||||
import pytest
|
||||
from langchain_community.agent_toolkits import PowerBIToolkit, create_pbi_agent
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.utilities.powerbi import PowerBIDataset
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
|
||||
def azure_installed() -> bool:
|
||||
try:
|
||||
from azure.core.credentials import TokenCredential # noqa: F401
|
||||
from azure.identity import DefaultAzureCredential # noqa: F401
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"azure not installed, skipping test {e}") # noqa: T201
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not azure_installed(), reason="requires azure package")
|
||||
def test_daxquery() -> None:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
DATASET_ID = get_from_env("", "POWERBI_DATASET_ID")
|
||||
TABLE_NAME = get_from_env("", "POWERBI_TABLE_NAME")
|
||||
NUM_ROWS = get_from_env("", "POWERBI_NUMROWS")
|
||||
|
||||
fast_llm = ChatOpenAI(
|
||||
temperature=0.5, max_tokens=1000, model_name="gpt-3.5-turbo", verbose=True
|
||||
) # type: ignore[call-arg]
|
||||
smart_llm = ChatOpenAI(
|
||||
temperature=0, max_tokens=100, model_name="gpt-4", verbose=True
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
toolkit = PowerBIToolkit(
|
||||
powerbi=PowerBIDataset(
|
||||
dataset_id=DATASET_ID,
|
||||
table_names=[TABLE_NAME],
|
||||
credential=DefaultAzureCredential(),
|
||||
),
|
||||
llm=smart_llm,
|
||||
)
|
||||
|
||||
agent_executor = create_pbi_agent(llm=fast_llm, toolkit=toolkit, verbose=True)
|
||||
|
||||
output = agent_executor.run(f"How many rows are in the table, {TABLE_NAME}")
|
||||
assert NUM_ROWS in output
|
||||
@@ -1,159 +0,0 @@
|
||||
"""
|
||||
Test AstraDB caches. Requires an Astra DB vector instance.
|
||||
|
||||
Required to run this test:
|
||||
- a recent `astrapy` Python package available
|
||||
- an Astra DB instance;
|
||||
- the two environment variables set:
|
||||
export ASTRA_DB_API_ENDPOINT="https://<DB-ID>-us-east1.apps.astra.datastax.com"
|
||||
export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........."
|
||||
- optionally this as well (otherwise defaults are used):
|
||||
export ASTRA_DB_KEYSPACE="my_keyspace"
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import AsyncIterator, Iterator
|
||||
|
||||
import pytest
|
||||
from langchain_community.cache import AstraDBCache, AstraDBSemanticCache
|
||||
from langchain_community.utilities.astradb import SetupMode
|
||||
from langchain_core.caches import BaseCache
|
||||
from langchain_core.language_models import LLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from tests.integration_tests.cache.fake_embeddings import FakeEmbeddings
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def _has_env_vars() -> bool:
|
||||
return all(
|
||||
[
|
||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def astradb_cache() -> Iterator[AstraDBCache]:
|
||||
cache = AstraDBCache(
|
||||
collection_name="lc_integration_test_cache",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield cache
|
||||
cache.collection.astra_db.delete_collection("lc_integration_test_cache")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_astradb_cache() -> AsyncIterator[AstraDBCache]:
|
||||
cache = AstraDBCache(
|
||||
collection_name="lc_integration_test_cache_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield cache
|
||||
await cache.async_collection.astra_db.delete_collection(
|
||||
"lc_integration_test_cache_async"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]:
|
||||
fake_embe = FakeEmbeddings()
|
||||
sem_cache = AstraDBSemanticCache(
|
||||
collection_name="lc_integration_test_sem_cache",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
embedding=fake_embe,
|
||||
)
|
||||
yield sem_cache
|
||||
sem_cache.collection.astra_db.delete_collection("lc_integration_test_sem_cache")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_astradb_semantic_cache() -> AsyncIterator[AstraDBSemanticCache]:
|
||||
fake_embe = FakeEmbeddings()
|
||||
sem_cache = AstraDBSemanticCache(
|
||||
collection_name="lc_integration_test_sem_cache_async",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
embedding=fake_embe,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield sem_cache
|
||||
sem_cache.collection.astra_db.delete_collection(
|
||||
"lc_integration_test_sem_cache_async"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
class TestAstraDBCaches:
|
||||
def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None:
|
||||
self.do_cache_test(FakeLLM(), astradb_cache, "foo")
|
||||
|
||||
async def test_astradb_cache_async(self, async_astradb_cache: AstraDBCache) -> None:
|
||||
await self.ado_cache_test(FakeLLM(), async_astradb_cache, "foo")
|
||||
|
||||
def test_astradb_semantic_cache(
|
||||
self, astradb_semantic_cache: AstraDBSemanticCache
|
||||
) -> None:
|
||||
llm = FakeLLM()
|
||||
self.do_cache_test(llm, astradb_semantic_cache, "bar")
|
||||
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
astradb_semantic_cache.clear()
|
||||
|
||||
async def test_astradb_semantic_cache_async(
|
||||
self, async_astradb_semantic_cache: AstraDBSemanticCache
|
||||
) -> None:
|
||||
llm = FakeLLM()
|
||||
await self.ado_cache_test(llm, async_astradb_semantic_cache, "bar")
|
||||
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
await async_astradb_semantic_cache.aclear()
|
||||
|
||||
@staticmethod
|
||||
def do_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
|
||||
set_llm_cache(cache)
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
cache.clear()
|
||||
|
||||
@staticmethod
|
||||
async def ado_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None:
|
||||
set_llm_cache(cache)
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||
output = await llm.agenerate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
await cache.aclear()
|
||||
@@ -1,359 +0,0 @@
|
||||
"""Test Azure CosmosDB cache functionality.
|
||||
|
||||
Required to run this test:
|
||||
- a recent 'pymongo' Python package available
|
||||
- an Azure CosmosDB Mongo vCore instance
|
||||
- one environment variable set:
|
||||
export MONGODB_VCORE_URI="connection string for azure cosmos db mongo vCore"
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from langchain_community.cache import AzureCosmosDBSemanticCache
|
||||
from langchain_community.vectorstores.azure_cosmos_db import (
|
||||
CosmosDBSimilarityType,
|
||||
CosmosDBVectorSearchType,
|
||||
)
|
||||
from langchain_core.outputs import Generation
|
||||
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from tests.integration_tests.cache.fake_embeddings import (
|
||||
FakeEmbeddings,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
INDEX_NAME = "langchain-test-index"
|
||||
NAMESPACE = "langchain_test_db.langchain_test_collection"
|
||||
CONNECTION_STRING: str = os.environ.get("MONGODB_VCORE_URI", "")
|
||||
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
||||
|
||||
num_lists = 3
|
||||
dimensions = 10
|
||||
similarity_algorithm = CosmosDBSimilarityType.COS
|
||||
kind = CosmosDBVectorSearchType.VECTOR_IVF
|
||||
m = 16
|
||||
ef_construction = 64
|
||||
ef_search = 40
|
||||
score_threshold = 0.1
|
||||
application_name = "LANGCHAIN_CACHING_PYTHON"
|
||||
|
||||
|
||||
def _has_env_vars() -> bool:
|
||||
return all(["MONGODB_VCORE_URI" in os.environ])
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
@pytest.mark.skipif(
|
||||
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
|
||||
)
|
||||
def test_azure_cosmos_db_semantic_cache() -> None:
|
||||
set_llm_cache(
|
||||
AzureCosmosDBSemanticCache(
|
||||
cosmosdb_connection_string=CONNECTION_STRING,
|
||||
cosmosdb_client=None,
|
||||
embedding=FakeEmbeddings(),
|
||||
database_name=DB_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
num_lists=num_lists,
|
||||
similarity=similarity_algorithm,
|
||||
kind=kind,
|
||||
dimensions=dimensions,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
ef_search=ef_search,
|
||||
score_threshold=score_threshold,
|
||||
application_name=application_name,
|
||||
)
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
|
||||
# foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz")]
|
||||
|
||||
# clear the cache
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
@pytest.mark.skipif(
|
||||
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
|
||||
)
|
||||
def test_azure_cosmos_db_semantic_cache_inner_product() -> None:
|
||||
set_llm_cache(
|
||||
AzureCosmosDBSemanticCache(
|
||||
cosmosdb_connection_string=CONNECTION_STRING,
|
||||
cosmosdb_client=None,
|
||||
embedding=FakeEmbeddings(),
|
||||
database_name=DB_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
num_lists=num_lists,
|
||||
similarity=CosmosDBSimilarityType.IP,
|
||||
kind=kind,
|
||||
dimensions=dimensions,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
ef_search=ef_search,
|
||||
score_threshold=score_threshold,
|
||||
application_name=application_name,
|
||||
)
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
|
||||
# foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz")]
|
||||
|
||||
# clear the cache
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
@pytest.mark.skipif(
|
||||
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
|
||||
)
|
||||
def test_azure_cosmos_db_semantic_cache_multi() -> None:
|
||||
set_llm_cache(
|
||||
AzureCosmosDBSemanticCache(
|
||||
cosmosdb_connection_string=CONNECTION_STRING,
|
||||
cosmosdb_client=None,
|
||||
embedding=FakeEmbeddings(),
|
||||
database_name=DB_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
num_lists=num_lists,
|
||||
similarity=similarity_algorithm,
|
||||
kind=kind,
|
||||
dimensions=dimensions,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
ef_search=ef_search,
|
||||
score_threshold=score_threshold,
|
||||
application_name=application_name,
|
||||
)
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
|
||||
# foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
|
||||
# clear the cache
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
@pytest.mark.skipif(
|
||||
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
|
||||
)
|
||||
def test_azure_cosmos_db_semantic_cache_multi_inner_product() -> None:
|
||||
set_llm_cache(
|
||||
AzureCosmosDBSemanticCache(
|
||||
cosmosdb_connection_string=CONNECTION_STRING,
|
||||
cosmosdb_client=None,
|
||||
embedding=FakeEmbeddings(),
|
||||
database_name=DB_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
num_lists=num_lists,
|
||||
similarity=CosmosDBSimilarityType.IP,
|
||||
kind=kind,
|
||||
dimensions=dimensions,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
ef_search=ef_search,
|
||||
score_threshold=score_threshold,
|
||||
application_name=application_name,
|
||||
)
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
|
||||
# foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
|
||||
# clear the cache
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
@pytest.mark.skipif(
|
||||
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
|
||||
)
|
||||
def test_azure_cosmos_db_semantic_cache_hnsw() -> None:
|
||||
set_llm_cache(
|
||||
AzureCosmosDBSemanticCache(
|
||||
cosmosdb_connection_string=CONNECTION_STRING,
|
||||
cosmosdb_client=None,
|
||||
embedding=FakeEmbeddings(),
|
||||
database_name=DB_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
num_lists=num_lists,
|
||||
similarity=similarity_algorithm,
|
||||
kind=CosmosDBVectorSearchType.VECTOR_HNSW,
|
||||
dimensions=dimensions,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
ef_search=ef_search,
|
||||
score_threshold=score_threshold,
|
||||
application_name=application_name,
|
||||
)
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
|
||||
# foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz")]
|
||||
|
||||
# clear the cache
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
@pytest.mark.skipif(
|
||||
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
|
||||
)
|
||||
def test_azure_cosmos_db_semantic_cache_inner_product_hnsw() -> None:
|
||||
set_llm_cache(
|
||||
AzureCosmosDBSemanticCache(
|
||||
cosmosdb_connection_string=CONNECTION_STRING,
|
||||
cosmosdb_client=None,
|
||||
embedding=FakeEmbeddings(),
|
||||
database_name=DB_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
num_lists=num_lists,
|
||||
similarity=CosmosDBSimilarityType.IP,
|
||||
kind=CosmosDBVectorSearchType.VECTOR_HNSW,
|
||||
dimensions=dimensions,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
ef_search=ef_search,
|
||||
score_threshold=score_threshold,
|
||||
application_name=application_name,
|
||||
)
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
|
||||
# foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz")]
|
||||
|
||||
# clear the cache
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
@pytest.mark.skipif(
|
||||
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
|
||||
)
|
||||
def test_azure_cosmos_db_semantic_cache_multi_hnsw() -> None:
|
||||
set_llm_cache(
|
||||
AzureCosmosDBSemanticCache(
|
||||
cosmosdb_connection_string=CONNECTION_STRING,
|
||||
cosmosdb_client=None,
|
||||
embedding=FakeEmbeddings(),
|
||||
database_name=DB_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
num_lists=num_lists,
|
||||
similarity=similarity_algorithm,
|
||||
kind=CosmosDBVectorSearchType.VECTOR_HNSW,
|
||||
dimensions=dimensions,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
ef_search=ef_search,
|
||||
score_threshold=score_threshold,
|
||||
application_name=application_name,
|
||||
)
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
|
||||
# foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
|
||||
# clear the cache
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.requires("pymongo")
|
||||
@pytest.mark.skipif(
|
||||
not _has_env_vars(), reason="Missing Azure CosmosDB Mongo vCore env. vars"
|
||||
)
|
||||
def test_azure_cosmos_db_semantic_cache_multi_inner_product_hnsw() -> None:
|
||||
set_llm_cache(
|
||||
AzureCosmosDBSemanticCache(
|
||||
cosmosdb_connection_string=CONNECTION_STRING,
|
||||
cosmosdb_client=None,
|
||||
embedding=FakeEmbeddings(),
|
||||
database_name=DB_NAME,
|
||||
collection_name=COLLECTION_NAME,
|
||||
num_lists=num_lists,
|
||||
similarity=CosmosDBSimilarityType.IP,
|
||||
kind=CosmosDBVectorSearchType.VECTOR_HNSW,
|
||||
dimensions=dimensions,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
ef_search=ef_search,
|
||||
score_threshold=score_threshold,
|
||||
application_name=application_name,
|
||||
)
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
|
||||
# foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
|
||||
# clear the cache
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
@@ -1,177 +0,0 @@
|
||||
"""Test Cassandra caches. Requires a running vector-capable Cassandra cluster."""
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Iterator, Tuple
|
||||
|
||||
import pytest
|
||||
from langchain_community.cache import CassandraCache, CassandraSemanticCache
|
||||
from langchain_community.utilities.cassandra import SetupMode
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from tests.integration_tests.cache.fake_embeddings import FakeEmbeddings
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cassandra_connection() -> Iterator[Tuple[Any, str]]:
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
keyspace = "langchain_cache_test_keyspace"
|
||||
# get db connection
|
||||
if "CASSANDRA_CONTACT_POINTS" in os.environ:
|
||||
contact_points = os.environ["CONTACT_POINTS"].split(",")
|
||||
cluster = Cluster(contact_points)
|
||||
else:
|
||||
cluster = Cluster()
|
||||
#
|
||||
session = cluster.connect()
|
||||
# ensure keyspace exists
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
|
||||
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||
)
|
||||
)
|
||||
|
||||
yield (session, keyspace)
|
||||
|
||||
|
||||
def test_cassandra_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(session=session, keyspace=keyspace)
|
||||
set_llm_cache(cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
cache.clear()
|
||||
|
||||
|
||||
async def test_cassandra_cache_async(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(
|
||||
session=session, keyspace=keyspace, setup_mode=SetupMode.ASYNC
|
||||
)
|
||||
set_llm_cache(cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||
output = await llm.agenerate(["foo"])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
await cache.aclear()
|
||||
|
||||
|
||||
def test_cassandra_cache_ttl(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(session=session, keyspace=keyspace, ttl_seconds=2)
|
||||
set_llm_cache(cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
output = llm.generate(["foo"])
|
||||
assert output == expected_output
|
||||
time.sleep(2.5)
|
||||
# entry has expired away.
|
||||
output = llm.generate(["foo"])
|
||||
assert output != expected_output
|
||||
cache.clear()
|
||||
|
||||
|
||||
async def test_cassandra_cache_ttl_async(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
cache = CassandraCache(
|
||||
session=session, keyspace=keyspace, ttl_seconds=2, setup_mode=SetupMode.ASYNC
|
||||
)
|
||||
set_llm_cache(cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
output = await llm.agenerate(["foo"])
|
||||
assert output == expected_output
|
||||
await asyncio.sleep(2.5)
|
||||
# entry has expired away.
|
||||
output = await llm.agenerate(["foo"])
|
||||
assert output != expected_output
|
||||
await cache.aclear()
|
||||
|
||||
|
||||
def test_cassandra_semantic_cache(cassandra_connection: Tuple[Any, str]) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
sem_cache = CassandraSemanticCache(
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
embedding=FakeEmbeddings(),
|
||||
)
|
||||
set_llm_cache(sem_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["bar"]) # same embedding as 'foo'
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
sem_cache.clear()
|
||||
output = llm.generate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != expected_output
|
||||
sem_cache.clear()
|
||||
|
||||
|
||||
async def test_cassandra_semantic_cache_async(
|
||||
cassandra_connection: Tuple[Any, str],
|
||||
) -> None:
|
||||
session, keyspace = cassandra_connection
|
||||
sem_cache = CassandraSemanticCache(
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
embedding=FakeEmbeddings(),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
set_llm_cache(sem_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")])
|
||||
output = await llm.agenerate(["bar"]) # same embedding as 'foo'
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
await sem_cache.aclear()
|
||||
output = await llm.agenerate(["bar"]) # 'fizz' is erased away now
|
||||
assert output != expected_output
|
||||
await sem_cache.aclear()
|
||||
@@ -1,62 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import pytest
|
||||
from langchain_community.cache import GPTCache
|
||||
from langchain_core.outputs import Generation
|
||||
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
try:
|
||||
from gptcache import Cache # noqa: F401
|
||||
from gptcache.manager.factory import get_data_manager
|
||||
from gptcache.processor.pre import get_prompt
|
||||
|
||||
gptcache_installed = True
|
||||
except ImportError:
|
||||
gptcache_installed = False
|
||||
|
||||
|
||||
def init_gptcache_map(cache_obj: Any) -> None:
|
||||
i = getattr(init_gptcache_map, "_i", 0)
|
||||
cache_path = f"data_map_{i}.txt"
|
||||
if os.path.isfile(cache_path):
|
||||
os.remove(cache_path)
|
||||
cache_obj.init(
|
||||
pre_embedding_func=get_prompt,
|
||||
data_manager=get_data_manager(data_path=cache_path),
|
||||
)
|
||||
init_gptcache_map._i = i + 1 # type: ignore
|
||||
|
||||
|
||||
def init_gptcache_map_with_llm(cache_obj: Any, llm: str) -> None:
|
||||
cache_path = f"data_map_{llm}.txt"
|
||||
if os.path.isfile(cache_path):
|
||||
os.remove(cache_path)
|
||||
cache_obj.init(
|
||||
pre_embedding_func=get_prompt,
|
||||
data_manager=get_data_manager(data_path=cache_path),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
|
||||
@pytest.mark.parametrize(
|
||||
"init_func", [None, init_gptcache_map, init_gptcache_map_with_llm]
|
||||
)
|
||||
def test_gptcache_caching(
|
||||
init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None],
|
||||
) -> None:
|
||||
"""Test gptcache default caching behavior."""
|
||||
set_llm_cache(GPTCache(init_func))
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
_ = llm.generate(["foo", "bar", "foo"])
|
||||
cache_output = get_llm_cache().lookup("foo", llm_string)
|
||||
assert cache_output == [Generation(text="fizz")]
|
||||
|
||||
get_llm_cache().clear()
|
||||
assert get_llm_cache().lookup("bar", llm_string) is None
|
||||
@@ -1,97 +0,0 @@
|
||||
"""Test Momento cache functionality.
|
||||
|
||||
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
|
||||
Momento auth token. This can be obtained by signing up for a free
|
||||
Momento account at https://gomomento.com/.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from langchain_community.cache import MomentoCache
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
from langchain.globals import set_llm_cache
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def momento_cache() -> Iterator[MomentoCache]:
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
|
||||
cache_name = f"langchain-test-cache-{random_string()}"
|
||||
client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_API_KEY"),
|
||||
default_ttl=timedelta(seconds=30),
|
||||
)
|
||||
try:
|
||||
llm_cache = MomentoCache(client, cache_name)
|
||||
set_llm_cache(llm_cache)
|
||||
yield llm_cache
|
||||
finally:
|
||||
client.delete_cache(cache_name)
|
||||
|
||||
|
||||
def test_invalid_ttl() -> None:
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
|
||||
client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_API_KEY"),
|
||||
default_ttl=timedelta(seconds=30),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
MomentoCache(client, cache_name=random_string(), ttl=timedelta(seconds=-1))
|
||||
|
||||
|
||||
def test_momento_cache_miss(momento_cache: MomentoCache) -> None:
|
||||
llm = FakeLLM()
|
||||
stub_llm_output = LLMResult(generations=[[Generation(text="foo")]])
|
||||
assert llm.generate([random_string()]) == stub_llm_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompts, generations",
|
||||
[
|
||||
# Single prompt, single generation
|
||||
([random_string()], [[random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string(), random_string()]]),
|
||||
# Multiple prompts, multiple generations
|
||||
(
|
||||
[random_string(), random_string()],
|
||||
[[random_string()], [random_string(), random_string()]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_momento_cache_hit(
|
||||
momento_cache: MomentoCache, prompts: list[str], generations: list[list[str]]
|
||||
) -> None:
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
llm_generations = [
|
||||
[
|
||||
Generation(text=generation, generation_info=params)
|
||||
for generation in prompt_i_generations
|
||||
]
|
||||
for prompt_i_generations in generations
|
||||
]
|
||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||
momento_cache.update(prompt_i, llm_string, llm_generations_i)
|
||||
|
||||
assert llm.generate(prompts) == LLMResult(
|
||||
generations=llm_generations, llm_output={}
|
||||
)
|
||||
@@ -1,59 +0,0 @@
|
||||
from langchain_community.cache import OpenSearchSemanticCache
|
||||
from langchain_core.outputs import Generation
|
||||
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from tests.integration_tests.cache.fake_embeddings import (
|
||||
FakeEmbeddings,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
DEFAULT_OPENSEARCH_URL = "http://localhost:9200"
|
||||
|
||||
|
||||
def test_opensearch_semantic_cache() -> None:
|
||||
"""Test opensearch semantic cache functionality."""
|
||||
set_llm_cache(
|
||||
OpenSearchSemanticCache(
|
||||
embedding=FakeEmbeddings(),
|
||||
opensearch_url=DEFAULT_OPENSEARCH_URL,
|
||||
score_threshold=0.0,
|
||||
)
|
||||
)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz")]
|
||||
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert output != [Generation(text="fizz")]
|
||||
|
||||
|
||||
def test_opensearch_semantic_cache_multi() -> None:
|
||||
set_llm_cache(
|
||||
OpenSearchSemanticCache(
|
||||
embedding=FakeEmbeddings(),
|
||||
opensearch_url=DEFAULT_OPENSEARCH_URL,
|
||||
score_threshold=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
|
||||
# foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
|
||||
# clear the cache
|
||||
get_llm_cache().clear(llm_string=llm_string)
|
||||
output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert output != [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
@@ -1,319 +0,0 @@
|
||||
"""Test Redis cache functionality."""
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import AsyncGenerator, Generator, List, Optional, cast
|
||||
|
||||
import pytest
|
||||
from langchain_community.cache import AsyncRedisCache, RedisCache, RedisSemanticCache
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.load.dump import dumps
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
||||
|
||||
from langchain.globals import get_llm_cache, set_llm_cache
|
||||
from tests.integration_tests.cache.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
# Using a non-standard port to avoid conflicts with potentially local running
|
||||
# redis instances
|
||||
# You can spin up a local redis using docker compose
|
||||
# cd [repository-root]/docker
|
||||
# docker-compose up redis
|
||||
REDIS_TEST_URL = "redis://localhost:6020"
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_sync_redis(*, ttl: Optional[int] = 1) -> Generator[RedisCache, None, None]:
|
||||
"""Get a sync RedisCache instance."""
|
||||
import redis
|
||||
|
||||
cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL), ttl=ttl)
|
||||
try:
|
||||
yield cache
|
||||
finally:
|
||||
cache.clear()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_redis(
|
||||
*, ttl: Optional[int] = 1
|
||||
) -> AsyncGenerator[AsyncRedisCache, None]:
|
||||
"""Get an async RedisCache instance."""
|
||||
from redis.asyncio import Redis
|
||||
|
||||
cache = AsyncRedisCache(redis_=Redis.from_url(REDIS_TEST_URL), ttl=ttl)
|
||||
try:
|
||||
yield cache
|
||||
finally:
|
||||
await cache.aclear()
|
||||
|
||||
|
||||
def test_redis_cache_ttl() -> None:
|
||||
from redis import Redis
|
||||
|
||||
with get_sync_redis() as llm_cache:
|
||||
set_llm_cache(llm_cache)
|
||||
llm_cache.update("foo", "bar", [Generation(text="fizz")])
|
||||
key = llm_cache._key("foo", "bar")
|
||||
assert isinstance(llm_cache.redis, Redis)
|
||||
assert llm_cache.redis.pttl(key) > 0
|
||||
|
||||
|
||||
async def test_async_redis_cache_ttl() -> None:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
async with get_async_redis() as redis_cache:
|
||||
set_llm_cache(redis_cache)
|
||||
llm_cache = cast(RedisCache, get_llm_cache())
|
||||
await llm_cache.aupdate("foo", "bar", [Generation(text="fizz")])
|
||||
key = llm_cache._key("foo", "bar")
|
||||
assert isinstance(llm_cache.redis, AsyncRedis)
|
||||
assert await llm_cache.redis.pttl(key) > 0
|
||||
|
||||
|
||||
def test_sync_redis_cache() -> None:
|
||||
with get_sync_redis() as llm_cache:
|
||||
set_llm_cache(llm_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_cache.update("prompt", llm_string, [Generation(text="fizz0")])
|
||||
output = llm.generate(["prompt"])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz0")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
async def test_sync_in_async_redis_cache() -> None:
|
||||
"""Test the sync RedisCache invoked with async methods"""
|
||||
with get_sync_redis() as llm_cache:
|
||||
set_llm_cache(llm_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
# llm_cache.update("meow", llm_string, [Generation(text="meow")])
|
||||
await llm_cache.aupdate("prompt", llm_string, [Generation(text="fizz1")])
|
||||
output = await llm.agenerate(["prompt"])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz1")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
async def test_async_redis_cache() -> None:
|
||||
async with get_async_redis() as redis_cache:
|
||||
set_llm_cache(redis_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_cache = cast(RedisCache, get_llm_cache())
|
||||
await llm_cache.aupdate("prompt", llm_string, [Generation(text="fizz2")])
|
||||
output = await llm.agenerate(["prompt"])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz2")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
async def test_async_in_sync_redis_cache() -> None:
|
||||
async with get_async_redis() as redis_cache:
|
||||
set_llm_cache(redis_cache)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_cache = cast(RedisCache, get_llm_cache())
|
||||
with pytest.raises(NotImplementedError):
|
||||
llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
|
||||
|
||||
def test_redis_cache_chat() -> None:
|
||||
with get_sync_redis() as redis_cache:
|
||||
set_llm_cache(redis_cache)
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||
llm_cache = cast(RedisCache, get_llm_cache())
|
||||
llm_cache.update(
|
||||
dumps(prompt),
|
||||
llm_string,
|
||||
[ChatGeneration(message=AIMessage(content="fizz"))],
|
||||
)
|
||||
output = llm.generate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
async def test_async_redis_cache_chat() -> None:
|
||||
async with get_async_redis() as redis_cache:
|
||||
set_llm_cache(redis_cache)
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||
llm_cache = cast(RedisCache, get_llm_cache())
|
||||
await llm_cache.aupdate(
|
||||
dumps(prompt),
|
||||
llm_string,
|
||||
[ChatGeneration(message=AIMessage(content="fizz"))],
|
||||
)
|
||||
output = await llm.agenerate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_redis_semantic_cache() -> None:
|
||||
"""Test redis semantic cache functionality."""
|
||||
set_llm_cache(
|
||||
RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_cache = cast(RedisSemanticCache, get_llm_cache())
|
||||
llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(
|
||||
["bar"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
llm_cache.clear(llm_string=llm_string)
|
||||
output = llm.generate(
|
||||
["bar"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
# expect different output now without cached result
|
||||
assert output != expected_output
|
||||
llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_multi() -> None:
|
||||
set_llm_cache(
|
||||
RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_cache = cast(RedisSemanticCache, get_llm_cache())
|
||||
llm_cache.update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
output = llm.generate(
|
||||
["bar"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_chat() -> None:
|
||||
set_llm_cache(
|
||||
RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
)
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
prompt: List[BaseMessage] = [HumanMessage(content="foo")]
|
||||
llm_cache = cast(RedisSemanticCache, get_llm_cache())
|
||||
llm_cache.update(
|
||||
dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))]
|
||||
)
|
||||
output = llm.generate([prompt])
|
||||
expected_output = LLMResult(
|
||||
generations=[[ChatGeneration(message=AIMessage(content="fizz"))]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
|
||||
@pytest.mark.parametrize(
|
||||
"prompts, generations",
|
||||
[
|
||||
# Single prompt, single generation
|
||||
([random_string()], [[random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string(), random_string()]]),
|
||||
# Multiple prompts, multiple generations
|
||||
(
|
||||
[random_string(), random_string()],
|
||||
[[random_string()], [random_string(), random_string()]],
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"single_prompt_single_generation",
|
||||
"single_prompt_multiple_generations",
|
||||
"single_prompt_multiple_generations",
|
||||
"multiple_prompts_multiple_generations",
|
||||
],
|
||||
)
|
||||
def test_redis_semantic_cache_hit(
|
||||
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
|
||||
) -> None:
|
||||
set_llm_cache(RedisSemanticCache(embedding=embedding, redis_url=REDIS_TEST_URL))
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
llm_generations = [
|
||||
[
|
||||
Generation(text=generation, generation_info=params)
|
||||
for generation in prompt_i_generations
|
||||
]
|
||||
for prompt_i_generations in generations
|
||||
]
|
||||
llm_cache = cast(RedisSemanticCache, get_llm_cache())
|
||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||
print(prompt_i) # noqa: T201
|
||||
print(llm_generations_i) # noqa: T201
|
||||
llm_cache.update(prompt_i, llm_string, llm_generations_i)
|
||||
llm.generate(prompts)
|
||||
assert llm.generate(prompts) == LLMResult(
|
||||
generations=llm_generations, llm_output={}
|
||||
)
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Test Upstash Redis cache functionality."""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from langchain_community.cache import UpstashRedisCache
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
import langchain
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
URL = "<UPSTASH_REDIS_REST_URL>"
|
||||
TOKEN = "<UPSTASH_REDIS_REST_TOKEN>"
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.mark.requires("upstash_redis")
|
||||
def test_redis_cache_ttl() -> None:
|
||||
from upstash_redis import Redis
|
||||
|
||||
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
|
||||
langchain.llm_cache.update("foo", "bar", [Generation(text="fizz")])
|
||||
key = langchain.llm_cache._key("foo", "bar")
|
||||
assert langchain.llm_cache.redis.pttl(key) > 0
|
||||
|
||||
|
||||
@pytest.mark.requires("upstash_redis")
|
||||
def test_redis_cache() -> None:
|
||||
from upstash_redis import Redis
|
||||
|
||||
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
lookup_output = langchain.llm_cache.lookup("foo", llm_string)
|
||||
if lookup_output and len(lookup_output) > 0:
|
||||
assert lookup_output == expected_output.generations[0]
|
||||
|
||||
langchain.llm_cache.clear()
|
||||
output = llm.generate(["foo"])
|
||||
|
||||
assert output != expected_output
|
||||
langchain.llm_cache.redis.flushall()
|
||||
|
||||
|
||||
def test_redis_cache_multi() -> None:
|
||||
from upstash_redis import Redis
|
||||
|
||||
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
output = llm.generate(
|
||||
["foo"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
langchain.llm_cache.clear()
|
||||
|
||||
|
||||
@pytest.mark.requires("upstash_redis")
|
||||
def test_redis_cache_chat() -> None:
|
||||
from upstash_redis import Redis
|
||||
|
||||
langchain.llm_cache = UpstashRedisCache(redis_=Redis(url=URL, token=TOKEN), ttl=1)
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm.invoke("foo")
|
||||
langchain.llm_cache.redis.flushall()
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Integration test for Dall-E image generator agent."""
|
||||
from langchain_community.agent_toolkits.load_tools import load_tools
|
||||
from langchain_community.llms import OpenAI
|
||||
|
||||
from langchain.agents import AgentType, initialize_agent
|
||||
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test that the agent runs and returns output."""
|
||||
llm = OpenAI(temperature=0.9)
|
||||
tools = load_tools(["dalle-image-generator"])
|
||||
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
|
||||
output = agent.run("Create an image of a volcano island")
|
||||
assert output is not None
|
||||
@@ -1,337 +0,0 @@
|
||||
"""Test Graph Database Chain."""
|
||||
import os
|
||||
|
||||
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
|
||||
from langchain.chains.loading import load_chain
|
||||
|
||||
|
||||
def test_connect_neo4j() -> None:
|
||||
"""Test that Neo4j database is correctly instantiated and connected."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
|
||||
output = graph.query(
|
||||
"""
|
||||
RETURN "test" AS output
|
||||
"""
|
||||
)
|
||||
expected_output = [{"output": "test"}]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_connect_neo4j_env() -> None:
|
||||
"""Test that Neo4j database environment variables."""
|
||||
graph = Neo4jGraph()
|
||||
|
||||
output = graph.query(
|
||||
"""
|
||||
RETURN "test" AS output
|
||||
"""
|
||||
)
|
||||
expected_output = [{"output": "test"}]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_cypher_generating_run() -> None:
|
||||
"""Test that Cypher statement is correctly generated and executed."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||
)
|
||||
# Refresh schema information
|
||||
graph.refresh_schema()
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(OpenAI(temperature=0), graph=graph)
|
||||
output = chain.run("Who played in Pulp Fiction?")
|
||||
expected_output = " Bruce Willis played in Pulp Fiction."
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_cypher_top_k() -> None:
|
||||
"""Test top_k parameter correctly limits the number of results in the context."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
TOP_K = 1
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||
"<-[:ACTED_IN]-(:Actor {name:'Foo'})"
|
||||
)
|
||||
# Refresh schema information
|
||||
graph.refresh_schema()
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
OpenAI(temperature=0), graph=graph, return_direct=True, top_k=TOP_K
|
||||
)
|
||||
output = chain.run("Who played in Pulp Fiction?")
|
||||
assert len(output) == TOP_K
|
||||
|
||||
|
||||
def test_cypher_intermediate_steps() -> None:
|
||||
"""Test the returning of the intermediate steps."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||
)
|
||||
# Refresh schema information
|
||||
graph.refresh_schema()
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
OpenAI(temperature=0), graph=graph, return_intermediate_steps=True
|
||||
)
|
||||
output = chain("Who played in Pulp Fiction?")
|
||||
|
||||
expected_output = " Bruce Willis played in Pulp Fiction."
|
||||
assert output["result"] == expected_output
|
||||
|
||||
query = output["intermediate_steps"][0]["query"]
|
||||
# LLM can return variations of the same query
|
||||
expected_queries = [
|
||||
(
|
||||
"\n\nMATCH (a:Actor)-[:ACTED_IN]->"
|
||||
"(m:Movie {title: 'Pulp Fiction'}) RETURN a.name"
|
||||
),
|
||||
(
|
||||
"\n\nMATCH (a:Actor)-[:ACTED_IN]->"
|
||||
"(m:Movie {title: 'Pulp Fiction'}) RETURN a.name;"
|
||||
),
|
||||
(
|
||||
"\n\nMATCH (a:Actor)-[:ACTED_IN]->"
|
||||
"(m:Movie) WHERE m.title = 'Pulp Fiction' RETURN a.name"
|
||||
),
|
||||
]
|
||||
|
||||
assert query in expected_queries
|
||||
|
||||
context = output["intermediate_steps"][1]["context"]
|
||||
expected_context = [{"a.name": "Bruce Willis"}]
|
||||
assert context == expected_context
|
||||
|
||||
|
||||
def test_cypher_return_direct() -> None:
|
||||
"""Test that chain returns direct results."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||
)
|
||||
# Refresh schema information
|
||||
graph.refresh_schema()
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
OpenAI(temperature=0), graph=graph, return_direct=True
|
||||
)
|
||||
output = chain.run("Who played in Pulp Fiction?")
|
||||
expected_output = [{"a.name": "Bruce Willis"}]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_cypher_save_load() -> None:
|
||||
"""Test saving and loading."""
|
||||
|
||||
FILE_PATH = "cypher.yaml"
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
OpenAI(temperature=0), graph=graph, return_direct=True
|
||||
)
|
||||
|
||||
chain.save(file_path=FILE_PATH)
|
||||
qa_loaded = load_chain(FILE_PATH, graph=graph)
|
||||
|
||||
assert qa_loaded == chain
|
||||
|
||||
|
||||
def test_exclude_types() -> None:
|
||||
"""Test exclude types from schema."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||
"<-[:DIRECTED]-(p:Person {name:'John'})"
|
||||
)
|
||||
# Refresh schema information
|
||||
graph.refresh_schema()
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
OpenAI(temperature=0), graph=graph, exclude_types=["Person", "DIRECTED"]
|
||||
)
|
||||
expected_schema = (
|
||||
"Node properties are the following:\n"
|
||||
"Movie {title: STRING},Actor {name: STRING}\n"
|
||||
"Relationship properties are the following:\n\n"
|
||||
"The relationships are the following:\n"
|
||||
"(:Actor)-[:ACTED_IN]->(:Movie)"
|
||||
)
|
||||
assert chain.graph_schema == expected_schema
|
||||
|
||||
|
||||
def test_include_types() -> None:
|
||||
"""Test include types from schema."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||
"<-[:DIRECTED]-(p:Person {name:'John'})"
|
||||
)
|
||||
# Refresh schema information
|
||||
graph.refresh_schema()
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
OpenAI(temperature=0), graph=graph, include_types=["Movie", "Actor", "ACTED_IN"]
|
||||
)
|
||||
expected_schema = (
|
||||
"Node properties are the following:\n"
|
||||
"Movie {title: STRING},Actor {name: STRING}\n"
|
||||
"Relationship properties are the following:\n\n"
|
||||
"The relationships are the following:\n"
|
||||
"(:Actor)-[:ACTED_IN]->(:Movie)"
|
||||
)
|
||||
|
||||
assert chain.graph_schema == expected_schema
|
||||
|
||||
|
||||
def test_include_types2() -> None:
|
||||
"""Test include types from schema."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Create two nodes and a relationship
|
||||
graph.query(
|
||||
"CREATE (a:Actor {name:'Bruce Willis'})"
|
||||
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
|
||||
"<-[:DIRECTED]-(p:Person {name:'John'})"
|
||||
)
|
||||
# Refresh schema information
|
||||
graph.refresh_schema()
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
OpenAI(temperature=0), graph=graph, include_types=["Movie", "ACTED_IN"]
|
||||
)
|
||||
expected_schema = (
|
||||
"Node properties are the following:\n"
|
||||
"Movie {title: STRING}\n"
|
||||
"Relationship properties are the following:\n\n"
|
||||
"The relationships are the following:\n"
|
||||
)
|
||||
assert chain.graph_schema == expected_schema
|
||||
@@ -1,93 +0,0 @@
|
||||
"""Test Graph Database Chain."""
|
||||
from typing import Any
|
||||
|
||||
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
|
||||
from langchain_community.graphs import ArangoGraph
|
||||
from langchain_community.graphs.arangodb_graph import get_arangodb_client
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
|
||||
|
||||
def populate_arangodb_database(db: Any) -> None:
|
||||
if db.has_graph("GameOfThrones"):
|
||||
return
|
||||
|
||||
db.create_graph(
|
||||
"GameOfThrones",
|
||||
edge_definitions=[
|
||||
{
|
||||
"edge_collection": "ChildOf",
|
||||
"from_vertex_collections": ["Characters"],
|
||||
"to_vertex_collections": ["Characters"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
documents = [
|
||||
{
|
||||
"_key": "NedStark",
|
||||
"name": "Ned",
|
||||
"surname": "Stark",
|
||||
"alive": True,
|
||||
"age": 41,
|
||||
"gender": "male",
|
||||
},
|
||||
{
|
||||
"_key": "AryaStark",
|
||||
"name": "Arya",
|
||||
"surname": "Stark",
|
||||
"alive": True,
|
||||
"age": 11,
|
||||
"gender": "female",
|
||||
},
|
||||
]
|
||||
|
||||
edges = [{"_to": "Characters/NedStark", "_from": "Characters/AryaStark"}]
|
||||
|
||||
db.collection("Characters").import_bulk(documents)
|
||||
db.collection("ChildOf").import_bulk(edges)
|
||||
|
||||
|
||||
def test_connect_arangodb() -> None:
|
||||
"""Test that the ArangoDB database is correctly instantiated and connected."""
|
||||
graph = ArangoGraph(get_arangodb_client())
|
||||
|
||||
sample_aql_result = graph.query("RETURN 'hello world'")
|
||||
assert ["hello_world"] == sample_aql_result
|
||||
|
||||
|
||||
def test_empty_schema_on_no_data() -> None:
|
||||
"""Test that the schema is empty for an empty ArangoDB Database"""
|
||||
db = get_arangodb_client()
|
||||
db.delete_graph("GameOfThrones", drop_collections=True, ignore_missing=True)
|
||||
db.delete_collection("empty_collection", ignore_missing=True)
|
||||
db.create_collection("empty_collection")
|
||||
|
||||
graph = ArangoGraph(db)
|
||||
|
||||
assert graph.schema == {
|
||||
"Graph Schema": [],
|
||||
"Collection Schema": [],
|
||||
}
|
||||
|
||||
|
||||
def test_aql_generation() -> None:
|
||||
"""Test that AQL statement is correctly generated and executed."""
|
||||
db = get_arangodb_client()
|
||||
|
||||
populate_arangodb_database(db)
|
||||
|
||||
graph = ArangoGraph(db)
|
||||
chain = ArangoGraphQAChain.from_llm(OpenAI(temperature=0), graph=graph)
|
||||
chain.return_aql_result = True
|
||||
|
||||
output = chain("Is Ned Stark alive?")
|
||||
assert output["aql_result"] == [True]
|
||||
assert "Yes" in output["result"]
|
||||
|
||||
output = chain("How old is Arya Stark?")
|
||||
assert output["aql_result"] == [11]
|
||||
assert "11" in output["result"]
|
||||
|
||||
output = chain("What is the relationship between Arya Stark and Ned Stark?")
|
||||
assert len(output["aql_result"]) == 1
|
||||
assert "child of" in output["result"]
|
||||
@@ -1,268 +0,0 @@
|
||||
"""Test RDF/ SPARQL Graph Database Chain."""
|
||||
import pathlib
|
||||
import re
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from langchain_community.chains.graph_qa.sparql import GraphSparqlQAChain
|
||||
from langchain_community.graphs import RdfGraph
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
|
||||
"""
|
||||
cd libs/langchain/tests/integration_tests/chains/docker-compose-ontotext-graphdb
|
||||
./start.sh
|
||||
"""
|
||||
|
||||
|
||||
def test_connect_file_rdf() -> None:
|
||||
"""
|
||||
Test loading online resource.
|
||||
"""
|
||||
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
|
||||
|
||||
graph = RdfGraph(
|
||||
source_file=berners_lee_card,
|
||||
standard="rdf",
|
||||
)
|
||||
|
||||
query = """SELECT ?s ?p ?o\n""" """WHERE { ?s ?p ?o }"""
|
||||
|
||||
output = graph.query(query)
|
||||
assert len(output) == 86
|
||||
|
||||
|
||||
def test_sparql_select() -> None:
|
||||
"""
|
||||
Test for generating and executing simple SPARQL SELECT query.
|
||||
"""
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
|
||||
|
||||
graph = RdfGraph(
|
||||
source_file=berners_lee_card,
|
||||
standard="rdf",
|
||||
)
|
||||
|
||||
question = "What is Tim Berners-Lee's work homepage?"
|
||||
answer = "Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/."
|
||||
|
||||
chain = GraphSparqlQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
)
|
||||
chain.sparql_intent_chain = Mock(LLMChain)
|
||||
chain.sparql_generation_select_chain = Mock(LLMChain)
|
||||
chain.sparql_generation_update_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_intent_chain.run = Mock(return_value="SELECT")
|
||||
chain.sparql_generation_select_chain.run = Mock(
|
||||
return_value="""PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
||||
SELECT ?workHomepage
|
||||
WHERE {
|
||||
?person rdfs:label "Tim Berners-Lee" .
|
||||
?person foaf:workplaceHomepage ?workHomepage .
|
||||
}"""
|
||||
)
|
||||
chain.qa_chain = MagicMock(
|
||||
return_value={
|
||||
"text": answer,
|
||||
"prompt": question,
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
chain.qa_chain.output_key = "text"
|
||||
|
||||
output = chain.invoke({chain.input_key: question})[chain.output_key]
|
||||
assert output == answer
|
||||
|
||||
assert chain.sparql_intent_chain.run.call_count == 1
|
||||
assert chain.sparql_generation_select_chain.run.call_count == 1
|
||||
assert chain.sparql_generation_update_chain.run.call_count == 0
|
||||
assert chain.qa_chain.call_count == 1
|
||||
|
||||
|
||||
def test_sparql_insert(tmp_path: pathlib.Path) -> None:
|
||||
"""
|
||||
Test for generating and executing simple SPARQL INSERT query.
|
||||
"""
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
|
||||
local_copy = tmp_path / "test.ttl"
|
||||
|
||||
graph = RdfGraph(
|
||||
source_file=berners_lee_card,
|
||||
standard="rdf",
|
||||
local_copy=str(local_copy),
|
||||
)
|
||||
|
||||
query = (
|
||||
"Save that the person with the name 'Timothy Berners-Lee' "
|
||||
"has a work homepage at 'http://www.w3.org/foo/bar/'"
|
||||
)
|
||||
|
||||
chain = GraphSparqlQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
)
|
||||
chain.sparql_intent_chain = Mock(LLMChain)
|
||||
chain.sparql_generation_select_chain = Mock(LLMChain)
|
||||
chain.sparql_generation_update_chain = Mock(LLMChain)
|
||||
chain.qa_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_intent_chain.run = Mock(return_value="UPDATE")
|
||||
chain.sparql_generation_update_chain.run = Mock(
|
||||
return_value="""PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
INSERT {
|
||||
?p foaf:workplaceHomepage <http://www.w3.org/foo/bar/> .
|
||||
}
|
||||
WHERE {
|
||||
?p foaf:name "Timothy Berners-Lee" .
|
||||
}"""
|
||||
)
|
||||
|
||||
output = chain.invoke({chain.input_key: query})[chain.output_key]
|
||||
assert output == "Successfully inserted triples into the graph."
|
||||
|
||||
assert chain.sparql_intent_chain.run.call_count == 1
|
||||
assert chain.sparql_generation_select_chain.run.call_count == 0
|
||||
assert chain.sparql_generation_update_chain.run.call_count == 1
|
||||
assert chain.qa_chain.call_count == 0
|
||||
|
||||
query = (
|
||||
"""PREFIX foaf: <http://xmlns.com/foaf/0.1/>\n"""
|
||||
"""SELECT ?hp\n"""
|
||||
"""WHERE {\n"""
|
||||
""" ?person foaf:name "Timothy Berners-Lee" . \n"""
|
||||
""" ?person foaf:workplaceHomepage ?hp .\n"""
|
||||
"""}"""
|
||||
)
|
||||
output = graph.query(query)
|
||||
assert len(output) == 2
|
||||
|
||||
|
||||
def test_sparql_select_return_query() -> None:
|
||||
"""
|
||||
Test for generating and executing simple SPARQL SELECT query
|
||||
and returning the generated SPARQL query.
|
||||
"""
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
|
||||
|
||||
graph = RdfGraph(
|
||||
source_file=berners_lee_card,
|
||||
standard="rdf",
|
||||
)
|
||||
|
||||
question = "What is Tim Berners-Lee's work homepage?"
|
||||
answer = "Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/."
|
||||
|
||||
chain = GraphSparqlQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
return_sparql_query=True,
|
||||
)
|
||||
chain.sparql_intent_chain = Mock(LLMChain)
|
||||
chain.sparql_generation_select_chain = Mock(LLMChain)
|
||||
chain.sparql_generation_update_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_intent_chain.run = Mock(return_value="SELECT")
|
||||
chain.sparql_generation_select_chain.run = Mock(
|
||||
return_value="""PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
||||
SELECT ?workHomepage
|
||||
WHERE {
|
||||
?person rdfs:label "Tim Berners-Lee" .
|
||||
?person foaf:workplaceHomepage ?workHomepage .
|
||||
}"""
|
||||
)
|
||||
chain.qa_chain = MagicMock(
|
||||
return_value={
|
||||
"text": answer,
|
||||
"prompt": question,
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
chain.qa_chain.output_key = "text"
|
||||
|
||||
output = chain.invoke({chain.input_key: question})
|
||||
assert output[chain.output_key] == answer
|
||||
assert "sparql_query" in output
|
||||
|
||||
assert chain.sparql_intent_chain.run.call_count == 1
|
||||
assert chain.sparql_generation_select_chain.run.call_count == 1
|
||||
assert chain.sparql_generation_update_chain.run.call_count == 0
|
||||
assert chain.qa_chain.call_count == 1
|
||||
|
||||
|
||||
def test_loading_schema_from_ontotext_graphdb() -> None:
|
||||
graph = RdfGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/langchain",
|
||||
graph_kwargs={"bind_namespaces": "none"},
|
||||
)
|
||||
schema = graph.get_schema
|
||||
prefix = (
|
||||
"In the following, each IRI is followed by the local name and "
|
||||
"optionally its description in parentheses. \n"
|
||||
"The RDF graph supports the following node types:"
|
||||
)
|
||||
assert schema.startswith(prefix)
|
||||
|
||||
infix = "The RDF graph supports the following relationships:"
|
||||
assert infix in schema
|
||||
|
||||
classes = schema[len(prefix) : schema.index(infix)]
|
||||
assert len(re.findall("<[^>]+> \\([^)]+\\)", classes)) == 5
|
||||
|
||||
relationships = schema[schema.index(infix) + len(infix) :]
|
||||
assert len(re.findall("<[^>]+> \\([^)]+\\)", relationships)) == 58
|
||||
|
||||
|
||||
def test_graph_qa_chain_with_ontotext_graphdb() -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
question = "What is Tim Berners-Lee's work homepage?"
|
||||
answer = "Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/."
|
||||
|
||||
graph = RdfGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/langchain",
|
||||
graph_kwargs={"bind_namespaces": "none"},
|
||||
)
|
||||
|
||||
chain = GraphSparqlQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
)
|
||||
chain.sparql_intent_chain = Mock(LLMChain)
|
||||
chain.sparql_generation_select_chain = Mock(LLMChain)
|
||||
chain.sparql_generation_update_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_intent_chain.run = Mock(return_value="SELECT")
|
||||
chain.sparql_generation_select_chain.run = Mock(
|
||||
return_value="""PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
||||
SELECT ?workHomepage
|
||||
WHERE {
|
||||
?person rdfs:label "Tim Berners-Lee" .
|
||||
?person foaf:workplaceHomepage ?workHomepage .
|
||||
}"""
|
||||
)
|
||||
chain.qa_chain = MagicMock(
|
||||
return_value={
|
||||
"text": answer,
|
||||
"prompt": question,
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
chain.qa_chain.output_key = "text"
|
||||
|
||||
output = chain.invoke({chain.input_key: question})[chain.output_key]
|
||||
assert output == answer
|
||||
|
||||
assert chain.sparql_intent_chain.run.call_count == 1
|
||||
assert chain.sparql_generation_select_chain.run.call_count == 1
|
||||
assert chain.sparql_generation_update_chain.run.call_count == 0
|
||||
assert chain.qa_chain.call_count == 1
|
||||
@@ -1,385 +0,0 @@
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from langchain_community.chains.graph_qa.ontotext_graphdb import OntotextGraphDBQAChain
|
||||
from langchain_community.graphs import OntotextGraphDBGraph
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
|
||||
"""
|
||||
cd libs/langchain/tests/integration_tests/chains/docker-compose-ontotext-graphdb
|
||||
./start.sh
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||
@pytest.mark.parametrize("max_fix_retries", [-2, -1, 0, 1, 2])
|
||||
def test_valid_sparql(max_fix_retries: int) -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
question = "What is Luke Skywalker's home planet?"
|
||||
answer = "Tatooine"
|
||||
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/starwars",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o} "
|
||||
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||
)
|
||||
chain = OntotextGraphDBQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
max_fix_retries=max_fix_retries,
|
||||
)
|
||||
chain.sparql_generation_chain = Mock(LLMChain)
|
||||
chain.sparql_fix_chain = Mock(LLMChain)
|
||||
chain.qa_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_generation_chain.output_key = "text"
|
||||
chain.sparql_generation_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": "SELECT * {?s ?p ?o} LIMIT 1",
|
||||
"prompt": question,
|
||||
"schema": "",
|
||||
}
|
||||
)
|
||||
chain.sparql_fix_chain.output_key = "text"
|
||||
chain.sparql_fix_chain.invoke = MagicMock()
|
||||
chain.qa_chain.output_key = "text"
|
||||
chain.qa_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": answer,
|
||||
"prompt": question,
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
|
||||
result = chain.invoke({chain.input_key: question})
|
||||
|
||||
assert chain.sparql_generation_chain.invoke.call_count == 1
|
||||
assert chain.sparql_fix_chain.invoke.call_count == 0
|
||||
assert chain.qa_chain.invoke.call_count == 1
|
||||
assert result == {chain.output_key: answer, chain.input_key: question}
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||
@pytest.mark.parametrize("max_fix_retries", [-2, -1, 0])
|
||||
def test_invalid_sparql_non_positive_max_fix_retries(
|
||||
max_fix_retries: int,
|
||||
) -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
question = "What is Luke Skywalker's home planet?"
|
||||
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/starwars",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o} "
|
||||
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||
)
|
||||
chain = OntotextGraphDBQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
max_fix_retries=max_fix_retries,
|
||||
)
|
||||
chain.sparql_generation_chain = Mock(LLMChain)
|
||||
chain.sparql_fix_chain = Mock(LLMChain)
|
||||
chain.qa_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_generation_chain.output_key = "text"
|
||||
chain.sparql_generation_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": "```sparql SELECT * {?s ?p ?o} LIMIT 1```",
|
||||
"prompt": question,
|
||||
"schema": "",
|
||||
}
|
||||
)
|
||||
chain.sparql_fix_chain.output_key = "text"
|
||||
chain.sparql_fix_chain.invoke = MagicMock()
|
||||
chain.qa_chain.output_key = "text"
|
||||
chain.qa_chain.invoke = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
chain.invoke({chain.input_key: question})
|
||||
|
||||
assert str(e.value) == "The generated SPARQL query is invalid."
|
||||
|
||||
assert chain.sparql_generation_chain.invoke.call_count == 1
|
||||
assert chain.sparql_fix_chain.invoke.call_count == 0
|
||||
assert chain.qa_chain.invoke.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
|
||||
def test_valid_sparql_after_first_retry(max_fix_retries: int) -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
question = "What is Luke Skywalker's home planet?"
|
||||
answer = "Tatooine"
|
||||
generated_invalid_sparql = "```sparql SELECT * {?s ?p ?o} LIMIT 1```"
|
||||
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/starwars",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o} "
|
||||
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||
)
|
||||
chain = OntotextGraphDBQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
max_fix_retries=max_fix_retries,
|
||||
)
|
||||
chain.sparql_generation_chain = Mock(LLMChain)
|
||||
chain.sparql_fix_chain = Mock(LLMChain)
|
||||
chain.qa_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_generation_chain.output_key = "text"
|
||||
chain.sparql_generation_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": generated_invalid_sparql,
|
||||
"prompt": question,
|
||||
"schema": "",
|
||||
}
|
||||
)
|
||||
chain.sparql_fix_chain.output_key = "text"
|
||||
chain.sparql_fix_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": "SELECT * {?s ?p ?o} LIMIT 1",
|
||||
"error_message": "pyparsing.exceptions.ParseException: "
|
||||
"Expected {SelectQuery | ConstructQuery | DescribeQuery | AskQuery}, "
|
||||
"found '`' (at char 0), (line:1, col:1)",
|
||||
"generated_sparql": generated_invalid_sparql,
|
||||
"schema": "",
|
||||
}
|
||||
)
|
||||
chain.qa_chain.output_key = "text"
|
||||
chain.qa_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": answer,
|
||||
"prompt": question,
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
|
||||
result = chain.invoke({chain.input_key: question})
|
||||
|
||||
assert chain.sparql_generation_chain.invoke.call_count == 1
|
||||
assert chain.sparql_fix_chain.invoke.call_count == 1
|
||||
assert chain.qa_chain.invoke.call_count == 1
|
||||
assert result == {chain.output_key: answer, chain.input_key: question}
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
|
||||
def test_invalid_sparql_server_response_400(max_fix_retries: int) -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
question = "Who is the oldest character?"
|
||||
generated_invalid_sparql = (
|
||||
"PREFIX : <https://swapi.co/vocabulary/> "
|
||||
"PREFIX owl: <http://www.w3.org/2002/07/owl#> "
|
||||
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
|
||||
"PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> "
|
||||
"SELECT ?character (MAX(?lifespan) AS ?maxLifespan) "
|
||||
"WHERE {"
|
||||
" ?species a :Species ;"
|
||||
" :character ?character ;"
|
||||
" :averageLifespan ?lifespan ."
|
||||
" FILTER(xsd:integer(?lifespan))"
|
||||
"} "
|
||||
"ORDER BY DESC(?maxLifespan) "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/starwars",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o} "
|
||||
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||
)
|
||||
chain = OntotextGraphDBQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
max_fix_retries=max_fix_retries,
|
||||
)
|
||||
chain.sparql_generation_chain = Mock(LLMChain)
|
||||
chain.sparql_fix_chain = Mock(LLMChain)
|
||||
chain.qa_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_generation_chain.output_key = "text"
|
||||
chain.sparql_generation_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": generated_invalid_sparql,
|
||||
"prompt": question,
|
||||
"schema": "",
|
||||
}
|
||||
)
|
||||
chain.sparql_fix_chain.output_key = "text"
|
||||
chain.sparql_fix_chain.invoke = MagicMock()
|
||||
chain.qa_chain.output_key = "text"
|
||||
chain.qa_chain.invoke = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
chain.invoke({chain.input_key: question})
|
||||
|
||||
assert str(e.value) == "Failed to execute the generated SPARQL query."
|
||||
|
||||
assert chain.sparql_generation_chain.invoke.call_count == 1
|
||||
assert chain.sparql_fix_chain.invoke.call_count == 0
|
||||
assert chain.qa_chain.invoke.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
|
||||
def test_invalid_sparql_after_all_retries(max_fix_retries: int) -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
question = "What is Luke Skywalker's home planet?"
|
||||
generated_invalid_sparql = "```sparql SELECT * {?s ?p ?o} LIMIT 1```"
|
||||
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/starwars",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o} "
|
||||
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||
)
|
||||
chain = OntotextGraphDBQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
max_fix_retries=max_fix_retries,
|
||||
)
|
||||
chain.sparql_generation_chain = Mock(LLMChain)
|
||||
chain.sparql_fix_chain = Mock(LLMChain)
|
||||
chain.qa_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_generation_chain.output_key = "text"
|
||||
chain.sparql_generation_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": generated_invalid_sparql,
|
||||
"prompt": question,
|
||||
"schema": "",
|
||||
}
|
||||
)
|
||||
chain.sparql_fix_chain.output_key = "text"
|
||||
chain.sparql_fix_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": generated_invalid_sparql,
|
||||
"error_message": "pyparsing.exceptions.ParseException: "
|
||||
"Expected {SelectQuery | ConstructQuery | DescribeQuery | AskQuery}, "
|
||||
"found '`' (at char 0), (line:1, col:1)",
|
||||
"generated_sparql": generated_invalid_sparql,
|
||||
"schema": "",
|
||||
}
|
||||
)
|
||||
chain.qa_chain.output_key = "text"
|
||||
chain.qa_chain.invoke = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
chain.invoke({chain.input_key: question})
|
||||
|
||||
assert str(e.value) == "The generated SPARQL query is invalid."
|
||||
|
||||
assert chain.sparql_generation_chain.invoke.call_count == 1
|
||||
assert chain.sparql_fix_chain.invoke.call_count == max_fix_retries
|
||||
assert chain.qa_chain.invoke.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||
@pytest.mark.parametrize(
|
||||
"max_fix_retries,number_of_invalid_responses",
|
||||
[(1, 0), (2, 0), (2, 1), (10, 6)],
|
||||
)
|
||||
def test_valid_sparql_after_some_retries(
|
||||
max_fix_retries: int, number_of_invalid_responses: int
|
||||
) -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
question = "What is Luke Skywalker's home planet?"
|
||||
answer = "Tatooine"
|
||||
generated_invalid_sparql = "```sparql SELECT * {?s ?p ?o} LIMIT 1```"
|
||||
generated_valid_sparql_query = "SELECT * {?s ?p ?o} LIMIT 1"
|
||||
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/starwars",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o} "
|
||||
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||
)
|
||||
chain = OntotextGraphDBQAChain.from_llm(
|
||||
Mock(ChatOpenAI),
|
||||
graph=graph,
|
||||
max_fix_retries=max_fix_retries,
|
||||
)
|
||||
chain.sparql_generation_chain = Mock(LLMChain)
|
||||
chain.sparql_fix_chain = Mock(LLMChain)
|
||||
chain.qa_chain = Mock(LLMChain)
|
||||
|
||||
chain.sparql_generation_chain.output_key = "text"
|
||||
chain.sparql_generation_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": generated_invalid_sparql,
|
||||
"prompt": question,
|
||||
"schema": "",
|
||||
}
|
||||
)
|
||||
chain.sparql_fix_chain.output_key = "text"
|
||||
chain.sparql_fix_chain.invoke = Mock()
|
||||
chain.sparql_fix_chain.invoke.side_effect = [
|
||||
{
|
||||
"text": generated_invalid_sparql,
|
||||
"error_message": "pyparsing.exceptions.ParseException: "
|
||||
"Expected {SelectQuery | ConstructQuery | DescribeQuery | AskQuery}, "
|
||||
"found '`' (at char 0), (line:1, col:1)",
|
||||
"generated_sparql": generated_invalid_sparql,
|
||||
"schema": "",
|
||||
}
|
||||
] * number_of_invalid_responses + [
|
||||
{
|
||||
"text": generated_valid_sparql_query,
|
||||
"error_message": "pyparsing.exceptions.ParseException: "
|
||||
"Expected {SelectQuery | ConstructQuery | DescribeQuery | AskQuery}, "
|
||||
"found '`' (at char 0), (line:1, col:1)",
|
||||
"generated_sparql": generated_invalid_sparql,
|
||||
"schema": "",
|
||||
}
|
||||
]
|
||||
chain.qa_chain.output_key = "text"
|
||||
chain.qa_chain.invoke = MagicMock(
|
||||
return_value={
|
||||
"text": answer,
|
||||
"prompt": question,
|
||||
"context": [],
|
||||
}
|
||||
)
|
||||
|
||||
result = chain.invoke({chain.input_key: question})
|
||||
|
||||
assert chain.sparql_generation_chain.invoke.call_count == 1
|
||||
assert chain.sparql_fix_chain.invoke.call_count == number_of_invalid_responses + 1
|
||||
assert chain.qa_chain.invoke.call_count == 1
|
||||
assert result == {chain.output_key: answer, chain.input_key: question}
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "rdflib")
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,question",
|
||||
[
|
||||
("gpt-3.5-turbo-1106", "What is the average height of the Wookiees?"),
|
||||
("gpt-3.5-turbo-1106", "What is the climate on Tatooine?"),
|
||||
("gpt-3.5-turbo-1106", "What is Luke Skywalker's home planet?"),
|
||||
("gpt-4-1106-preview", "What is the average height of the Wookiees?"),
|
||||
("gpt-4-1106-preview", "What is the climate on Tatooine?"),
|
||||
("gpt-4-1106-preview", "What is Luke Skywalker's home planet?"),
|
||||
],
|
||||
)
|
||||
def test_chain(model_name: str, question: str) -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
graph = OntotextGraphDBGraph(
|
||||
query_endpoint="http://localhost:7200/repositories/starwars",
|
||||
query_ontology="CONSTRUCT {?s ?p ?o} "
|
||||
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
|
||||
)
|
||||
chain = OntotextGraphDBQAChain.from_llm(
|
||||
ChatOpenAI(temperature=0, model_name=model_name),
|
||||
graph=graph,
|
||||
verbose=True, # type: ignore[call-arg]
|
||||
)
|
||||
try:
|
||||
chain.invoke({chain.input_key: question})
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Integration test for self ask with search."""
|
||||
|
||||
from langchain_community.docstore import Wikipedia
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
|
||||
from langchain.agents.react.base import ReActChain
|
||||
|
||||
|
||||
def test_react() -> None:
|
||||
"""Test functionality on a prompt."""
|
||||
llm = OpenAI(temperature=0, model_name="gpt-3.5-turbo-instruct") # type: ignore[call-arg]
|
||||
react = ReActChain(llm=llm, docstore=Wikipedia())
|
||||
question = (
|
||||
"Author David Chanoff has collaborated with a U.S. Navy admiral "
|
||||
"who served as the ambassador to the United Kingdom under "
|
||||
"which President?"
|
||||
)
|
||||
output = react.run(question)
|
||||
assert output == "Bill Clinton"
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Test RetrievalQA functionality."""
|
||||
from pathlib import Path
|
||||
|
||||
from langchain_community.document_loaders import TextLoader
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_text_splitters.character import CharacterTextSplitter
|
||||
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.chains.loading import load_chain
|
||||
|
||||
|
||||
def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:
|
||||
"""Test saving and loading."""
|
||||
loader = TextLoader("docs/extras/modules/state_of_the_union.txt")
|
||||
documents = loader.load()
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
texts = text_splitter.split_documents(documents)
|
||||
embeddings = OpenAIEmbeddings()
|
||||
docsearch = FAISS.from_documents(texts, embeddings)
|
||||
qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever())
|
||||
qa.run("What did the president say about Ketanji Brown Jackson?")
|
||||
|
||||
file_path = tmp_path / "RetrievalQA_chain.yaml"
|
||||
qa.save(file_path=file_path)
|
||||
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())
|
||||
|
||||
assert qa_loaded == qa
|
||||
@@ -1,39 +0,0 @@
|
||||
"""Test RetrievalQA functionality."""
|
||||
from langchain_community.document_loaders import DirectoryLoader
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_text_splitters.character import CharacterTextSplitter
|
||||
|
||||
from langchain.chains import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.loading import load_chain
|
||||
|
||||
|
||||
def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None:
|
||||
"""Test saving and loading."""
|
||||
loader = DirectoryLoader("docs/extras/modules/", glob="*.txt")
|
||||
documents = loader.load()
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
texts = text_splitter.split_documents(documents)
|
||||
embeddings = OpenAIEmbeddings()
|
||||
docsearch = FAISS.from_documents(texts, embeddings)
|
||||
qa = RetrievalQAWithSourcesChain.from_llm(
|
||||
llm=OpenAI(), retriever=docsearch.as_retriever()
|
||||
)
|
||||
result = qa("What did the president say about Ketanji Brown Jackson?")
|
||||
assert "question" in result.keys()
|
||||
assert "answer" in result.keys()
|
||||
assert "sources" in result.keys()
|
||||
file_path = str(tmp_path) + "/RetrievalQAWithSourcesChain.yaml"
|
||||
qa.save(file_path=file_path)
|
||||
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())
|
||||
|
||||
assert qa_loaded == qa
|
||||
|
||||
qa2 = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm=OpenAI(), retriever=docsearch.as_retriever(), chain_type="stuff"
|
||||
)
|
||||
result2 = qa2("What did the president say about Ketanji Brown Jackson?")
|
||||
assert "question" in result2.keys()
|
||||
assert "answer" in result2.keys()
|
||||
assert "sources" in result2.keys()
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Integration test for self ask with search."""
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
from langchain_community.utilities.searchapi import SearchApiAPIWrapper
|
||||
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
|
||||
|
||||
def test_self_ask_with_search() -> None:
|
||||
"""Test functionality on a prompt."""
|
||||
question = "What is the hometown of the reigning men's U.S. Open champion?"
|
||||
chain = SelfAskWithSearchChain(
|
||||
llm=OpenAI(temperature=0),
|
||||
search_chain=SearchApiAPIWrapper(),
|
||||
input_key="q",
|
||||
output_key="a",
|
||||
)
|
||||
answer = chain.run(question)
|
||||
final_answer = answer.split("\n")[-1]
|
||||
assert final_answer == "Belgrade, Serbia"
|
||||
@@ -1,202 +0,0 @@
|
||||
import os
|
||||
from typing import AsyncIterable, Iterable
|
||||
|
||||
import pytest
|
||||
from langchain_community.chat_message_histories.astradb import (
|
||||
AstraDBChatMessageHistory,
|
||||
)
|
||||
from langchain_community.utilities.astradb import SetupMode
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
def _has_env_vars() -> bool:
|
||||
return all(
|
||||
[
|
||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def history1() -> Iterable[AstraDBChatMessageHistory]:
|
||||
history1 = AstraDBChatMessageHistory(
|
||||
session_id="session-test-1",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield history1
|
||||
history1.collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def history2() -> Iterable[AstraDBChatMessageHistory]:
|
||||
history2 = AstraDBChatMessageHistory(
|
||||
session_id="session-test-2",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield history2
|
||||
history2.collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]:
|
||||
history1 = AstraDBChatMessageHistory(
|
||||
session_id="async-session-test-1",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield history1
|
||||
await history1.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]:
|
||||
history2 = AstraDBChatMessageHistory(
|
||||
session_id="async-session-test-2",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield history2
|
||||
await history2.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
|
||||
"""Test the memory with a message store."""
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz",
|
||||
chat_memory=history1,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_messages(
|
||||
[
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
)
|
||||
|
||||
messages = memory.chat_memory.messages
|
||||
expected = [
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
assert messages == expected
|
||||
|
||||
# clear the store
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
async def test_memory_with_message_store_async(
|
||||
async_history1: AstraDBChatMessageHistory,
|
||||
) -> None:
|
||||
"""Test the memory with a message store."""
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz",
|
||||
chat_memory=async_history1,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
assert await memory.chat_memory.aget_messages() == []
|
||||
|
||||
# add some messages
|
||||
await memory.chat_memory.aadd_messages(
|
||||
[
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
)
|
||||
|
||||
messages = await memory.chat_memory.aget_messages()
|
||||
expected = [
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
assert messages == expected
|
||||
|
||||
# clear the store
|
||||
await memory.chat_memory.aclear()
|
||||
|
||||
assert await memory.chat_memory.aget_messages() == []
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
def test_memory_separate_session_ids(
|
||||
history1: AstraDBChatMessageHistory, history2: AstraDBChatMessageHistory
|
||||
) -> None:
|
||||
"""Test that separate session IDs do not share entries."""
|
||||
memory1 = ConversationBufferMemory(
|
||||
memory_key="mk1",
|
||||
chat_memory=history1,
|
||||
return_messages=True,
|
||||
)
|
||||
memory2 = ConversationBufferMemory(
|
||||
memory_key="mk2",
|
||||
chat_memory=history2,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
memory1.chat_memory.add_messages([AIMessage(content="Just saying.")])
|
||||
|
||||
assert memory2.chat_memory.messages == []
|
||||
|
||||
memory2.chat_memory.clear()
|
||||
|
||||
assert memory1.chat_memory.messages != []
|
||||
|
||||
memory1.chat_memory.clear()
|
||||
|
||||
assert memory1.chat_memory.messages == []
|
||||
|
||||
|
||||
@pytest.mark.requires("astrapy")
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
async def test_memory_separate_session_ids_async(
|
||||
async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory
|
||||
) -> None:
|
||||
"""Test that separate session IDs do not share entries."""
|
||||
memory1 = ConversationBufferMemory(
|
||||
memory_key="mk1",
|
||||
chat_memory=async_history1,
|
||||
return_messages=True,
|
||||
)
|
||||
memory2 = ConversationBufferMemory(
|
||||
memory_key="mk2",
|
||||
chat_memory=async_history2,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")])
|
||||
|
||||
assert await memory2.chat_memory.aget_messages() == []
|
||||
|
||||
await memory2.chat_memory.aclear()
|
||||
|
||||
assert await memory1.chat_memory.aget_messages() != []
|
||||
|
||||
await memory1.chat_memory.aclear()
|
||||
|
||||
assert await memory1.chat_memory.aget_messages() == []
|
||||
@@ -1,116 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.chat_message_histories.cassandra import (
|
||||
CassandraChatMessageHistory,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
def _chat_message_history(
|
||||
session_id: str = "test-session",
|
||||
drop: bool = True,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> CassandraChatMessageHistory:
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
keyspace = "cmh_test_keyspace"
|
||||
table_name = "cmh_test_table"
|
||||
# get db connection
|
||||
if "CASSANDRA_CONTACT_POINTS" in os.environ:
|
||||
contact_points = os.environ["CONTACT_POINTS"].split(",")
|
||||
cluster = Cluster(contact_points)
|
||||
else:
|
||||
cluster = Cluster()
|
||||
#
|
||||
session = cluster.connect()
|
||||
# ensure keyspace exists
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
|
||||
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||
)
|
||||
)
|
||||
# drop table if required
|
||||
if drop:
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}")
|
||||
#
|
||||
return CassandraChatMessageHistory(
|
||||
session_id=session_id,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
**({} if ttl_seconds is None else {"ttl_seconds": ttl_seconds}),
|
||||
)
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup cassandra as a message store
|
||||
message_history = _chat_message_history()
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz",
|
||||
chat_memory=message_history,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
messages = memory.chat_memory.messages
|
||||
expected = [
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
assert messages == expected
|
||||
|
||||
# clear the store
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
|
||||
def test_memory_separate_session_ids() -> None:
|
||||
"""Test that separate session IDs do not share entries."""
|
||||
message_history1 = _chat_message_history(session_id="test-session1")
|
||||
memory1 = ConversationBufferMemory(
|
||||
memory_key="mk1",
|
||||
chat_memory=message_history1,
|
||||
return_messages=True,
|
||||
)
|
||||
message_history2 = _chat_message_history(session_id="test-session2")
|
||||
memory2 = ConversationBufferMemory(
|
||||
memory_key="mk2",
|
||||
chat_memory=message_history2,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
memory1.chat_memory.add_ai_message("Just saying.")
|
||||
|
||||
assert memory2.chat_memory.messages == []
|
||||
|
||||
memory1.chat_memory.clear()
|
||||
memory2.chat_memory.clear()
|
||||
|
||||
|
||||
def test_memory_ttl() -> None:
|
||||
"""Test time-to-live feature of the memory."""
|
||||
message_history = _chat_message_history(ttl_seconds=5)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz",
|
||||
chat_memory=message_history,
|
||||
return_messages=True,
|
||||
)
|
||||
#
|
||||
assert memory.chat_memory.messages == []
|
||||
memory.chat_memory.add_ai_message("Nothing special here.")
|
||||
time.sleep(2)
|
||||
assert memory.chat_memory.messages != []
|
||||
time.sleep(5)
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -1,45 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from langchain_community.chat_message_histories import CosmosDBChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
# Replace these with your Azure Cosmos DB endpoint and key
|
||||
endpoint = os.environ.get("COSMOS_DB_ENDPOINT", "")
|
||||
credential = os.environ.get("COSMOS_DB_KEY", "")
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup Azure Cosmos DB as a message store
|
||||
message_history = CosmosDBChatMessageHistory(
|
||||
cosmos_endpoint=endpoint,
|
||||
cosmos_database="chat_history",
|
||||
cosmos_container="messages",
|
||||
credential=credential,
|
||||
session_id="my-test-session",
|
||||
user_id="my-test-user",
|
||||
ttl=10,
|
||||
)
|
||||
message_history.prepare_cosmos()
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Azure Cosmos DB, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -1,91 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Generator, Union
|
||||
|
||||
import pytest
|
||||
from langchain_community.chat_message_histories import ElasticsearchChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
"""
|
||||
cd tests/integration_tests/memory/docker-compose
|
||||
docker-compose -f elasticsearch.yml up
|
||||
|
||||
By default runs against local docker instance of Elasticsearch.
|
||||
To run against Elastic Cloud, set the following environment variables:
|
||||
- ES_CLOUD_ID
|
||||
- ES_USERNAME
|
||||
- ES_PASSWORD
|
||||
"""
|
||||
|
||||
|
||||
class TestElasticsearch:
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def elasticsearch_connection(self) -> Union[dict, Generator[dict, None, None]]: # type: ignore[return]
|
||||
# Run this integration test against Elasticsearch on localhost,
|
||||
# or an Elastic Cloud instance
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
es_url = os.environ.get("ES_URL", "http://localhost:9200")
|
||||
es_cloud_id = os.environ.get("ES_CLOUD_ID")
|
||||
es_username = os.environ.get("ES_USERNAME", "elastic")
|
||||
es_password = os.environ.get("ES_PASSWORD", "changeme")
|
||||
|
||||
if es_cloud_id:
|
||||
es = Elasticsearch(
|
||||
cloud_id=es_cloud_id,
|
||||
basic_auth=(es_username, es_password),
|
||||
)
|
||||
yield {
|
||||
"es_cloud_id": es_cloud_id,
|
||||
"es_user": es_username,
|
||||
"es_password": es_password,
|
||||
}
|
||||
|
||||
else:
|
||||
# Running this integration test with local docker instance
|
||||
es = Elasticsearch(hosts=es_url)
|
||||
yield {"es_url": es_url}
|
||||
|
||||
# Clear all indexes
|
||||
index_names = es.indices.get(index="_all").keys()
|
||||
for index_name in index_names:
|
||||
if index_name.startswith("test_"):
|
||||
es.indices.delete(index=index_name)
|
||||
es.indices.refresh(index="_all")
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_name(self) -> str:
|
||||
"""Return the index name."""
|
||||
return f"test_{uuid.uuid4().hex}"
|
||||
|
||||
def test_memory_with_message_store(
|
||||
self, elasticsearch_connection: dict, index_name: str
|
||||
) -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup Elasticsearch as a message store
|
||||
message_history = ElasticsearchChatMessageHistory(
|
||||
**elasticsearch_connection, index=index_name, session_id="test-session"
|
||||
)
|
||||
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Elasticsearch, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -1,44 +0,0 @@
|
||||
import json
|
||||
|
||||
from langchain_community.chat_message_histories import FirestoreChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
|
||||
message_history = FirestoreChatMessageHistory(
|
||||
collection_name="chat_history",
|
||||
session_id="my-test-session",
|
||||
user_id="my-test-user",
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store
|
||||
# and check if the messages are there as expected
|
||||
message_history = FirestoreChatMessageHistory(
|
||||
collection_name="chat_history",
|
||||
session_id="my-test-session",
|
||||
user_id="my-test-user",
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Firestore, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -1,71 +0,0 @@
|
||||
"""Test Momento chat message history functionality.
|
||||
|
||||
To run tests, set the environment variable MOMENTO_AUTH_TOKEN to a valid
|
||||
Momento auth token. This can be obtained by signing up for a free
|
||||
Momento account at https://gomomento.com/.
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from langchain_community.chat_message_histories import MomentoChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def message_history() -> Iterator[MomentoChatMessageHistory]:
|
||||
from momento import CacheClient, Configurations, CredentialProvider
|
||||
|
||||
cache_name = f"langchain-test-cache-{random_string()}"
|
||||
client = CacheClient(
|
||||
Configurations.Laptop.v1(),
|
||||
CredentialProvider.from_environment_variable("MOMENTO_API_KEY"),
|
||||
default_ttl=timedelta(seconds=30),
|
||||
)
|
||||
try:
|
||||
chat_message_history = MomentoChatMessageHistory(
|
||||
session_id="my-test-session",
|
||||
cache_client=client,
|
||||
cache_name=cache_name,
|
||||
)
|
||||
yield chat_message_history
|
||||
finally:
|
||||
client.delete_cache(cache_name)
|
||||
|
||||
|
||||
def test_memory_empty_on_new_session(
|
||||
message_history: MomentoChatMessageHistory,
|
||||
) -> None:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="foo", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
|
||||
def test_memory_with_message_store(message_history: MomentoChatMessageHistory) -> None:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# Add some messages to the memory store
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# Verify that the messages are in the store
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# Verify clearing the store
|
||||
memory.chat_memory.clear()
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -1,37 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from langchain_community.chat_message_histories import MongoDBChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
# Replace these with your mongodb connection string
|
||||
connection_string = os.environ.get("MONGODB_CONNECTION_STRING", "")
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup MongoDB as a message store
|
||||
message_history = MongoDBChatMessageHistory(
|
||||
connection_string=connection_string, session_id="test-session"
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Azure Cosmos DB, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -1,31 +0,0 @@
|
||||
import json
|
||||
|
||||
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup MongoDB as a message store
|
||||
message_history = Neo4jChatMessageHistory(session_id="test-session")
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Azure Cosmos DB, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -1,31 +0,0 @@
|
||||
import json
|
||||
|
||||
from langchain_community.chat_message_histories import RedisChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup Redis as a message store
|
||||
message_history = RedisChatMessageHistory(
|
||||
url="redis://localhost:6379/0", ttl=10, session_id="my-test-session"
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Redis, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Tests RocksetChatMessageHistory by creating a collection
|
||||
for message history, adding to it, and clearing it.
|
||||
|
||||
To run these tests, make sure you have the ROCKSET_API_KEY
|
||||
and ROCKSET_REGION environment variables set.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from langchain_community.chat_message_histories import RocksetChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
collection_name = "langchain_demo"
|
||||
session_id = "MySession"
|
||||
|
||||
|
||||
class TestRockset:
|
||||
memory: RocksetChatMessageHistory
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
from rockset import DevRegions, Regions, RocksetClient
|
||||
|
||||
assert os.environ.get("ROCKSET_API_KEY") is not None
|
||||
assert os.environ.get("ROCKSET_REGION") is not None
|
||||
|
||||
api_key = os.environ.get("ROCKSET_API_KEY")
|
||||
region = os.environ.get("ROCKSET_REGION")
|
||||
if region == "use1a1":
|
||||
host = Regions.use1a1
|
||||
elif region == "usw2a1" or not region:
|
||||
host = Regions.usw2a1
|
||||
elif region == "euc1a1":
|
||||
host = Regions.euc1a1
|
||||
elif region == "dev":
|
||||
host = DevRegions.usw2a1
|
||||
else:
|
||||
host = region
|
||||
|
||||
client = RocksetClient(host, api_key)
|
||||
cls.memory = RocksetChatMessageHistory(
|
||||
session_id, client, collection_name, sync=True
|
||||
)
|
||||
|
||||
def test_memory_with_message_store(self) -> None:
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="messages", chat_memory=self.memory, return_messages=True
|
||||
)
|
||||
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -1,37 +0,0 @@
|
||||
import json
|
||||
|
||||
from langchain_community.chat_message_histories import SingleStoreDBChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
# Replace these with your mongodb connection string
|
||||
TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db"
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup SingleStoreDB as a message store
|
||||
message_history = SingleStoreDBChatMessageHistory(
|
||||
session_id="test-session",
|
||||
host=TEST_SINGLESTOREDB_URL,
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from SingleStoreDB, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
@@ -1,38 +0,0 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from langchain_community.chat_message_histories.upstash_redis import (
|
||||
UpstashRedisChatMessageHistory,
|
||||
)
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
URL = "<UPSTASH_REDIS_REST_URL>"
|
||||
TOKEN = "<UPSTASH_REDIS_REST_TOKEN>"
|
||||
|
||||
|
||||
@pytest.mark.requires("upstash_redis")
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup Upstash Redis as a message store
|
||||
message_history = UpstashRedisChatMessageHistory(
|
||||
url=URL, token=TOKEN, ttl=10, session_id="my-test-session"
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Redis, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Test Xata chat memory store functionality.
|
||||
|
||||
Before running this test, please create a Xata database.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from langchain_community.chat_message_histories import XataChatMessageHistory
|
||||
from langchain_core.messages import message_to_dict
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
class TestXata:
|
||||
@classmethod
|
||||
def setup_class(cls) -> None:
|
||||
assert os.getenv("XATA_API_KEY"), "XATA_API_KEY environment variable is not set"
|
||||
assert os.getenv("XATA_DB_URL"), "XATA_DB_URL environment variable is not set"
|
||||
|
||||
def test_xata_chat_memory(self) -> None:
|
||||
message_history = XataChatMessageHistory(
|
||||
api_key=os.getenv("XATA_API_KEY", ""),
|
||||
db_url=os.getenv("XATA_DB_URL", ""),
|
||||
session_id="integration-test-session",
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Redis, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Test functionality related to ngram overlap based selector."""
|
||||
|
||||
import pytest
|
||||
from langchain_community.example_selectors import (
|
||||
NGramOverlapExampleSelector,
|
||||
ngram_overlap_score,
|
||||
)
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
EXAMPLES = [
|
||||
{"input": "See Spot run.", "output": "foo1"},
|
||||
{"input": "My dog barks.", "output": "foo2"},
|
||||
{"input": "Spot can run.", "output": "foo3"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def selector() -> NGramOverlapExampleSelector:
|
||||
"""Get ngram overlap based selector to use in tests."""
|
||||
prompts = PromptTemplate(
|
||||
input_variables=["input", "output"], template="Input: {input}\nOutput: {output}"
|
||||
)
|
||||
selector = NGramOverlapExampleSelector(
|
||||
examples=EXAMPLES,
|
||||
example_prompt=prompts,
|
||||
)
|
||||
return selector
|
||||
|
||||
|
||||
def test_selector_valid(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Test NGramOverlapExampleSelector can select examples."""
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]]
|
||||
|
||||
|
||||
def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Test NGramOverlapExampleSelector can add an example."""
|
||||
new_example = {"input": "Spot plays fetch.", "output": "foo4"}
|
||||
selector.add_example(new_example)
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]]
|
||||
|
||||
|
||||
def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Tests NGramOverlapExampleSelector threshold set to 0.0."""
|
||||
selector.threshold = 0.0
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == [EXAMPLES[2], EXAMPLES[0]]
|
||||
|
||||
|
||||
def test_selector_threshold_more_than_one(
|
||||
selector: NGramOverlapExampleSelector,
|
||||
) -> None:
|
||||
"""Tests NGramOverlapExampleSelector threshold greater than 1.0."""
|
||||
selector.threshold = 1.0 + 1e-9
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == []
|
||||
|
||||
|
||||
def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Tests that ngram_overlap_score returns correct values."""
|
||||
selector.threshold = 1.0 + 1e-9
|
||||
none = ngram_overlap_score(["Spot can run."], ["My dog barks."])
|
||||
some = ngram_overlap_score(["Spot can run."], ["See Spot run."])
|
||||
complete = ngram_overlap_score(["Spot can run."], ["Spot can run."])
|
||||
|
||||
check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9]
|
||||
assert check == [True, True, True]
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Integration test for compression pipelines."""
|
||||
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters.character import CharacterTextSplitter
|
||||
|
||||
from langchain.retrievers.document_compressors import (
|
||||
DocumentCompressorPipeline,
|
||||
EmbeddingsFilter,
|
||||
)
|
||||
|
||||
|
||||
def test_document_compressor_pipeline() -> None:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
splitter = CharacterTextSplitter(chunk_size=20, chunk_overlap=0, separator=". ")
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8)
|
||||
pipeline_filter = DocumentCompressorPipeline(
|
||||
transformers=[splitter, redundant_filter, relevant_filter]
|
||||
)
|
||||
texts = [
|
||||
"This sentence is about cows",
|
||||
"This sentence was about cows",
|
||||
"foo bar baz",
|
||||
]
|
||||
docs = [Document(page_content=". ".join(texts))]
|
||||
actual = pipeline_filter.compress_documents(docs, "Tell me about farm animals")
|
||||
assert len(actual) == 1
|
||||
assert actual[0].page_content in texts[:2]
|
||||
@@ -1,45 +0,0 @@
|
||||
"""Integration test for LLMChainExtractor."""
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers.document_compressors import LLMChainExtractor
|
||||
|
||||
|
||||
def test_llm_construction_with_kwargs() -> None:
|
||||
llm_chain_kwargs = {"verbose": True}
|
||||
compressor = LLMChainExtractor.from_llm(
|
||||
ChatOpenAI(), llm_chain_kwargs=llm_chain_kwargs
|
||||
)
|
||||
assert compressor.llm_chain.verbose is True
|
||||
|
||||
|
||||
def test_llm_chain_extractor() -> None:
|
||||
texts = [
|
||||
"The Roman Empire followed the Roman Republic.",
|
||||
"I love chocolate chip cookies—my mother makes great cookies.",
|
||||
"The first Roman emperor was Caesar Augustus.",
|
||||
"Don't you just love Caesar salad?",
|
||||
"The Roman Empire collapsed in 476 AD after the fall of Rome.",
|
||||
"Let's go to Olive Garden!",
|
||||
]
|
||||
doc = Document(page_content=" ".join(texts))
|
||||
compressor = LLMChainExtractor.from_llm(ChatOpenAI())
|
||||
actual = compressor.compress_documents([doc], "Tell me about the Roman Empire")[
|
||||
0
|
||||
].page_content
|
||||
expected_returned = [0, 2, 4]
|
||||
expected_not_returned = [1, 3, 5]
|
||||
assert all([texts[i] in actual for i in expected_returned])
|
||||
assert all([texts[i] not in actual for i in expected_not_returned])
|
||||
|
||||
|
||||
def test_llm_chain_extractor_empty() -> None:
|
||||
texts = [
|
||||
"I love chocolate chip cookies—my mother makes great cookies.",
|
||||
"Don't you just love Caesar salad?",
|
||||
"Let's go to Olive Garden!",
|
||||
]
|
||||
doc = Document(page_content=" ".join(texts))
|
||||
compressor = LLMChainExtractor.from_llm(ChatOpenAI())
|
||||
actual = compressor.compress_documents([doc], "Tell me about the Roman Empire")
|
||||
assert len(actual) == 0
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Integration test for llm-based relevant doc filtering."""
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers.document_compressors import LLMChainFilter
|
||||
|
||||
|
||||
def test_llm_chain_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
"My favorite color is green",
|
||||
]
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
relevant_filter = LLMChainFilter.from_llm(llm=ChatOpenAI())
|
||||
actual = relevant_filter.compress_documents(docs, "Things I said related to food")
|
||||
assert len(actual) == 2
|
||||
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
|
||||
@@ -1,43 +0,0 @@
|
||||
"""Integration test for embedding-based relevant doc filtering."""
|
||||
import numpy as np
|
||||
from langchain_community.document_transformers.embeddings_redundant_filter import (
|
||||
_DocumentWithState,
|
||||
)
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||
|
||||
|
||||
def test_embeddings_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
"My favorite color is green",
|
||||
]
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
actual = relevant_filter.compress_documents(docs, "What did I say about food?")
|
||||
assert len(actual) == 2
|
||||
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
|
||||
|
||||
|
||||
def test_embeddings_filter_with_state() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
"My favorite color is green",
|
||||
]
|
||||
query = "What did I say about food?"
|
||||
embeddings = OpenAIEmbeddings()
|
||||
embedded_query = embeddings.embed_query(query)
|
||||
state = {"embedded_doc": np.zeros(len(embedded_query))}
|
||||
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
||||
docs[-1].state = {"embedded_doc": embedded_query}
|
||||
relevant_filter = EmbeddingsFilter( # type: ignore[call-arg]
|
||||
embeddings=embeddings, similarity_threshold=0.75, return_similarity_scores=True
|
||||
)
|
||||
actual = relevant_filter.compress_documents(docs, query)
|
||||
assert len(actual) == 1
|
||||
assert texts[-1] == actual[0].page_content
|
||||
@@ -1,26 +0,0 @@
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||
|
||||
|
||||
def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
||||
"""Test get_relevant_docs."""
|
||||
texts = [
|
||||
"This is a document about the Boston Celtics",
|
||||
"The Boston Celtics won the game by 20 points",
|
||||
"I simply love going to the movies",
|
||||
]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
base_retriever = FAISS.from_texts(texts, embedding=embeddings).as_retriever(
|
||||
search_kwargs={"k": len(texts)}
|
||||
)
|
||||
retriever = ContextualCompressionRetriever(
|
||||
base_compressor=base_compressor, base_retriever=base_retriever
|
||||
)
|
||||
|
||||
actual = retriever.invoke("Tell me about the Celtics")
|
||||
assert len(actual) == 2
|
||||
assert texts[-1] not in [d.page_content for d in actual]
|
||||
@@ -1,33 +0,0 @@
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import Chroma
|
||||
|
||||
from langchain.retrievers.merger_retriever import MergerRetriever
|
||||
|
||||
|
||||
def test_merger_retriever_get_relevant_docs() -> None:
|
||||
"""Test get_relevant_docs."""
|
||||
texts_group_a = [
|
||||
"This is a document about the Boston Celtics",
|
||||
"Fly me to the moon is one of my favourite songs."
|
||||
"I simply love going to the movies",
|
||||
]
|
||||
texts_group_b = [
|
||||
"This is a document about the Poenix Suns",
|
||||
"The Boston Celtics won the game by 20 points",
|
||||
"Real stupidity beats artificial intelligence every time. TP",
|
||||
]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
retriever_a = Chroma.from_texts(texts_group_a, embedding=embeddings).as_retriever(
|
||||
search_kwargs={"k": 1}
|
||||
)
|
||||
retriever_b = Chroma.from_texts(texts_group_b, embedding=embeddings).as_retriever(
|
||||
search_kwargs={"k": 1}
|
||||
)
|
||||
|
||||
# The Lord of the Retrievers.
|
||||
lotr = MergerRetriever(retrievers=[retriever_a, retriever_b])
|
||||
|
||||
actual = lotr.invoke("Tell me about the Celtics")
|
||||
assert len(actual) == 2
|
||||
assert texts_group_a[0] in [d.page_content for d in actual]
|
||||
assert texts_group_b[1] in [d.page_content for d in actual]
|
||||
@@ -1,503 +0,0 @@
|
||||
from typing import Iterator, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
from langsmith import Client as Client
|
||||
from langsmith.evaluation import run_evaluator
|
||||
from langsmith.schemas import DataType, Example, Run
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.evaluation import EvaluatorType
|
||||
from langchain.smith import RunEvalConfig, run_on_dataset
|
||||
from langchain.smith.evaluation import InputFormatError
|
||||
from langchain.smith.evaluation.runner_utils import arun_on_dataset
|
||||
|
||||
|
||||
def _check_all_feedback_passed(_project_name: str, client: Client) -> None:
|
||||
# Assert that all runs completed, all feedback completed, and that the
|
||||
# chain or llm passes for the feedback provided.
|
||||
runs = list(client.list_runs(project_name=_project_name, execution_order=1))
|
||||
if not runs:
|
||||
# Queue delays. We are mainly just smoke checking rn.
|
||||
return
|
||||
feedback = list(client.list_feedback(run_ids=[run.id for run in runs]))
|
||||
if not feedback:
|
||||
return
|
||||
assert all([bool(f.score) for f in feedback])
|
||||
|
||||
|
||||
@run_evaluator
|
||||
def not_empty(run: Run, example: Optional[Example] = None) -> dict:
|
||||
return {
|
||||
"score": run.outputs and next(iter(run.outputs.values())),
|
||||
"key": "not_empty",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eval_project_name() -> str:
|
||||
return f"lcp integration tests - {str(uuid4())[-8:]}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client() -> Client:
|
||||
return Client()
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
)
|
||||
def kv_dataset_name() -> Iterator[str]:
|
||||
import pandas as pd
|
||||
|
||||
client = Client()
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"some_input": [
|
||||
"What's the capital of California?",
|
||||
"What's the capital of Nevada?",
|
||||
"What's the capital of Oregon?",
|
||||
"What's the capital of Washington?",
|
||||
],
|
||||
"other_input": [
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
],
|
||||
"some_output": ["Sacramento", "Carson City", "Salem", "Olympia"],
|
||||
"other_output": ["e", "f", "g", "h"],
|
||||
}
|
||||
)
|
||||
|
||||
uid = str(uuid4())[-8:]
|
||||
_dataset_name = f"lcp kv dataset integration tests - {uid}"
|
||||
client.upload_dataframe(
|
||||
df,
|
||||
name=_dataset_name,
|
||||
input_keys=["some_input", "other_input"],
|
||||
output_keys=["some_output", "other_output"],
|
||||
description="Integration test dataset",
|
||||
)
|
||||
yield _dataset_name
|
||||
|
||||
|
||||
def test_chat_model(
|
||||
kv_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
eval_config = RunEvalConfig(
|
||||
evaluators=[EvaluatorType.QA], custom_evaluators=[not_empty]
|
||||
)
|
||||
with pytest.raises(ValueError, match="Must specify reference_key"):
|
||||
run_on_dataset(
|
||||
dataset_name=kv_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
evaluation=eval_config,
|
||||
client=client,
|
||||
)
|
||||
eval_config = RunEvalConfig(
|
||||
evaluators=[EvaluatorType.QA],
|
||||
reference_key="some_output",
|
||||
)
|
||||
with pytest.raises(
|
||||
InputFormatError, match="Example inputs do not match language model"
|
||||
):
|
||||
run_on_dataset(
|
||||
dataset_name=kv_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
evaluation=eval_config,
|
||||
client=client,
|
||||
)
|
||||
|
||||
def input_mapper(d: dict) -> List[BaseMessage]:
|
||||
return [HumanMessage(content=d["some_input"])]
|
||||
|
||||
run_on_dataset(
|
||||
client=client,
|
||||
dataset_name=kv_dataset_name,
|
||||
llm_or_chain_factory=input_mapper | llm,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
def test_llm(kv_dataset_name: str, eval_project_name: str, client: Client) -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA])
|
||||
with pytest.raises(ValueError, match="Must specify reference_key"):
|
||||
run_on_dataset(
|
||||
dataset_name=kv_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
evaluation=eval_config,
|
||||
client=client,
|
||||
)
|
||||
eval_config = RunEvalConfig(
|
||||
evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA],
|
||||
reference_key="some_output",
|
||||
)
|
||||
with pytest.raises(InputFormatError, match="Example inputs"):
|
||||
run_on_dataset(
|
||||
dataset_name=kv_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
evaluation=eval_config,
|
||||
client=client,
|
||||
)
|
||||
|
||||
def input_mapper(d: dict) -> str:
|
||||
return d["some_input"]
|
||||
|
||||
run_on_dataset(
|
||||
client=client,
|
||||
dataset_name=kv_dataset_name,
|
||||
llm_or_chain_factory=input_mapper | llm,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
def test_chain(kv_dataset_name: str, eval_project_name: str, client: Client) -> None:
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(llm, "The answer to the {question} is: ")
|
||||
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
|
||||
with pytest.raises(ValueError, match="Must specify reference_key"):
|
||||
run_on_dataset(
|
||||
dataset_name=kv_dataset_name,
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
evaluation=eval_config,
|
||||
client=client,
|
||||
)
|
||||
eval_config = RunEvalConfig(
|
||||
evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA],
|
||||
reference_key="some_output",
|
||||
)
|
||||
with pytest.raises(InputFormatError, match="Example inputs"):
|
||||
run_on_dataset(
|
||||
dataset_name=kv_dataset_name,
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
evaluation=eval_config,
|
||||
client=client,
|
||||
)
|
||||
|
||||
eval_config = RunEvalConfig(
|
||||
custom_evaluators=[not_empty],
|
||||
)
|
||||
|
||||
def right_input_mapper(d: dict) -> dict:
|
||||
return {"question": d["some_input"]}
|
||||
|
||||
run_on_dataset(
|
||||
dataset_name=kv_dataset_name,
|
||||
llm_or_chain_factory=lambda: right_input_mapper | chain,
|
||||
client=client,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
### Testing Chat Datasets
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
)
|
||||
def chat_dataset_name() -> Iterator[str]:
|
||||
def _create_message(txt: str, role: str = "human") -> List[dict]:
|
||||
return [{"type": role, "data": {"content": txt}}]
|
||||
|
||||
import pandas as pd
|
||||
|
||||
client = Client()
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"input": [
|
||||
_create_message(txt)
|
||||
for txt in (
|
||||
"What's the capital of California?",
|
||||
"What's the capital of Nevada?",
|
||||
"What's the capital of Oregon?",
|
||||
"What's the capital of Washington?",
|
||||
)
|
||||
],
|
||||
"output": [
|
||||
_create_message(txt, role="ai")[0]
|
||||
for txt in ("Sacramento", "Carson City", "Salem", "Olympia")
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
uid = str(uuid4())[-8:]
|
||||
_dataset_name = f"lcp chat dataset integration tests - {uid}"
|
||||
ds = client.create_dataset(
|
||||
_dataset_name, description="Integration test dataset", data_type=DataType.chat
|
||||
)
|
||||
for row in df.itertuples():
|
||||
client.create_example(
|
||||
dataset_id=ds.id,
|
||||
inputs={"input": row.input},
|
||||
outputs={"output": row.output},
|
||||
)
|
||||
yield _dataset_name
|
||||
|
||||
|
||||
def test_chat_model_on_chat_dataset(
|
||||
chat_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
|
||||
run_on_dataset(
|
||||
dataset_name=chat_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
evaluation=eval_config,
|
||||
client=client,
|
||||
project_name=eval_project_name,
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
def test_llm_on_chat_dataset(
|
||||
chat_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
|
||||
run_on_dataset(
|
||||
dataset_name=chat_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
client=client,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
def test_chain_on_chat_dataset(chat_dataset_name: str, client: Client) -> None:
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(llm, "The answer to the {question} is: ")
|
||||
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot evaluate a chain on dataset with data_type=chat"
|
||||
):
|
||||
run_on_dataset(
|
||||
dataset_name=chat_dataset_name,
|
||||
client=client,
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
evaluation=eval_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
)
|
||||
def llm_dataset_name() -> Iterator[str]:
|
||||
import pandas as pd
|
||||
|
||||
client = Client()
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"input": [
|
||||
"What's the capital of California?",
|
||||
"What's the capital of Nevada?",
|
||||
"What's the capital of Oregon?",
|
||||
"What's the capital of Washington?",
|
||||
],
|
||||
"output": ["Sacramento", "Carson City", "Salem", "Olympia"],
|
||||
}
|
||||
)
|
||||
|
||||
uid = str(uuid4())[-8:]
|
||||
_dataset_name = f"lcp llm dataset integration tests - {uid}"
|
||||
client.upload_dataframe(
|
||||
df,
|
||||
name=_dataset_name,
|
||||
input_keys=["input"],
|
||||
output_keys=["output"],
|
||||
description="Integration test dataset",
|
||||
data_type=DataType.llm,
|
||||
)
|
||||
yield _dataset_name
|
||||
|
||||
|
||||
def test_chat_model_on_llm_dataset(
|
||||
llm_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
|
||||
run_on_dataset(
|
||||
client=client,
|
||||
dataset_name=llm_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
def test_llm_on_llm_dataset(
|
||||
llm_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
|
||||
run_on_dataset(
|
||||
client=client,
|
||||
dataset_name=llm_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
def test_chain_on_llm_dataset(llm_dataset_name: str, client: Client) -> None:
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(llm, "The answer to the {question} is: ")
|
||||
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot evaluate a chain on dataset with data_type=llm"
|
||||
):
|
||||
run_on_dataset(
|
||||
client=client,
|
||||
dataset_name=llm_dataset_name,
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
evaluation=eval_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
)
|
||||
def kv_singleio_dataset_name() -> Iterator[str]:
|
||||
import pandas as pd
|
||||
|
||||
client = Client()
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"the wackiest input": [
|
||||
"What's the capital of California?",
|
||||
"What's the capital of Nevada?",
|
||||
"What's the capital of Oregon?",
|
||||
"What's the capital of Washington?",
|
||||
],
|
||||
"unthinkable output": ["Sacramento", "Carson City", "Salem", "Olympia"],
|
||||
}
|
||||
)
|
||||
|
||||
uid = str(uuid4())[-8:]
|
||||
_dataset_name = f"lcp singleio kv dataset integration tests - {uid}"
|
||||
client.upload_dataframe(
|
||||
df,
|
||||
name=_dataset_name,
|
||||
input_keys=["the wackiest input"],
|
||||
output_keys=["unthinkable output"],
|
||||
description="Integration test dataset",
|
||||
)
|
||||
yield _dataset_name
|
||||
|
||||
|
||||
def test_chat_model_on_kv_singleio_dataset(
|
||||
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
eval_config = RunEvalConfig(evaluators=[EvaluatorType.QA, EvaluatorType.CRITERIA])
|
||||
run_on_dataset(
|
||||
dataset_name=kv_singleio_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
evaluation=eval_config,
|
||||
client=client,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
def test_llm_on_kv_singleio_dataset(
|
||||
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
|
||||
run_on_dataset(
|
||||
dataset_name=kv_singleio_dataset_name,
|
||||
llm_or_chain_factory=llm,
|
||||
client=client,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
def test_chain_on_kv_singleio_dataset(
|
||||
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
chain = LLMChain.from_string(llm, "The answer to the {question} is: ")
|
||||
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
|
||||
run_on_dataset(
|
||||
dataset_name=kv_singleio_dataset_name,
|
||||
llm_or_chain_factory=lambda: chain,
|
||||
client=client,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
async def test_runnable_on_kv_singleio_dataset(
|
||||
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
runnable = (
|
||||
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
|
||||
| ChatOpenAI()
|
||||
)
|
||||
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
|
||||
await arun_on_dataset(
|
||||
dataset_name=kv_singleio_dataset_name,
|
||||
llm_or_chain_factory=runnable,
|
||||
client=client,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
|
||||
|
||||
async def test_arb_func_on_kv_singleio_dataset(
|
||||
kv_singleio_dataset_name: str, eval_project_name: str, client: Client
|
||||
) -> None:
|
||||
runnable = (
|
||||
ChatPromptTemplate.from_messages([("human", "{the wackiest input}")])
|
||||
| ChatOpenAI()
|
||||
)
|
||||
|
||||
def my_func(x: dict) -> str:
|
||||
content = runnable.invoke(x).content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected message with content type string, got {content}"
|
||||
)
|
||||
|
||||
eval_config = RunEvalConfig(custom_evaluators=[not_empty])
|
||||
await arun_on_dataset(
|
||||
dataset_name=kv_singleio_dataset_name,
|
||||
llm_or_chain_factory=my_func,
|
||||
client=client,
|
||||
evaluation=eval_config,
|
||||
project_name=eval_project_name,
|
||||
tags=["shouldpass"],
|
||||
)
|
||||
_check_all_feedback_passed(eval_project_name, client)
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Integration test for DallE API Wrapper."""
|
||||
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
|
||||
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test that call returns a URL in the output."""
|
||||
search = DallEAPIWrapper() # type: ignore[call-arg]
|
||||
output = search.run("volcano island")
|
||||
assert "https://oaidalleapi" in output
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Integration test for embedding-based redundant doc filtering."""
|
||||
|
||||
from langchain_community.document_transformers.embeddings_redundant_filter import (
|
||||
EmbeddingsClusteringFilter,
|
||||
EmbeddingsRedundantFilter,
|
||||
_DocumentWithState,
|
||||
)
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
def test_embeddings_redundant_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"Where did all of my cookies go?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
]
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
actual = redundant_filter.transform_documents(docs)
|
||||
assert len(actual) == 2
|
||||
assert set(texts[:2]).intersection([d.page_content for d in actual])
|
||||
|
||||
|
||||
def test_embeddings_redundant_filter_with_state() -> None:
|
||||
texts = ["What happened to all of my cookies?", "foo bar baz"]
|
||||
state = {"embedded_doc": [0.5] * 10}
|
||||
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
actual = redundant_filter.transform_documents(docs)
|
||||
assert len(actual) == 1
|
||||
|
||||
|
||||
def test_embeddings_clustering_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"A cookie is a small, baked sweet treat and you can find it in the cookie",
|
||||
"monsters' jar.",
|
||||
"Cookies are good.",
|
||||
"I have nightmares about the cookie monster.",
|
||||
"The most popular pizza styles are: Neapolitan, New York-style and",
|
||||
"Chicago-style. You can find them on iconic restaurants in major cities.",
|
||||
"Neapolitan pizza: This is the original pizza style,hailing from Naples,",
|
||||
"Italy.",
|
||||
"I wish there were better Italian Pizza restaurants in my neighborhood.",
|
||||
"New York-style pizza: This is characterized by its large, thin crust, and",
|
||||
"generous toppings.",
|
||||
"The first movie to feature a robot was 'A Trip to the Moon' (1902).",
|
||||
"The first movie to feature a robot that could pass for a human was",
|
||||
"'Blade Runner' (1982)",
|
||||
"The first movie to feature a robot that could fall in love with a human",
|
||||
"was 'Her' (2013)",
|
||||
"A robot is a machine capable of carrying out complex actions automatically.",
|
||||
"There are certainly hundreds, if not thousands movies about robots like:",
|
||||
"'Blade Runner', 'Her' and 'A Trip to the Moon'",
|
||||
]
|
||||
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redundant_filter = EmbeddingsClusteringFilter(
|
||||
embeddings=embeddings,
|
||||
num_clusters=3,
|
||||
num_closest=1,
|
||||
sorted=True,
|
||||
)
|
||||
actual = redundant_filter.transform_documents(docs)
|
||||
assert len(actual) == 3
|
||||
assert texts[1] in [d.page_content for d in actual]
|
||||
assert texts[4] in [d.page_content for d in actual]
|
||||
assert texts[11] in [d.page_content for d in actual]
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Integration test for doc reordering."""
|
||||
from langchain_community.document_transformers.long_context_reorder import (
|
||||
LongContextReorder,
|
||||
)
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import Chroma
|
||||
|
||||
|
||||
def test_long_context_reorder() -> None:
|
||||
"""Test Lost in the middle reordering get_relevant_docs."""
|
||||
texts = [
|
||||
"Basquetball is a great sport.",
|
||||
"Fly me to the moon is one of my favourite songs.",
|
||||
"The Celtics are my favourite team.",
|
||||
"This is a document about the Boston Celtics",
|
||||
"I simply love going to the movies",
|
||||
"The Boston Celtics won the game by 20 points",
|
||||
"This is just a random text.",
|
||||
"Elden Ring is one of the best games in the last 15 years.",
|
||||
"L. Kornet is one of the best Celtics players.",
|
||||
"Larry Bird was an iconic NBA player.",
|
||||
]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever(
|
||||
search_kwargs={"k": 10}
|
||||
)
|
||||
reordering = LongContextReorder()
|
||||
docs = retriever.invoke("Tell me about the Celtics")
|
||||
actual = reordering.transform_documents(docs)
|
||||
|
||||
# First 2 and Last 2 elements must contain the most relevant
|
||||
first_and_last = list(actual[:2]) + list(actual[-2:])
|
||||
assert len(actual) == 10
|
||||
assert texts[2] in [d.page_content for d in first_and_last]
|
||||
assert texts[3] in [d.page_content for d in first_and_last]
|
||||
assert texts[5] in [d.page_content for d in first_and_last]
|
||||
assert texts[8] in [d.page_content for d in first_and_last]
|
||||
@@ -1,62 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
from langchain_community.document_transformers.nuclia_text_transform import (
|
||||
NucliaTextTransformer,
|
||||
)
|
||||
from langchain_community.tools.nuclia.tool import NucliaUnderstandingAPI
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
def fakerun(**args: Any) -> Any:
|
||||
async def run(self: Any, **args: Any) -> str:
|
||||
await asyncio.sleep(0.1)
|
||||
data = {
|
||||
"extracted_text": [{"body": {"text": "Hello World"}}],
|
||||
"file_extracted_data": [{"language": "en"}],
|
||||
"field_metadata": [
|
||||
{
|
||||
"metadata": {
|
||||
"metadata": {
|
||||
"paragraphs": [
|
||||
{"end": 66, "sentences": [{"start": 1, "end": 67}]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
return json.dumps(data)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
async def test_nuclia_loader() -> None:
|
||||
with mock.patch(
|
||||
"langchain_community.tools.nuclia.tool.NucliaUnderstandingAPI._arun",
|
||||
new_callable=fakerun,
|
||||
):
|
||||
with mock.patch("os.environ.get", return_value="_a_key_"):
|
||||
nua = NucliaUnderstandingAPI(enable_ml=False)
|
||||
documents = [
|
||||
Document(page_content="Hello, my name is Alice", metadata={}),
|
||||
Document(page_content="Hello, my name is Bob", metadata={}),
|
||||
]
|
||||
nuclia_transformer = NucliaTextTransformer(nua)
|
||||
transformed_documents = await nuclia_transformer.atransform_documents(
|
||||
documents
|
||||
)
|
||||
assert len(transformed_documents) == 2
|
||||
assert (
|
||||
transformed_documents[0].metadata["nuclia"]["file"]["language"] == "en"
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
transformed_documents[1].metadata["nuclia"]["metadata"]["metadata"][
|
||||
"metadata"
|
||||
]["paragraphs"]
|
||||
)
|
||||
== 1
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Test splitting with page numbers included."""
|
||||
import os
|
||||
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
|
||||
def test_pdf_pagesplitter() -> None:
|
||||
"""Test splitting with page numbers included."""
|
||||
script_dir = os.path.dirname(__file__)
|
||||
loader = PyPDFLoader(os.path.join(script_dir, "examples/hello.pdf"))
|
||||
docs = loader.load()
|
||||
assert "page" in docs[0].metadata
|
||||
assert "source" in docs[0].metadata
|
||||
|
||||
faiss_index = FAISS.from_documents(docs, OpenAIEmbeddings())
|
||||
docs = faiss_index.similarity_search("Complete this sentence: Hello", k=1)
|
||||
assert "Hello world" in docs[0].page_content
|
||||
Reference in New Issue
Block a user