mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 13:07:58 +00:00
introduce output parser (#250)
This commit is contained in:
@@ -1,11 +1,22 @@
|
||||
"""Test LLM chain."""
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BaseOutputParser
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
class FakeOutputParser(BaseOutputParser):
|
||||
"""Fake output parser class for testing."""
|
||||
|
||||
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Parse by splitting."""
|
||||
return text.split()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_chain() -> LLMChain:
|
||||
"""Fake LLM chain for testing purposes."""
|
||||
@@ -34,3 +45,14 @@ def test_predict_method(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test predict method works."""
|
||||
output = fake_llm_chain.predict(bar="baz")
|
||||
assert output == "foo"
|
||||
|
||||
|
||||
def test_predict_and_parse() -> None:
|
||||
"""Test parsing ability."""
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["foo"], template="{foo}", output_parser=FakeOutputParser()
|
||||
)
|
||||
llm = FakeLLM(queries={"foo": "foo bar"})
|
||||
chain = LLMChain(prompt=prompt, llm=llm)
|
||||
output = chain.predict_and_parse(foo="foo")
|
||||
assert output == ["foo", "bar"]
|
||||
|
Reference in New Issue
Block a user