mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
standard-tests: migrate to pytest-recording (#31425)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -1,25 +1,85 @@
|
||||
import base64
|
||||
import gzip
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from vcr import VCR # type: ignore[import-untyped]
|
||||
from vcr.serializers import yamlserializer # type: ignore[import-untyped]
|
||||
import yaml
|
||||
from vcr import VCR
|
||||
from vcr.persisters.filesystem import CassetteNotFoundError
|
||||
from vcr.request import Request
|
||||
|
||||
|
||||
class YamlGzipSerializer:
|
||||
@staticmethod
|
||||
def serialize(cassette_dict: dict) -> str:
|
||||
raw = yamlserializer.serialize(cassette_dict).encode("utf-8")
|
||||
compressed = gzip.compress(raw)
|
||||
return base64.b64encode(compressed).decode("ascii")
|
||||
class CustomSerializer:
|
||||
"""Custom serializer for VCR cassettes using YAML and gzip.
|
||||
|
||||
We're using a custom serializer to avoid the default yaml serializer
|
||||
used by VCR, which is not designed to be safe for untrusted input.
|
||||
|
||||
This step is an extra precaution necessary because the cassette files
|
||||
are in compressed YAML format, which makes it more difficult to inspect
|
||||
their contents during development or debugging.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data: str) -> dict:
|
||||
compressed = base64.b64decode(data.encode("ascii"))
|
||||
text = gzip.decompress(compressed).decode("utf-8")
|
||||
return yamlserializer.deserialize(text)
|
||||
def serialize(cassette_dict: dict) -> bytes:
|
||||
"""Convert cassette to YAML and compress it."""
|
||||
cassette_dict["requests"] = [
|
||||
request._to_dict() for request in cassette_dict["requests"]
|
||||
]
|
||||
yml = yaml.safe_dump(cassette_dict)
|
||||
return gzip.compress(yml.encode("utf-8"))
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data: bytes) -> dict:
|
||||
"""Decompress data and convert it from YAML."""
|
||||
text = gzip.decompress(data).decode("utf-8")
|
||||
cassette = yaml.safe_load(text)
|
||||
cassette["requests"] = [
|
||||
Request._from_dict(request) for request in cassette["requests"]
|
||||
]
|
||||
return cassette
|
||||
|
||||
|
||||
class CustomPersister:
|
||||
"""A custom persister for VCR that uses the CustomSerializer."""
|
||||
|
||||
@classmethod
|
||||
def load_cassette(
|
||||
cls, cassette_path: Union[str, PathLike[str]], serializer: CustomSerializer
|
||||
) -> tuple[dict, dict]:
|
||||
"""Load a cassette from a file."""
|
||||
# If cassette path is already Path this is a no-op
|
||||
cassette_path = Path(cassette_path)
|
||||
if not cassette_path.is_file():
|
||||
raise CassetteNotFoundError(
|
||||
f"Cassette file {cassette_path} does not exist."
|
||||
)
|
||||
with cassette_path.open(mode="rb") as f:
|
||||
data = f.read()
|
||||
deser = serializer.deserialize(data)
|
||||
return deser["requests"], deser["responses"]
|
||||
|
||||
@staticmethod
|
||||
def save_cassette(
|
||||
cassette_path: Union[str, PathLike[str]],
|
||||
cassette_dict: dict,
|
||||
serializer: CustomSerializer,
|
||||
) -> None:
|
||||
"""Save a cassette to a file."""
|
||||
data = serializer.serialize(cassette_dict)
|
||||
# if cassette path is already Path this is no operation
|
||||
cassette_path = Path(cassette_path)
|
||||
cassette_folder = cassette_path.parent
|
||||
if not cassette_folder.exists():
|
||||
cassette_folder.mkdir(parents=True)
|
||||
with cassette_path.open("wb") as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
# A list of headers that should be filtered out of the cassettes.
|
||||
# These are typically associated with sensitive information and should
|
||||
# not be stored in cassettes.
|
||||
_BASE_FILTER_HEADERS = [
|
||||
("authorization", "PLACEHOLDER"),
|
||||
("x-api-key", "PLACEHOLDER"),
|
||||
@@ -29,14 +89,15 @@ _BASE_FILTER_HEADERS = [
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _base_vcr_config() -> dict:
|
||||
"""
|
||||
Configuration that every cassette will receive.
|
||||
"""Configuration that every cassette will receive.
|
||||
|
||||
(Anything permitted by vcr.VCR(**kwargs) can be put here.)
|
||||
"""
|
||||
return {
|
||||
"record_mode": "once",
|
||||
"filter_headers": _BASE_FILTER_HEADERS.copy(),
|
||||
"match_on": ["method", "scheme", "host", "port", "path", "query"],
|
||||
"match_on": ["method", "uri", "body"],
|
||||
"allow_playback_repeats": True,
|
||||
"decode_compressed_response": True,
|
||||
"cassette_library_dir": "tests/cassettes",
|
||||
"path_transformer": VCR.ensure_suffix(".yaml"),
|
||||
|
@@ -6,7 +6,6 @@ from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import vcr # type: ignore[import-untyped]
|
||||
from langchain_core._api import warn_deprecated
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.language_models import BaseChatModel, GenericFakeChatModel
|
||||
@@ -31,6 +30,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import Field as FieldV1
|
||||
from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-untyped]
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
from vcr.cassette import Cassette
|
||||
|
||||
from langchain_tests.unit_tests.chat_models import (
|
||||
ChatModelTests,
|
||||
@@ -592,7 +592,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
:caption: tests/conftest.py
|
||||
|
||||
import pytest
|
||||
from langchain_tests.conftest import YamlGzipSerializer
|
||||
from langchain_tests.conftest import CustomPersister, CustomSerializer
|
||||
from langchain_tests.conftest import _base_vcr_config as _base_vcr_config
|
||||
from vcr import VCR
|
||||
|
||||
@@ -621,24 +621,26 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vcr(vcr_config: dict) -> VCR:
|
||||
\"\"\"Override the default vcr fixture to include custom serializers\"\"\"
|
||||
my_vcr = VCR(**vcr_config)
|
||||
my_vcr.register_serializer("yaml.gz", YamlGzipSerializer)
|
||||
return my_vcr
|
||||
def pytest_recording_configure(config: dict, vcr: VCR) -> None:
|
||||
vcr.register_persister(CustomPersister())
|
||||
vcr.register_serializer("yaml.gz", CustomSerializer())
|
||||
|
||||
|
||||
You can inspect the contents of the compressed cassettes (e.g., to
|
||||
ensure no sensitive information is recorded) using the serializer:
|
||||
ensure no sensitive information is recorded) using
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
gunzip -k /path/to/tests/cassettes/TestClass_test.yaml.gz
|
||||
|
||||
or by using the serializer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_tests.conftest import YamlGzipSerializer
|
||||
from langchain_tests.conftest import CustomPersister, CustomSerializer
|
||||
|
||||
with open("/path/to/tests/cassettes/TestClass_test.yaml.gz", "r") as f:
|
||||
data = f.read()
|
||||
|
||||
YamlGzipSerializer.deserialize(data)
|
||||
cassette_path = "/path/to/tests/cassettes/TestClass_test.yaml.gz"
|
||||
requests, responses = CustomPersister().load_cassette(path, CustomSerializer())
|
||||
|
||||
3. Run tests to generate VCR cassettes.
|
||||
|
||||
@@ -2826,8 +2828,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert isinstance(response, AIMessage)
|
||||
|
||||
@pytest.mark.benchmark
|
||||
@pytest.mark.vcr
|
||||
def test_stream_time(
|
||||
self, model: BaseChatModel, benchmark: BenchmarkFixture, vcr: vcr.VCR
|
||||
self, model: BaseChatModel, benchmark: BenchmarkFixture, vcr: Cassette
|
||||
) -> None:
|
||||
"""Test that streaming does not introduce undue overhead.
|
||||
|
||||
@@ -2857,12 +2860,13 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
pytest.skip("VCR not set up.")
|
||||
|
||||
def _run() -> None:
|
||||
cassette_name = f"{self.__class__.__name__}_test_stream_time"
|
||||
with vcr.use_cassette(cassette_name, record_mode="once"):
|
||||
for _ in model.stream("Write a story about a cat."):
|
||||
pass
|
||||
for _ in model.stream("Write a story about a cat."):
|
||||
pass
|
||||
|
||||
benchmark(_run)
|
||||
if not vcr.responses:
|
||||
_run()
|
||||
else:
|
||||
benchmark(_run)
|
||||
|
||||
def invoke_with_audio_input(self, *, stream: bool = False) -> AIMessage:
|
||||
""":private:"""
|
||||
|
@@ -693,7 +693,7 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
:caption: tests/conftest.py
|
||||
|
||||
import pytest
|
||||
from langchain_tests.conftest import YamlGzipSerializer
|
||||
from langchain_tests.conftest import CustomPersister, CustomSerializer
|
||||
from langchain_tests.conftest import _base_vcr_config as _base_vcr_config
|
||||
from vcr import VCR
|
||||
|
||||
@@ -722,24 +722,26 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vcr(vcr_config: dict) -> VCR:
|
||||
\"\"\"Override the default vcr fixture to include custom serializers\"\"\"
|
||||
my_vcr = VCR(**vcr_config)
|
||||
my_vcr.register_serializer("yaml.gz", YamlGzipSerializer)
|
||||
return my_vcr
|
||||
def pytest_recording_configure(config: dict, vcr: VCR) -> None:
|
||||
vcr.register_persister(CustomPersister())
|
||||
vcr.register_serializer("yaml.gz", CustomSerializer())
|
||||
|
||||
|
||||
You can inspect the contents of the compressed cassettes (e.g., to
|
||||
ensure no sensitive information is recorded) using the serializer:
|
||||
ensure no sensitive information is recorded) using
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
gunzip -k /path/to/tests/cassettes/TestClass_test.yaml.gz
|
||||
|
||||
or by using the serializer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_tests.conftest import YamlGzipSerializer
|
||||
from langchain_tests.conftest import CustomPersister, CustomSerializer
|
||||
|
||||
with open("/path/to/tests/cassettes/TestClass_test.yaml.gz", "r") as f:
|
||||
data = f.read()
|
||||
|
||||
YamlGzipSerializer.deserialize(data)
|
||||
cassette_path = "/path/to/tests/cassettes/TestClass_test.yaml.gz"
|
||||
requests, responses = CustomPersister().load_cassette(path, CustomSerializer())
|
||||
|
||||
3. Run tests to generate VCR cassettes.
|
||||
|
||||
|
Reference in New Issue
Block a user