mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 03:26:17 +00:00
Harrison/serialize llm chain (#671)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
"""Test LLM chain."""
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict, List, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.loading import load_chain
|
||||
from langchain.prompts.base import BaseOutputParser
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
@@ -24,6 +27,16 @@ def fake_llm_chain() -> LLMChain:
|
||||
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
|
||||
|
||||
|
||||
@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM})
|
||||
def test_serialization(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test serialization."""
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
file = temp_dir + "/llm.json"
|
||||
fake_llm_chain.save(file)
|
||||
loaded_chain = load_chain(file)
|
||||
assert loaded_chain == fake_llm_chain
|
||||
|
||||
|
||||
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test error is raised if inputs are missing."""
|
||||
with pytest.raises(ValueError):
|
||||
|
@@ -12,7 +12,7 @@ def test_caching() -> None:
|
||||
"""Test caching behavior."""
|
||||
langchain.llm_cache = InMemoryCache()
|
||||
llm = FakeLLM()
|
||||
params = llm._llm_dict()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
@@ -50,7 +50,7 @@ def test_custom_caching() -> None:
|
||||
engine = create_engine("sqlite://")
|
||||
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
|
||||
llm = FakeLLM()
|
||||
params = llm._llm_dict()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
|
Reference in New Issue
Block a user