mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(langchain-classic): fix init_chat_model for HuggingFace models (#33943)
This commit is contained in:
@@ -627,8 +627,54 @@ class ChatHuggingFace(BaseChatModel):
|
||||
HuggingFacePipeline,
|
||||
)
|
||||
|
||||
task = task if task is not None else "text-generation"
|
||||
|
||||
# Separate pipeline-specific kwargs from ChatHuggingFace kwargs
|
||||
# Parameters that should go to HuggingFacePipeline.from_model_id
|
||||
pipeline_specific_kwargs = {}
|
||||
|
||||
# Extract pipeline-specific parameters
|
||||
pipeline_keys = [
|
||||
"backend",
|
||||
"device",
|
||||
"device_map",
|
||||
"model_kwargs",
|
||||
"pipeline_kwargs",
|
||||
"batch_size",
|
||||
]
|
||||
for key in pipeline_keys:
|
||||
if key in kwargs:
|
||||
pipeline_specific_kwargs[key] = kwargs.pop(key)
|
||||
|
||||
# Remaining kwargs (temperature, max_tokens, etc.) should go to
|
||||
# pipeline_kwargs for generation parameters, which ChatHuggingFace
|
||||
# will inherit from the LLM
|
||||
if "pipeline_kwargs" not in pipeline_specific_kwargs:
|
||||
pipeline_specific_kwargs["pipeline_kwargs"] = {}
|
||||
|
||||
# Add generation parameters to pipeline_kwargs
|
||||
# Map max_tokens to max_new_tokens for HuggingFace pipeline
|
||||
generation_params = {}
|
||||
for k, v in list(kwargs.items()):
|
||||
if k == "max_tokens":
|
||||
generation_params["max_new_tokens"] = v
|
||||
kwargs.pop(k)
|
||||
elif k in (
|
||||
"temperature",
|
||||
"max_new_tokens",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
"do_sample",
|
||||
):
|
||||
generation_params[k] = v
|
||||
kwargs.pop(k)
|
||||
|
||||
pipeline_specific_kwargs["pipeline_kwargs"].update(generation_params)
|
||||
|
||||
# Create the HuggingFacePipeline
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id=model_id, task=cast(str, task), **kwargs
|
||||
model_id=model_id, task=task, **pipeline_specific_kwargs
|
||||
)
|
||||
elif backend == "endpoint":
|
||||
from langchain_huggingface.llms.huggingface_endpoint import (
|
||||
|
||||
Reference in New Issue
Block a user