Add HuggingFacePipeline LLM (#353)

https://github.com/hwchase17/langchain/issues/354

Add support for running your own HF pipeline locally. This would allow
you to get a lot more dynamic with what HF features and models you
support since you wouldn't be beholden to what is hosted in HF hub. You
could also do stuff with HF Optimum to quantize your models and stuff to
get pretty fast inference even running on a laptop.
This commit is contained in:
mrbean
2022-12-17 10:00:04 -05:00
committed by GitHub
parent 2eef76ed3f
commit fe6695b9e7
7 changed files with 332 additions and 39 deletions

View File

@@ -0,0 +1,41 @@
"""Test HuggingFace Pipeline wrapper."""
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms.loading import load_llm
from tests.integration_tests.llms.utils import assert_llm_equality
def test_huggingface_pipeline_text_generation() -> None:
"""Test valid call to HuggingFace text generation model."""
llm = HuggingFacePipeline.from_model_id(
model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10}
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an HuggingFaceHub LLM."""
llm = HuggingFacePipeline.from_model_id(
model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10}
)
llm.save(file_path=tmp_path / "hf.yaml")
loaded_llm = load_llm(tmp_path / "hf.yaml")
assert_llm_equality(llm, loaded_llm)
def test_init_with_pipeline() -> None:
"""Test initialization with a HF pipeline."""
model_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
pipe = pipeline(
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
)
llm = HuggingFacePipeline(pipeline=pipe)
output = llm("Say foo:")
assert isinstance(output, str)

View File

@@ -10,7 +10,7 @@ def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None:
# Client field can be session based, so hash is different despite
# all other values being the same, so just assess all other fields
for field in llm.__fields__.keys():
if field != "client":
if field != "client" and field != "pipeline":
val = getattr(llm, field)
new_val = getattr(loaded_llm, field)
assert new_val == val