Harrison/serialize llm chain (#671)

This commit is contained in:
Harrison Chase
2023-01-24 21:36:19 -08:00
committed by GitHub
parent 499e54edda
commit 0ffeabd14f
14 changed files with 578 additions and 14 deletions

View File

@@ -1,9 +1,12 @@
"""Test LLM chain."""
from tempfile import TemporaryDirectory
from typing import Dict, List, Union
from unittest.mock import patch
import pytest
from langchain.chains.llm import LLMChain
from langchain.chains.loading import load_chain
from langchain.prompts.base import BaseOutputParser
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -24,6 +27,16 @@ def fake_llm_chain() -> LLMChain:
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM})
def test_serialization(fake_llm_chain: LLMChain) -> None:
"""Test serialization."""
with TemporaryDirectory() as temp_dir:
file = temp_dir + "/llm.json"
fake_llm_chain.save(file)
loaded_chain = load_chain(file)
assert loaded_chain == fake_llm_chain
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
"""Test error is raised if inputs are missing."""
with pytest.raises(ValueError):