"""Test functionality around the simple pipeline chain."""

from typing import Dict, List

import pytest
from pydantic import BaseModel

from langchain.chains.base import Chain
from langchain.chains.simple_pipeline import SimplePipeline


class FakeChain(Chain, BaseModel):
    """Fake chain for testing purposes."""

    input_variables: List[str]
    output_variables: List[str]

    @property
    def input_keys(self) -> List[str]:
        """Input keys this chain returns."""
        return self.input_variables

    @property
    def output_keys(self) -> List[str]:
        """Input keys this chain returns."""
        return self.output_variables

    def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
        outputs = {}
        for var in self.output_variables:
            variables = [inputs[k] for k in self.input_variables]
            outputs[var] = " ".join(variables) + "foo"
        return outputs


def test_pipeline_functionality() -> None:
    """Test simple pipeline functionality."""
    chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
    chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
    pipeline = SimplePipeline(chains=[chain_1, chain_2])
    output = pipeline({"input": "123"})
    expected_output = {"output": "123foofoo", "input": "123"}
    assert output == expected_output


def test_multi_input_errors() -> None:
    """Test pipeline errors if multiple input variables are expected."""
    chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
    chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
    with pytest.raises(ValueError):
        SimplePipeline(chains=[chain_1, chain_2])


def test_multi_output_errors() -> None:
    """Test pipeline errors if multiple output variables are expected."""
    chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"])
    chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
    with pytest.raises(ValueError):
        SimplePipeline(chains=[chain_1, chain_2])