mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-24 20:20:50 +00:00
109 lines
3.5 KiB
Python
109 lines
3.5 KiB
Python
import gzip
|
|
from os import PathLike
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import pytest
|
|
import yaml
|
|
from vcr import VCR
|
|
from vcr.persisters.filesystem import CassetteNotFoundError
|
|
from vcr.request import Request
|
|
|
|
|
|
class CustomSerializer:
|
|
"""Custom serializer for VCR cassettes using YAML and gzip.
|
|
|
|
We're using a custom serializer to avoid the default yaml serializer
|
|
used by VCR, which is not designed to be safe for untrusted input.
|
|
|
|
This step is an extra precaution necessary because the cassette files
|
|
are in compressed YAML format, which makes it more difficult to inspect
|
|
their contents during development or debugging.
|
|
"""
|
|
|
|
@staticmethod
|
|
def serialize(cassette_dict: dict) -> bytes:
|
|
"""Convert cassette to YAML and compress it."""
|
|
cassette_dict["requests"] = [
|
|
request._to_dict() for request in cassette_dict["requests"]
|
|
]
|
|
yml = yaml.safe_dump(cassette_dict)
|
|
return gzip.compress(yml.encode("utf-8"))
|
|
|
|
@staticmethod
|
|
def deserialize(data: bytes) -> dict:
|
|
"""Decompress data and convert it from YAML."""
|
|
text = gzip.decompress(data).decode("utf-8")
|
|
cassette = yaml.safe_load(text)
|
|
cassette["requests"] = [
|
|
Request._from_dict(request) for request in cassette["requests"]
|
|
]
|
|
return cassette
|
|
|
|
|
|
class CustomPersister:
|
|
"""A custom persister for VCR that uses the CustomSerializer."""
|
|
|
|
@classmethod
|
|
def load_cassette(
|
|
cls, cassette_path: Union[str, PathLike[str]], serializer: CustomSerializer
|
|
) -> tuple[dict, dict]:
|
|
"""Load a cassette from a file."""
|
|
# If cassette path is already Path this is a no-op
|
|
cassette_path = Path(cassette_path)
|
|
if not cassette_path.is_file():
|
|
msg = f"Cassette file {cassette_path} does not exist."
|
|
raise CassetteNotFoundError(msg)
|
|
with cassette_path.open(mode="rb") as f:
|
|
data = f.read()
|
|
deser = serializer.deserialize(data)
|
|
return deser["requests"], deser["responses"]
|
|
|
|
@staticmethod
|
|
def save_cassette(
|
|
cassette_path: Union[str, PathLike[str]],
|
|
cassette_dict: dict,
|
|
serializer: CustomSerializer,
|
|
) -> None:
|
|
"""Save a cassette to a file."""
|
|
data = serializer.serialize(cassette_dict)
|
|
# if cassette path is already Path this is no operation
|
|
cassette_path = Path(cassette_path)
|
|
cassette_folder = cassette_path.parent
|
|
if not cassette_folder.exists():
|
|
cassette_folder.mkdir(parents=True)
|
|
with cassette_path.open("wb") as f:
|
|
f.write(data)
|
|
|
|
|
|
# A list of headers that should be filtered out of the cassettes.
|
|
# These are typically associated with sensitive information and should
|
|
# not be stored in cassettes.
|
|
_BASE_FILTER_HEADERS = [
|
|
("authorization", "PLACEHOLDER"),
|
|
("x-api-key", "PLACEHOLDER"),
|
|
("api-key", "PLACEHOLDER"),
|
|
]
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def _base_vcr_config() -> dict:
|
|
"""Configuration that every cassette will receive.
|
|
|
|
(Anything permitted by vcr.VCR(**kwargs) can be put here.)
|
|
"""
|
|
return {
|
|
"record_mode": "once",
|
|
"filter_headers": _BASE_FILTER_HEADERS.copy(),
|
|
"match_on": ["method", "uri", "body"],
|
|
"allow_playback_repeats": True,
|
|
"decode_compressed_response": True,
|
|
"cassette_library_dir": "tests/cassettes",
|
|
"path_transformer": VCR.ensure_suffix(".yaml"),
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vcr_config(_base_vcr_config: dict) -> dict:
|
|
return _base_vcr_config
|