mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 04:28:58 +00:00
huggingface[patch]: Support for HuggingFacePipeline in ChatHuggingFace. (#22194)
- **Description:** Added support for using HuggingFacePipeline in ChatHuggingFace (previously it was only usable with API endpoints, probably by oversight). - **Issue:** #19997 - **Dependencies:** none - **Twitter handle:** none --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
@@ -58,6 +58,62 @@
|
|||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### `HuggingFacePipeline`"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_huggingface import HuggingFacePipeline\n",
|
||||||
|
"\n",
|
||||||
|
"llm = HuggingFacePipeline.from_model_id(\n",
|
||||||
|
" model_id=\"HuggingFaceH4/zephyr-7b-beta\",\n",
|
||||||
|
" task=\"text-generation\",\n",
|
||||||
|
" pipeline_kwargs=dict(\n",
|
||||||
|
" max_new_tokens=512,\n",
|
||||||
|
" do_sample=False,\n",
|
||||||
|
" repetition_penalty=1.03,\n",
|
||||||
|
" ),\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"To run a quantized version, you might specify a `bitsandbytes` quantization config as follows:\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"from transformers import BitsAndBytesConfig\n",
|
||||||
|
"\n",
|
||||||
|
"quantization_config = BitsAndBytesConfig(\n",
|
||||||
|
" load_in_4bit=True,\n",
|
||||||
|
" bnb_4bit_quant_type=\"nf4\",\n",
|
||||||
|
" bnb_4bit_compute_dtype=\"float16\",\n",
|
||||||
|
" bnb_4bit_use_double_quant=True\n",
|
||||||
|
")\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"and pass it to the `HuggingFacePipeline` as a part of its `model_kwargs`:\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"pipeline = HuggingFacePipeline(\n",
|
||||||
|
" ...\n",
|
||||||
|
"\n",
|
||||||
|
" model_kwargs={\"quantization_config\": quantization_config},\n",
|
||||||
|
" \n",
|
||||||
|
" ...\n",
|
||||||
|
")\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@@ -35,6 +35,7 @@ from langchain_core.tools import BaseTool
|
|||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
|
||||||
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
|
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||||
|
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
|
||||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
|
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
|
||||||
|
|
||||||
@@ -135,6 +136,10 @@ def _is_huggingface_endpoint(llm: Any) -> bool:
|
|||||||
return isinstance(llm, HuggingFaceEndpoint)
|
return isinstance(llm, HuggingFaceEndpoint)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_huggingface_pipeline(llm: Any) -> bool:
|
||||||
|
return isinstance(llm, HuggingFacePipeline)
|
||||||
|
|
||||||
|
|
||||||
class ChatHuggingFace(BaseChatModel):
|
class ChatHuggingFace(BaseChatModel):
|
||||||
"""
|
"""
|
||||||
Wrapper for using Hugging Face LLM's as ChatModels.
|
Wrapper for using Hugging Face LLM's as ChatModels.
|
||||||
@@ -150,8 +155,8 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
llm: Any
|
llm: Any
|
||||||
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or
|
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint,
|
||||||
HuggingFaceHub."""
|
HuggingFaceHub, or HuggingFacePipeline."""
|
||||||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
||||||
tokenizer: Any = None
|
tokenizer: Any = None
|
||||||
model_id: Optional[str] = None
|
model_id: Optional[str] = None
|
||||||
@@ -175,10 +180,12 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
not _is_huggingface_hub(values["llm"])
|
not _is_huggingface_hub(values["llm"])
|
||||||
and not _is_huggingface_textgen_inference(values["llm"])
|
and not _is_huggingface_textgen_inference(values["llm"])
|
||||||
and not _is_huggingface_endpoint(values["llm"])
|
and not _is_huggingface_endpoint(values["llm"])
|
||||||
|
and not _is_huggingface_pipeline(values["llm"])
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||||
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(values['llm'])}"
|
"HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "
|
||||||
|
f"received {type(values['llm'])}"
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@@ -293,6 +300,9 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
return
|
return
|
||||||
elif _is_huggingface_textgen_inference(self.llm):
|
elif _is_huggingface_textgen_inference(self.llm):
|
||||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||||
|
elif _is_huggingface_pipeline(self.llm):
|
||||||
|
self.model_id = self.llm.model_id
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
endpoint_url = self.llm.endpoint_url
|
endpoint_url = self.llm.endpoint_url
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user