mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 00:49:25 +00:00
HuggingFacePipeline: Forward model_kwargs. (#696)
Since the tokenizer and model are constructed manually, model_kwargs needs to be passed to their constructors. Additionally, the pipeline has a specific named parameter to pass these with, which can provide forward compatibility if they are used for something other than tokenizer or model construction.
This commit is contained in:
parent
3a30e6daa8
commit
36b6b3cdf6
@ -68,19 +68,19 @@ class HuggingFacePipeline(LLM, BaseModel):
|
||||
)
|
||||
from transformers import pipeline as hf_pipeline
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
_model_kwargs = model_kwargs or {}
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||
if task == "text-generation":
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
|
||||
elif task == "text2text-generation":
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
_model_kwargs = model_kwargs or {}
|
||||
pipeline = hf_pipeline(
|
||||
task=task, model=model, tokenizer=tokenizer, **_model_kwargs
|
||||
task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs
|
||||
)
|
||||
if pipeline.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user