mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
Add HuggingFacePipeline LLM (#353)
https://github.com/hwchase17/langchain/issues/354 Add support for running your own HF pipeline locally. This would allow you to get a lot more dynamic with what HF features and models you support since you wouldn't be beholden to what is hosted in HF hub. You could also do stuff with HF Optimum to quantize your models and stuff to get pretty fast inference even running on a laptop.
This commit is contained in:
parent
2eef76ed3f
commit
fe6695b9e7
@ -18,6 +18,7 @@ from langchain.chains import (
|
||||
)
|
||||
from langchain.docstore import InMemoryDocstore, Wikipedia
|
||||
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.logger import BaseLogger, StdOutLogger
|
||||
from langchain.prompts import (
|
||||
BasePromptTemplate,
|
||||
@ -50,6 +51,7 @@ __all__ = [
|
||||
"ReActChain",
|
||||
"Wikipedia",
|
||||
"HuggingFaceHub",
|
||||
"HuggingFacePipeline",
|
||||
"SQLDatabase",
|
||||
"SQLDatabaseChain",
|
||||
"FAISS",
|
||||
|
@ -5,10 +5,18 @@ from langchain.llms.ai21 import AI21
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.cohere import Cohere
|
||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.llms.nlpcloud import NLPCloud
|
||||
from langchain.llms.openai import OpenAI
|
||||
|
||||
__all__ = ["Cohere", "NLPCloud", "OpenAI", "HuggingFaceHub", "AI21"]
|
||||
__all__ = [
|
||||
"Cohere",
|
||||
"NLPCloud",
|
||||
"OpenAI",
|
||||
"HuggingFaceHub",
|
||||
"HuggingFacePipeline",
|
||||
"AI21",
|
||||
]
|
||||
|
||||
type_to_cls_dict: Dict[str, Type[LLM]] = {
|
||||
"ai21": AI21,
|
||||
@ -16,4 +24,5 @@ type_to_cls_dict: Dict[str, Type[LLM]] = {
|
||||
"huggingface_hub": HuggingFaceHub,
|
||||
"nlpcloud": NLPCloud,
|
||||
"openai": OpenAI,
|
||||
"huggingface_pipeline": HuggingFacePipeline,
|
||||
}
|
||||
|
118
langchain/llms/huggingface_pipeline.py
Normal file
118
langchain/llms/huggingface_pipeline.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""Wrapper around HuggingFace Pipeline APIs."""
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
|
||||
|
||||
class HuggingFacePipeline(LLM, BaseModel):
|
||||
"""Wrapper around HuggingFace Pipeline API.
|
||||
|
||||
To use, you should have the ``transformers`` python package installed.
|
||||
|
||||
Only supports `text-generation` and `text2text-generation` for now.
|
||||
|
||||
Example using from_model_id:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
hf = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2", task="text-generation"
|
||||
)
|
||||
Example passing pipeline in directly:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
model_id = "gpt2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
pipe = pipeline(
|
||||
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
|
||||
)
|
||||
hf = HuggingFacePipeline(pipeline=pipe
|
||||
"""
|
||||
|
||||
pipeline: Any #: :meta private:
|
||||
model_id: str = DEFAULT_MODEL_ID
|
||||
"""Model name to use."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@classmethod
|
||||
def from_model_id(
|
||||
cls,
|
||||
model_id: str,
|
||||
task: str,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLM:
|
||||
"""Construct the pipeline object from model_id and task."""
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import pipeline as hf_pipeline
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
pipeline = hf_pipeline(
|
||||
task=task, model=model, tokenizer=tokenizer, **model_kwargs
|
||||
)
|
||||
if pipeline.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
f"Got invalid task {pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
|
||||
return cls(
|
||||
pipeline=pipeline,
|
||||
model_id=model_id,
|
||||
model_kwargs=model_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"model_id": self.model_id},
|
||||
**{"model_kwargs": self.model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "huggingface_pipeline"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
response = self.pipeline(text_inputs=prompt)
|
||||
if self.pipeline.task == "text-generation":
|
||||
# Text generation return includes the starter text.
|
||||
text = response[0]["generated_text"][len(prompt) :]
|
||||
elif self.pipeline.task == "text2text-generation":
|
||||
text = response[0]["generated_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
192
poetry.lock
generated
192
poetry.lock
generated
@ -640,7 +640,7 @@ arrow = ">=0.15.0"
|
||||
|
||||
[[package]]
|
||||
name = "isort"
|
||||
version = "5.11.1"
|
||||
version = "5.11.2"
|
||||
description = "A Python utility / library to sort Python imports."
|
||||
category = "dev"
|
||||
optional = false
|
||||
@ -1181,6 +1181,54 @@ category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cublas-cu11"
|
||||
version = "11.10.3.66"
|
||||
description = "CUBLAS native runtime libraries"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3"
|
||||
|
||||
[package.dependencies]
|
||||
setuptools = "*"
|
||||
wheel = "*"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-nvrtc-cu11"
|
||||
version = "11.7.99"
|
||||
description = "NVRTC native runtime libraries"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3"
|
||||
|
||||
[package.dependencies]
|
||||
setuptools = "*"
|
||||
wheel = "*"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-runtime-cu11"
|
||||
version = "11.7.99"
|
||||
description = "CUDA Runtime native Libraries"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3"
|
||||
|
||||
[package.dependencies]
|
||||
setuptools = "*"
|
||||
wheel = "*"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cudnn-cu11"
|
||||
version = "8.5.0.96"
|
||||
description = "cuDNN runtime libraries"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3"
|
||||
|
||||
[package.dependencies]
|
||||
setuptools = "*"
|
||||
wheel = "*"
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "22.0"
|
||||
@ -1750,7 +1798,7 @@ python-versions = ">=3.6"
|
||||
|
||||
[[package]]
|
||||
name = "spacy"
|
||||
version = "3.4.3"
|
||||
version = "3.4.4"
|
||||
description = "Industrial-strength Natural Language Processing (NLP) in Python"
|
||||
category = "main"
|
||||
optional = true
|
||||
@ -1769,6 +1817,7 @@ preshed = ">=3.0.2,<3.1.0"
|
||||
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<1.11.0"
|
||||
requests = ">=2.13.0,<3.0.0"
|
||||
setuptools = "*"
|
||||
smart-open = ">=5.2.1,<7.0.0"
|
||||
spacy-legacy = ">=3.0.10,<3.1.0"
|
||||
spacy-loggers = ">=1.0.0,<2.0.0"
|
||||
srsly = ">=2.4.3,<3.0.0"
|
||||
@ -1997,6 +2046,24 @@ category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "1.13.1"
|
||||
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3.7.0"
|
||||
|
||||
[package.dependencies]
|
||||
nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\""}
|
||||
nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
|
||||
nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
|
||||
nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\""}
|
||||
typing-extensions = "*"
|
||||
|
||||
[package.extras]
|
||||
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.2"
|
||||
@ -2227,6 +2294,17 @@ docs = ["Sphinx (>=3.4)", "sphinx-rtd-theme (>=0.5)"]
|
||||
optional = ["python-socks", "wsaccel"]
|
||||
test = ["websockets"]
|
||||
|
||||
[[package]]
|
||||
name = "wheel"
|
||||
version = "0.38.4"
|
||||
description = "A built-package format for Python"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[package.extras]
|
||||
test = ["pytest (>=3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "widgetsnbextension"
|
||||
version = "4.0.4"
|
||||
@ -2260,13 +2338,13 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker
|
||||
testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
|
||||
|
||||
[extras]
|
||||
all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken"]
|
||||
llms = ["manifest-ml"]
|
||||
all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"]
|
||||
llms = ["manifest-ml", "torch", "transformers"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "45b0665b465cf551601e1a18f5b99c9a66c99313de21292f2b897d2ec42d2f6a"
|
||||
content-hash = "d9ed2c5e1b2c51d7f8a9f74c858ab5058db14b7d6ca542777d7ecd07ccef5ee8"
|
||||
|
||||
[metadata.files]
|
||||
anyio = [
|
||||
@ -2757,8 +2835,8 @@ isoduration = [
|
||||
{file = "isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9"},
|
||||
]
|
||||
isort = [
|
||||
{file = "isort-5.11.1-py3-none-any.whl", hash = "sha256:bf02c95f1fe615ebbe13a619cfed1619ddfe8941274c9e3de3143adca406cb02"},
|
||||
{file = "isort-5.11.1.tar.gz", hash = "sha256:7c5bd998504826b6f1e6f2f98b533976b066baba29b8bae83fdeefd0b89c6b70"},
|
||||
{file = "isort-5.11.2-py3-none-any.whl", hash = "sha256:e486966fba83f25b8045f8dd7455b0a0d1e4de481e1d7ce4669902d9fb85e622"},
|
||||
{file = "isort-5.11.2.tar.gz", hash = "sha256:dd8bbc5c0990f2a095d754e50360915f73b4c26fc82733eb5bfc6b48396af4d2"},
|
||||
]
|
||||
jedi = [
|
||||
{file = "jedi-0.18.2-py2.py3-none-any.whl", hash = "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e"},
|
||||
@ -3084,6 +3162,23 @@ numpy = [
|
||||
{file = "numpy-1.23.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:01dd17cbb340bf0fc23981e52e1d18a9d4050792e8fb8363cecbf066a84b827d"},
|
||||
{file = "numpy-1.23.5.tar.gz", hash = "sha256:1b1766d6f397c18153d40015ddfc79ddb715cabadc04d2d228d4e5a8bc4ded1a"},
|
||||
]
|
||||
nvidia-cublas-cu11 = [
|
||||
{file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"},
|
||||
{file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"},
|
||||
]
|
||||
nvidia-cuda-nvrtc-cu11 = [
|
||||
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"},
|
||||
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"},
|
||||
{file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3"},
|
||||
]
|
||||
nvidia-cuda-runtime-cu11 = [
|
||||
{file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"},
|
||||
{file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"},
|
||||
]
|
||||
nvidia-cudnn-cu11 = [
|
||||
{file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"},
|
||||
{file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"},
|
||||
]
|
||||
packaging = [
|
||||
{file = "packaging-22.0-py3-none-any.whl", hash = "sha256:957e2148ba0e1a3b282772e791ef1d8083648bc131c8ab0c1feba110ce1146c3"},
|
||||
{file = "packaging-22.0.tar.gz", hash = "sha256:2198ec20bd4c017b8f9717e00f0c8714076fc2fd93816750ab48e2c41de2cfd3"},
|
||||
@ -3622,34 +3717,34 @@ soupsieve = [
|
||||
{file = "soupsieve-2.3.2.post1.tar.gz", hash = "sha256:fc53893b3da2c33de295667a0e19f078c14bf86544af307354de5fcf12a3f30d"},
|
||||
]
|
||||
spacy = [
|
||||
{file = "spacy-3.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e546b314f619502ae03e5eb9a0cfd09ca7a9db265bcdd8a3af83cfb0f1432e55"},
|
||||
{file = "spacy-3.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ded11aa8966236aab145b4d2d024b3eb61ac50078362d77d9ed7d8c240ef0f4a"},
|
||||
{file = "spacy-3.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:462e141f514d78cff85685b5b12eb8cadac0bad2f7820149cbe18d03ccb2e59c"},
|
||||
{file = "spacy-3.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c966d25b3f3e49f5de08546b3638928f49678c365cbbebd0eec28f74e0adb539"},
|
||||
{file = "spacy-3.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:2ddba486c4c981abe6f1e3fd72648dc8811966e5f0e05808f9c9fab155c388d7"},
|
||||
{file = "spacy-3.4.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c87117dd335fba44d1c0d77602f0763c3addf4e7ef9bdbe9a495466c3484c69"},
|
||||
{file = "spacy-3.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ce3938720f48eaeeb360a7f623f15a0d9efd1a688d5d740e3d4cdcd6f6da8a3"},
|
||||
{file = "spacy-3.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ad6bf5e4e7f0bc2ef94b7ff6fe59abd766f74c192bca2f17430a3b3cd5bda5a"},
|
||||
{file = "spacy-3.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6644c678bd7af567c6ce679f71d64119282e7d6f1a6f787162a91be3ea39333"},
|
||||
{file = "spacy-3.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:e6b871de8857a6820140358db3943180fdbe03d44ed792155cee6cb95f4ac4ea"},
|
||||
{file = "spacy-3.4.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d211c2b8894354bf8d961af9a9dcab38f764e1dcddd7b80760e438fcd4c9fe43"},
|
||||
{file = "spacy-3.4.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ea41f9de30435456235c4182d8bc2eb54a0a64719856e66e780350bb4c8cfbe"},
|
||||
{file = "spacy-3.4.3-cp36-cp36m-win_amd64.whl", hash = "sha256:afaf6e716cbac4a0fbfa9e9bf95decff223936597ddd03ea869118a7576aa1b1"},
|
||||
{file = "spacy-3.4.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7115da36369b3c537caf2fe08e0b45528bd091c7f56ba3580af1e6fdfa9b1081"},
|
||||
{file = "spacy-3.4.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b3e629c889cac9656151286ec1232c6a948ce0d44a39f1ef5e60fed4f183a10"},
|
||||
{file = "spacy-3.4.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9277cd0fcb96ee5dd885f7e96c639f21afd96198d61ca32100446afbff4dfbef"},
|
||||
{file = "spacy-3.4.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a36bd06a5a147350e5f5f6903c4777296c37b18199251bb41056c3a73aa4494f"},
|
||||
{file = "spacy-3.4.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bdafcd0823ca804c39d0bed9e677eb7d0235b1259563d0fd4d3a201c71108af8"},
|
||||
{file = "spacy-3.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cdc23a48e6543402b4c56ebf2d36246001175c29fd56d3081efcec684651abc"},
|
||||
{file = "spacy-3.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:455c2fbd1de24b6fe34fa121d87525134d7498f9f458ebc8274d7940b473999e"},
|
||||
{file = "spacy-3.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1c85279fbb6b75d7fb8d7c59c2b734502e51271cad90926e8df1d21b67da5aa"},
|
||||
{file = "spacy-3.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:5c0d65f39184f522b4e67b965a42d121a3b2d799362682fe8847b64b0ce5bc7c"},
|
||||
{file = "spacy-3.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a7b97ec21ed773edb2479ae5d6c7686b8034f418df6bccd9218f5c3c2b7cf888"},
|
||||
{file = "spacy-3.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:36a9a506029842795099fd97ad95f0da2845c319020fcc7164cbf33650726f83"},
|
||||
{file = "spacy-3.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ab293eb1423fa05c7ee71b2fedda57c2b4a4ca8dc054ce678809457287b01dc"},
|
||||
{file = "spacy-3.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb6d0f185126decc8392cde7d28eb6e85ba4bca15424713288cccc49c2a3c52b"},
|
||||
{file = "spacy-3.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:676ab9ab2cf94ba48caa306f185a166e85bd35b388ec24512c8ba7dfcbc7517e"},
|
||||
{file = "spacy-3.4.3.tar.gz", hash = "sha256:22698cf5175e2b697e82699fcccee3092b42137a57d352df208d71657fd693bb"},
|
||||
{file = "spacy-3.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:07a10999a3e37f896758a92c2eed263638bcbf2747dc3a4aeea929aaa20ea28c"},
|
||||
{file = "spacy-3.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e6d98511dc8a88d3a96bcae13971a284459362076738c85053d1a3791f6cde92"},
|
||||
{file = "spacy-3.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2cad9c5543f03b3375c252e4dd45670ee8ed99c925dca15eadab5084fd1b033"},
|
||||
{file = "spacy-3.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ade19c1e676cac2546f268db22bc5eba08d12beafabe80f1b9f06028b3a0b52"},
|
||||
{file = "spacy-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:e782c8a7c4805cc1b34ed2b11f72a5cf2b9851e20f7afe3e97caf206f19f761b"},
|
||||
{file = "spacy-3.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aa027e69ef9fe42c8b02b940872e5bde0ce1bf66b6bf488c6493e3ce660c4b3a"},
|
||||
{file = "spacy-3.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddeb5d725b6fa9c9009b1ff645db8f5caab9ed8956ee3a84b8379951caad1d36"},
|
||||
{file = "spacy-3.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29d6bb428a6bb19e026d8bbb9d4385c25b21e1ce51fcaabadfb5599b2390a79c"},
|
||||
{file = "spacy-3.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a21187ad4c44e166dc3deed23992ea1a74d731c9a6bdd9fca306d455181577fa"},
|
||||
{file = "spacy-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:10643c6d335a02805f6676738a3e992323cfd9438115cc253435e5053dc93824"},
|
||||
{file = "spacy-3.4.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:486228cfa7ced18ec99008388028bd2329262ab8108e7c19252c1a67b2801909"},
|
||||
{file = "spacy-3.4.4-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcb7a213178c298b95532075d6dddfb374bbe56ef8d2687212763b4583048da2"},
|
||||
{file = "spacy-3.4.4-cp36-cp36m-win_amd64.whl", hash = "sha256:15e5c41d408d1d30d8f3dd8e4eed9ed28e6174e011b8d61c1345981562e2e8f5"},
|
||||
{file = "spacy-3.4.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8979dbd3594c5c268cedad53f456a3ec3a0a2b78a1199788aacedcd68eef3a00"},
|
||||
{file = "spacy-3.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f4736fea2630e696422dfe38bfb3d0a7864bc6a9072d6e49a906af46870e36e"},
|
||||
{file = "spacy-3.4.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:498bf01e8c7ab601c3f8d6c51497817b40a3322a3967c032536b18ce9ea26d0a"},
|
||||
{file = "spacy-3.4.4-cp37-cp37m-win_amd64.whl", hash = "sha256:95f880c6fea57d51c448ad84f96d79d8758e5e18bdbaaee060c15af11641079b"},
|
||||
{file = "spacy-3.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9ccbede9be470c5d795168bf3be41fc86e18892a9247a742b394ba866c005391"},
|
||||
{file = "spacy-3.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2f1edbecfde9c11b17e87768bb5f2c33948fb1e3bf54b2197031ff9053607277"},
|
||||
{file = "spacy-3.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66eaf4764e95699934cbd8f38717b283db185c896cfd3d1fb1ad5c6552e8b3c9"},
|
||||
{file = "spacy-3.4.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0bb7d53f1a780bb8cc1b27a81e02e8b9bc71abb959f4dc13c21af4041fdd2c7a"},
|
||||
{file = "spacy-3.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:c1a5ce5c9b19cdfb4469079e710e72bb09c3cab855f21ef6a614b84c765e0311"},
|
||||
{file = "spacy-3.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f7044dca3542579ea1e3ac6cdd821640c2f65dd0c56230688f36e15aca1b8217"},
|
||||
{file = "spacy-3.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8a495b0fc00910fb5c1fbe64fdbfe1d3c11b09f421d1ae4e30cdb4c2388a91e4"},
|
||||
{file = "spacy-3.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31e9a637960b60c1bb7a36a187271425717e97c14e9d1df613dc4efeffefcbec"},
|
||||
{file = "spacy-3.4.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71f9449ffadef85b048c9735ee235da5dca9d0a87038dba6d4ed20c5188e0f13"},
|
||||
{file = "spacy-3.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:1b7791a6c0592615b0566001596cc48c72325d1b97e46e574c91bff34f4e3f4c"},
|
||||
{file = "spacy-3.4.4.tar.gz", hash = "sha256:e500cf2cb5f1849461a7928fa269703756069bdfb71559065240af6d0208b08c"},
|
||||
]
|
||||
spacy-legacy = [
|
||||
{file = "spacy-legacy-3.0.10.tar.gz", hash = "sha256:16104595d8ab1b7267f817a449ad1f986eb1f2a2edf1050748f08739a479679a"},
|
||||
@ -3837,6 +3932,29 @@ tomli = [
|
||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||
]
|
||||
torch = [
|
||||
{file = "torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:fd12043868a34a8da7d490bf6db66991108b00ffbeecb034228bfcbbd4197143"},
|
||||
{file = "torch-1.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d9fe785d375f2e26a5d5eba5de91f89e6a3be5d11efb497e76705fdf93fa3c2e"},
|
||||
{file = "torch-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:98124598cdff4c287dbf50f53fb455f0c1e3a88022b39648102957f3445e9b76"},
|
||||
{file = "torch-1.13.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:393a6273c832e047581063fb74335ff50b4c566217019cc6ace318cd79eb0566"},
|
||||
{file = "torch-1.13.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:0122806b111b949d21fa1a5f9764d1fd2fcc4a47cb7f8ff914204fd4fc752ed5"},
|
||||
{file = "torch-1.13.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:22128502fd8f5b25ac1cd849ecb64a418382ae81dd4ce2b5cebaa09ab15b0d9b"},
|
||||
{file = "torch-1.13.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:76024be052b659ac1304ab8475ab03ea0a12124c3e7626282c9c86798ac7bc11"},
|
||||
{file = "torch-1.13.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ea8dda84d796094eb8709df0fcd6b56dc20b58fdd6bc4e8d7109930dafc8e419"},
|
||||
{file = "torch-1.13.1-cp37-cp37m-win_amd64.whl", hash = "sha256:2ee7b81e9c457252bddd7d3da66fb1f619a5d12c24d7074de91c4ddafb832c93"},
|
||||
{file = "torch-1.13.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:0d9b8061048cfb78e675b9d2ea8503bfe30db43d583599ae8626b1263a0c1380"},
|
||||
{file = "torch-1.13.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:f402ca80b66e9fbd661ed4287d7553f7f3899d9ab54bf5c67faada1555abde28"},
|
||||
{file = "torch-1.13.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:727dbf00e2cf858052364c0e2a496684b9cb5aa01dc8a8bc8bbb7c54502bdcdd"},
|
||||
{file = "torch-1.13.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:df8434b0695e9ceb8cc70650afc1310d8ba949e6db2a0525ddd9c3b2b181e5fe"},
|
||||
{file = "torch-1.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:5e1e722a41f52a3f26f0c4fcec227e02c6c42f7c094f32e49d4beef7d1e213ea"},
|
||||
{file = "torch-1.13.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:33e67eea526e0bbb9151263e65417a9ef2d8fa53cbe628e87310060c9dcfa312"},
|
||||
{file = "torch-1.13.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:eeeb204d30fd40af6a2d80879b46a7efbe3cf43cdbeb8838dd4f3d126cc90b2b"},
|
||||
{file = "torch-1.13.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:50ff5e76d70074f6653d191fe4f6a42fdbe0cf942fbe2a3af0b75eaa414ac038"},
|
||||
{file = "torch-1.13.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:2c3581a3fd81eb1f0f22997cddffea569fea53bafa372b2c0471db373b26aafc"},
|
||||
{file = "torch-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:0aa46f0ac95050c604bcf9ef71da9f1172e5037fdf2ebe051962d47b123848e7"},
|
||||
{file = "torch-1.13.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6930791efa8757cb6974af73d4996b6b50c592882a324b8fb0589c6a9ba2ddaf"},
|
||||
{file = "torch-1.13.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:e0df902a7c7dd6c795698532ee5970ce898672625635d885eade9976e5a04949"},
|
||||
]
|
||||
tornado = [
|
||||
{file = "tornado-6.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:20f638fd8cc85f3cbae3c732326e96addff0a15e22d80f049e00121651e82e72"},
|
||||
{file = "tornado-6.2-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:87dcafae3e884462f90c90ecc200defe5e580a7fbbb4365eda7c7c1eb809ebc9"},
|
||||
@ -3914,6 +4032,10 @@ websocket-client = [
|
||||
{file = "websocket-client-1.4.2.tar.gz", hash = "sha256:d6e8f90ca8e2dd4e8027c4561adeb9456b54044312dba655e7cae652ceb9ae59"},
|
||||
{file = "websocket_client-1.4.2-py3-none-any.whl", hash = "sha256:d6b06432f184438d99ac1f456eaf22fe1ade524c3dd16e661142dc54e9cba574"},
|
||||
]
|
||||
wheel = [
|
||||
{file = "wheel-0.38.4-py3-none-any.whl", hash = "sha256:b60533f3f5d530e971d6737ca6d58681ee434818fab630c83a734bb10c083ce8"},
|
||||
{file = "wheel-0.38.4.tar.gz", hash = "sha256:965f5259b566725405b05e7cf774052044b1ed30119b5d586b2703aafe8719ac"},
|
||||
]
|
||||
widgetsnbextension = [
|
||||
{file = "widgetsnbextension-4.0.4-py3-none-any.whl", hash = "sha256:fa0e840719ec95dd2ec85c3a48913f1a0c29d323eacbcdb0b29bfed0cc6da678"},
|
||||
{file = "widgetsnbextension-4.0.4.tar.gz", hash = "sha256:44c69f18237af0f610557d6c1c7ef76853f5856a0e604c0a517f2320566bb775"},
|
||||
|
@ -22,6 +22,7 @@ spacy = {version = "^3", optional = true}
|
||||
nltk = {version = "^3", optional = true}
|
||||
transformers = {version = "^4", optional = true}
|
||||
beautifulsoup4 = {version = "^4", optional = true}
|
||||
torch = {version = "^1.13.1", optional = true}
|
||||
tiktoken = {version = "^0", optional = true, python="^3.9"}
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
@ -49,8 +50,8 @@ jupyter = "^1.0.0"
|
||||
playwright = "^1.28.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml"]
|
||||
all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken"]
|
||||
llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
|
||||
all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
41
tests/integration_tests/llms/test_huggingface_pipeline.py
Normal file
41
tests/integration_tests/llms/test_huggingface_pipeline.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Test HuggingFace Pipeline wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain.llms.loading import load_llm
|
||||
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
|
||||
def test_huggingface_pipeline_text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text generation model."""
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10}
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an HuggingFaceHub LLM."""
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10}
|
||||
)
|
||||
llm.save(file_path=tmp_path / "hf.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "hf.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
||||
|
||||
def test_init_with_pipeline() -> None:
|
||||
"""Test initialization with a HF pipeline."""
|
||||
model_id = "gpt2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
pipe = pipeline(
|
||||
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
|
||||
)
|
||||
llm = HuggingFacePipeline(pipeline=pipe)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
@ -10,7 +10,7 @@ def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None:
|
||||
# Client field can be session based, so hash is different despite
|
||||
# all other values being the same, so just assess all other fields
|
||||
for field in llm.__fields__.keys():
|
||||
if field != "client":
|
||||
if field != "client" and field != "pipeline":
|
||||
val = getattr(llm, field)
|
||||
new_val = getattr(loaded_llm, field)
|
||||
assert new_val == val
|
||||
|
Loading…
Reference in New Issue
Block a user