mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-21 06:58:02 +00:00
consolidating logic for when a chain is able to run with single input text, single output text open to feedback on naming, logic, usefulness
51 lines
1.2 KiB
Python
51 lines
1.2 KiB
Python
"""Test logic on base chain class."""
|
|
from typing import Dict, List
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
|
class FakeChain(Chain, BaseModel):
|
|
"""Fake chain class for testing purposes."""
|
|
|
|
be_correct: bool = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Input key of foo."""
|
|
return ["foo"]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Output key of bar."""
|
|
return ["bar"]
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
if self.be_correct:
|
|
return {"bar": "baz"}
|
|
else:
|
|
return {"baz": "bar"}
|
|
|
|
|
|
def test_bad_inputs() -> None:
|
|
"""Test errors are raised if input keys are not found."""
|
|
chain = FakeChain()
|
|
with pytest.raises(ValueError):
|
|
chain({"foobar": "baz"})
|
|
|
|
|
|
def test_bad_outputs() -> None:
|
|
"""Test errors are raised if outputs keys are not found."""
|
|
chain = FakeChain(be_correct=False)
|
|
with pytest.raises(ValueError):
|
|
chain({"foo": "baz"})
|
|
|
|
|
|
def test_correct_call() -> None:
|
|
"""Test correct call of fake chain."""
|
|
chain = FakeChain()
|
|
output = chain({"foo": "bar"})
|
|
assert output == {"foo": "bar", "bar": "baz"}
|