introduce output parser (#250)

This commit is contained in:
Harrison Chase
2022-12-03 13:28:07 -08:00
committed by GitHub
parent b4762dfff0
commit db58032973
3 changed files with 57 additions and 3 deletions

View File

@@ -1,11 +1,22 @@
"""Test LLM chain."""
from typing import Dict, List, Union
import pytest
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BaseOutputParser
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeOutputParser(BaseOutputParser):
"""Fake output parser class for testing."""
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
"""Parse by splitting."""
return text.split()
@pytest.fixture
def fake_llm_chain() -> LLMChain:
"""Fake LLM chain for testing purposes."""
@@ -34,3 +45,14 @@ def test_predict_method(fake_llm_chain: LLMChain) -> None:
"""Test predict method works."""
output = fake_llm_chain.predict(bar="baz")
assert output == "foo"
def test_predict_and_parse() -> None:
"""Test parsing ability."""
prompt = PromptTemplate(
input_variables=["foo"], template="{foo}", output_parser=FakeOutputParser()
)
llm = FakeLLM(queries={"foo": "foo bar"})
chain = LLMChain(prompt=prompt, llm=llm)
output = chain.predict_and_parse(foo="foo")
assert output == ["foo", "bar"]