fix(langchain-classic): fix init_chat_model for HuggingFace models (#33943)

This commit is contained in:
Georgey
2025-12-12 08:05:48 -08:00
committed by GitHub
parent 13dd115d1d
commit 16c984ef0a
4 changed files with 163 additions and 26 deletions

View File

@@ -627,8 +627,54 @@ class ChatHuggingFace(BaseChatModel):
HuggingFacePipeline,
)
task = task if task is not None else "text-generation"
# Separate pipeline-specific kwargs from ChatHuggingFace kwargs
# Parameters that should go to HuggingFacePipeline.from_model_id
pipeline_specific_kwargs = {}
# Extract pipeline-specific parameters
pipeline_keys = [
"backend",
"device",
"device_map",
"model_kwargs",
"pipeline_kwargs",
"batch_size",
]
for key in pipeline_keys:
if key in kwargs:
pipeline_specific_kwargs[key] = kwargs.pop(key)
# Remaining kwargs (temperature, max_tokens, etc.) should go to
# pipeline_kwargs for generation parameters, which ChatHuggingFace
# will inherit from the LLM
if "pipeline_kwargs" not in pipeline_specific_kwargs:
pipeline_specific_kwargs["pipeline_kwargs"] = {}
# Add generation parameters to pipeline_kwargs
# Map max_tokens to max_new_tokens for HuggingFace pipeline
generation_params = {}
for k, v in list(kwargs.items()):
if k == "max_tokens":
generation_params["max_new_tokens"] = v
kwargs.pop(k)
elif k in (
"temperature",
"max_new_tokens",
"top_p",
"top_k",
"repetition_penalty",
"do_sample",
):
generation_params[k] = v
kwargs.pop(k)
pipeline_specific_kwargs["pipeline_kwargs"].update(generation_params)
# Create the HuggingFacePipeline
llm = HuggingFacePipeline.from_model_id(
model_id=model_id, task=cast(str, task), **kwargs
model_id=model_id, task=task, **pipeline_specific_kwargs
)
elif backend == "endpoint":
from langchain_huggingface.llms.huggingface_endpoint import (