Harrison/llm saving (#331)

Co-authored-by: Akash Samant <70665700+asamant21@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2022-12-13 06:46:01 -08:00 committed by GitHub
parent 595cc1ae1a
commit 9bb7195085
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 279 additions and 28 deletions

View File

@ -1,7 +1,19 @@
"""Wrappers on top of large language models APIs."""
from typing import Dict, Type
from langchain.llms.ai21 import AI21
from langchain.llms.base import LLM
from langchain.llms.cohere import Cohere
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.nlpcloud import NLPCloud
from langchain.llms.openai import OpenAI
__all__ = ["Cohere", "NLPCloud", "OpenAI", "HuggingFaceHub"]
__all__ = ["Cohere", "NLPCloud", "OpenAI", "HuggingFaceHub", "AI21"]
type_to_cls_dict: Dict[str, Type[LLM]] = {
"ai21": AI21,
"cohere": Cohere,
"huggingface_hub": HuggingFaceHub,
"nlpcloud": NLPCloud,
"openai": OpenAI,
}

View File

@ -19,7 +19,7 @@ class AI21PenaltyData(BaseModel):
applyToEmojis: bool = True
class AI21(BaseModel, LLM):
class AI21(LLM, BaseModel):
"""Wrapper around AI21 large language models.
To use, you should have the environment variable ``AI21_API_KEY``
@ -96,6 +96,11 @@ class AI21(BaseModel, LLM):
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "ai21"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to AI21's complete endpoint.

View File

@ -1,6 +1,11 @@
"""Base interface for large language models to expose."""
import json
from abc import ABC, abstractmethod
from typing import Any, List, Mapping, NamedTuple, Optional
from pathlib import Path
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Union
import yaml
from pydantic import BaseModel, Extra
class Generation(NamedTuple):
@ -21,9 +26,14 @@ class LLMResult(NamedTuple):
"""For arbitrary LLM provider specific output."""
class LLM(ABC):
class LLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
@ -68,3 +78,46 @@ class LLM(ABC):
"""Get a string representation of the object for printing."""
cls_name = f"\033[1m{self.__class__.__name__}\033[0m"
return f"{cls_name}\nParams: {self._identifying_params}"
@property
@abstractmethod
def _llm_type(self) -> str:
"""Return type of llm."""
def _llm_dict(self) -> Dict:
"""Return a dictionary of the prompt."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the LLM.
Args:
file_path: Path to file to save the LLM to.
Example:
.. code-block:: python
llm.save(file_path="path/llm.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self._llm_dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")

View File

@ -85,6 +85,11 @@ class Cohere(LLM, BaseModel):
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "cohere"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to Cohere's generate endpoint.

View File

@ -74,7 +74,15 @@ class HuggingFaceHub(LLM, BaseModel):
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {**{"repo_id": self.repo_id}, **_model_kwargs}
return {
**{"repo_id": self.repo_id, "task": self.task},
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "huggingface_hub"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to HuggingFace Hub's inference endpoint.

42
langchain/llms/loading.py Normal file
View File

@ -0,0 +1,42 @@
"""Base interface for loading large language models apis."""
import json
from pathlib import Path
from typing import Union
import yaml
from langchain.llms import type_to_cls_dict
from langchain.llms.base import LLM
def load_llm_from_config(config: dict) -> LLM:
"""Load LLM from Config Dict."""
if "_type" not in config:
raise ValueError("Must specify an LLM Type in config")
config_type = config.pop("_type")
if config_type not in type_to_cls_dict:
raise ValueError(f"Loading {config_type} LLM not supported")
llm_cls = type_to_cls_dict[config_type]
return llm_cls(**config)
def load_llm(file: Union[str, Path]) -> LLM:
"""Load LLM from file."""
# Convert file to Path object.
if isinstance(file, str):
file_path = Path(file)
else:
file_path = file
# Load from either json or yaml.
if file_path.suffix == ".json":
with open(file_path) as f:
config = json.load(f)
elif file_path.suffix == ".yaml":
with open(file_path, "r") as f:
config = yaml.safe_load(f)
else:
raise ValueError("File type must be json or yaml")
# Load the LLM from the config now.
return load_llm_from_config(config)

View File

@ -37,6 +37,11 @@ class ManifestWrapper(LLM, BaseModel):
kwargs = self.llm_kwargs or {}
return {**self.client.client.get_model_params(), **kwargs}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "manifest"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to LLM through Manifest."""
if stop is not None and len(stop) != 1:

View File

@ -106,6 +106,11 @@ class NLPCloud(LLM, BaseModel):
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "nlpcloud"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to NLPCloud's create endpoint.

View File

@ -142,7 +142,12 @@ class OpenAI(LLM, BaseModel):
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model_name}, **self._default_params}
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "openai"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to OpenAI's create endpoint.

View File

@ -1,6 +1,9 @@
"""Test AI21 API wrapper."""
from pathlib import Path
from langchain.llms.ai21 import AI21
from langchain.llms.loading import load_llm
def test_ai21_call() -> None:
@ -8,3 +11,11 @@ def test_ai21_call() -> None:
llm = AI21(maxTokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an AI21 LLM."""
llm = AI21(maxTokens=10)
llm.save(file_path=tmp_path / "ai21.yaml")
loaded_llm = load_llm(tmp_path / "ai21.yaml")
assert llm == loaded_llm

View File

@ -1,6 +1,10 @@
"""Test Cohere API wrapper."""
from pathlib import Path
from langchain.llms.cohere import Cohere
from langchain.llms.loading import load_llm
from tests.integration_tests.llms.utils import assert_llm_equality
def test_cohere_call() -> None:
@ -8,3 +12,11 @@ def test_cohere_call() -> None:
llm = Cohere(max_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an Cohere LLM."""
llm = Cohere(max_tokens=10)
llm.save(file_path=tmp_path / "cohere.yaml")
loaded_llm = load_llm(tmp_path / "cohere.yaml")
assert_llm_equality(llm, loaded_llm)

View File

@ -1,8 +1,12 @@
"""Test HuggingFace API wrapper."""
from pathlib import Path
import pytest
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.loading import load_llm
from tests.integration_tests.llms.utils import assert_llm_equality
def test_huggingface_text_generation() -> None:
@ -24,3 +28,11 @@ def test_huggingface_call_error() -> None:
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})
with pytest.raises(ValueError):
llm("Say foo:")
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an HuggingFaceHub LLM."""
llm = HuggingFaceHub(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
llm.save(file_path=tmp_path / "hf.yaml")
loaded_llm = load_llm(tmp_path / "hf.yaml")
assert_llm_equality(llm, loaded_llm)

View File

@ -1,6 +1,10 @@
"""Test NLPCloud API wrapper."""
from pathlib import Path
from langchain.llms.loading import load_llm
from langchain.llms.nlpcloud import NLPCloud
from tests.integration_tests.llms.utils import assert_llm_equality
def test_nlpcloud_call() -> None:
@ -8,3 +12,11 @@ def test_nlpcloud_call() -> None:
llm = NLPCloud(max_length=10)
output = llm("Say foo:")
assert isinstance(output, str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an NLPCloud LLM."""
llm = NLPCloud(max_length=10)
llm.save(file_path=tmp_path / "nlpcloud.yaml")
loaded_llm = load_llm(tmp_path / "nlpcloud.yaml")
assert_llm_equality(llm, loaded_llm)

View File

@ -1,7 +1,10 @@
"""Test OpenAI API wrapper."""
from pathlib import Path
import pytest
from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI
@ -44,3 +47,11 @@ def test_openai_stop_error() -> None:
llm = OpenAI(stop="3", temperature=0)
with pytest.raises(ValueError):
llm("write an ordered list of five items", stop=["\n"])
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an OpenAPI LLM."""
llm = OpenAI(max_tokens=10)
llm.save(file_path=tmp_path / "openai.yaml")
loaded_llm = load_llm(tmp_path / "openai.yaml")
assert loaded_llm == llm

View File

@ -0,0 +1,16 @@
"""Utils for LLM Tests."""
from langchain.llms.base import LLM
def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None:
"""Assert LLM Equality for tests."""
# Check that they are the same type.
assert type(llm) == type(loaded_llm)
# Client field can be session based, so hash is different despite
# all other values being the same, so just assess all other fields
for field in llm.__fields__.keys():
if field != "client":
val = getattr(llm, field)
new_val = getattr(loaded_llm, field)
assert new_val == val

View File

@ -2,17 +2,17 @@
from typing import Any, List, Mapping, Optional
from pydantic import BaseModel
from langchain.agents import Tool, initialize_agent
from langchain.llms.base import LLM
class FakeListLLM(LLM):
class FakeListLLM(LLM, BaseModel):
"""Fake LLM for testing that outputs elements of a list."""
def __init__(self, responses: List[str]):
"""Initialize with list of responses."""
self.responses = responses
self.i = -1
responses: List[str]
i: int = -1
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Increment counter, and then return response in that index."""
@ -25,6 +25,11 @@ class FakeListLLM(LLM):
def _identifying_params(self) -> Mapping[str, Any]:
return {}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake_list"
def test_agent_bad_action() -> None:
"""Test react chain when bad action given."""
@ -33,7 +38,7 @@ def test_agent_bad_action() -> None:
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses)
fake_llm = FakeListLLM(responses=responses)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),

View File

@ -2,6 +2,8 @@
from typing import Any, List, Mapping, Optional, Union
from pydantic import BaseModel
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
from langchain.agents.tools import Tool
from langchain.docstore.base import Docstore
@ -20,13 +22,16 @@ Made in 2022."""
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
class FakeListLLM(LLM):
class FakeListLLM(LLM, BaseModel):
"""Fake LLM for testing that outputs elements of a list."""
def __init__(self, responses: List[str]):
"""Initialize with list of responses."""
self.responses = responses
self.i = -1
responses: List[str]
i: int = -1
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake_list"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Increment counter, and then return response in that index."""
@ -50,7 +55,7 @@ class FakeDocstore(Docstore):
def test_predict_until_observation_normal() -> None:
"""Test predict_until_observation when observation is made normally."""
outputs = ["foo\nAction 1: Search[foo]"]
fake_llm = FakeListLLM(outputs)
fake_llm = FakeListLLM(responses=outputs)
tools = [
Tool("Search", lambda x: x),
Tool("Lookup", lambda x: x),
@ -65,7 +70,7 @@ def test_predict_until_observation_normal() -> None:
def test_predict_until_observation_repeat() -> None:
"""Test when no action is generated initially."""
outputs = ["foo", " Search[foo]"]
fake_llm = FakeListLLM(outputs)
fake_llm = FakeListLLM(responses=outputs)
tools = [
Tool("Search", lambda x: x),
Tool("Lookup", lambda x: x),
@ -84,7 +89,7 @@ def test_react_chain() -> None:
"I should probably lookup\nAction 2: Lookup[made]",
"Ah okay now I know the answer\nAction 3: Finish[2022]",
]
fake_llm = FakeListLLM(responses)
fake_llm = FakeListLLM(responses=responses)
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore())
output = react_chain.run("when was langchain made")
assert output == "2022"
@ -97,7 +102,7 @@ def test_react_chain_bad_action() -> None:
f"I'm turning evil\nAction 1: {bad_action_name}[langchain]",
"Oh well\nAction 2: Finish[curses foiled again]",
]
fake_llm = FakeListLLM(responses)
fake_llm = FakeListLLM(responses=responses)
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore())
output = react_chain.run("when was langchain made")
assert output == "curses foiled again"

View File

@ -2,11 +2,13 @@
from typing import Any, List, Mapping, Optional
from pydantic import BaseModel
from langchain.chains.natbot.base import NatBotChain
from langchain.llms.base import LLM
class FakeLLM(LLM):
class FakeLLM(LLM, BaseModel):
"""Fake LLM wrapper for testing purposes."""
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
@ -16,6 +18,11 @@ class FakeLLM(LLM):
else:
return "bar"
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake"
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}

View File

@ -1,20 +1,25 @@
"""Fake LLM wrapper for testing purposes."""
from typing import Any, List, Mapping, Optional
from pydantic import BaseModel
from langchain.llms.base import LLM
class FakeLLM(LLM):
class FakeLLM(LLM, BaseModel):
"""Fake LLM wrapper for testing purposes."""
def __init__(self, queries: Optional[Mapping] = None):
"""Initialize with optional lookup of queries."""
self._queries = queries
queries: Optional[Mapping] = None
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake"
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
if self._queries is not None:
return self._queries[prompt]
if self.queries is not None:
return self.queries[prompt]
if stop is None:
return "foo"
else:

View File

@ -0,0 +1,15 @@
"""Test LLM saving and loading functions."""
from pathlib import Path
from unittest.mock import patch
from langchain.llms.loading import load_llm
from tests.unit_tests.llms.fake_llm import FakeLLM
@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM})
def test_saving_loading_round_trip(tmp_path: Path) -> None:
"""Test saving/loading a Fake LLM."""
fake_llm = FakeLLM()
fake_llm.save(file_path=tmp_path / "fake_llm.yaml")
loaded_llm = load_llm(tmp_path / "fake_llm.yaml")
assert loaded_llm == fake_llm