mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
Harrison/llm saving (#331)
Co-authored-by: Akash Samant <70665700+asamant21@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
"""Test AI21 API wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.llms.ai21 import AI21
|
||||
from langchain.llms.loading import load_llm
|
||||
|
||||
|
||||
def test_ai21_call() -> None:
|
||||
@@ -8,3 +11,11 @@ def test_ai21_call() -> None:
|
||||
llm = AI21(maxTokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an AI21 LLM."""
|
||||
llm = AI21(maxTokens=10)
|
||||
llm.save(file_path=tmp_path / "ai21.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "ai21.yaml")
|
||||
assert llm == loaded_llm
|
||||
|
@@ -1,6 +1,10 @@
|
||||
"""Test Cohere API wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.llms.cohere import Cohere
|
||||
from langchain.llms.loading import load_llm
|
||||
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
|
||||
def test_cohere_call() -> None:
|
||||
@@ -8,3 +12,11 @@ def test_cohere_call() -> None:
|
||||
llm = Cohere(max_tokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an Cohere LLM."""
|
||||
llm = Cohere(max_tokens=10)
|
||||
llm.save(file_path=tmp_path / "cohere.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "cohere.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
@@ -1,8 +1,12 @@
|
||||
"""Test HuggingFace API wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain.llms.loading import load_llm
|
||||
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
|
||||
def test_huggingface_text_generation() -> None:
|
||||
@@ -24,3 +28,11 @@ def test_huggingface_call_error() -> None:
|
||||
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})
|
||||
with pytest.raises(ValueError):
|
||||
llm("Say foo:")
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an HuggingFaceHub LLM."""
|
||||
llm = HuggingFaceHub(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
|
||||
llm.save(file_path=tmp_path / "hf.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "hf.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
@@ -1,6 +1,10 @@
|
||||
"""Test NLPCloud API wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.llms.loading import load_llm
|
||||
from langchain.llms.nlpcloud import NLPCloud
|
||||
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
|
||||
def test_nlpcloud_call() -> None:
|
||||
@@ -8,3 +12,11 @@ def test_nlpcloud_call() -> None:
|
||||
llm = NLPCloud(max_length=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an NLPCloud LLM."""
|
||||
llm = NLPCloud(max_length=10)
|
||||
llm.save(file_path=tmp_path / "nlpcloud.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "nlpcloud.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
@@ -1,7 +1,10 @@
|
||||
"""Test OpenAI API wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.llms.loading import load_llm
|
||||
from langchain.llms.openai import OpenAI
|
||||
|
||||
|
||||
@@ -44,3 +47,11 @@ def test_openai_stop_error() -> None:
|
||||
llm = OpenAI(stop="3", temperature=0)
|
||||
with pytest.raises(ValueError):
|
||||
llm("write an ordered list of five items", stop=["\n"])
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an OpenAPI LLM."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
llm.save(file_path=tmp_path / "openai.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "openai.yaml")
|
||||
assert loaded_llm == llm
|
||||
|
16
tests/integration_tests/llms/utils.py
Normal file
16
tests/integration_tests/llms/utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Utils for LLM Tests."""
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None:
|
||||
"""Assert LLM Equality for tests."""
|
||||
# Check that they are the same type.
|
||||
assert type(llm) == type(loaded_llm)
|
||||
# Client field can be session based, so hash is different despite
|
||||
# all other values being the same, so just assess all other fields
|
||||
for field in llm.__fields__.keys():
|
||||
if field != "client":
|
||||
val = getattr(llm, field)
|
||||
new_val = getattr(loaded_llm, field)
|
||||
assert new_val == val
|
Reference in New Issue
Block a user