mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
Harrison/version 0040 (#366)
This commit is contained in:
parent
50257fce59
commit
a7084ad6e4
@ -61,24 +61,36 @@ class HuggingFacePipeline(LLM, BaseModel):
|
|||||||
) -> LLM:
|
) -> LLM:
|
||||||
"""Construct the pipeline object from model_id and task."""
|
"""Construct the pipeline object from model_id and task."""
|
||||||
try:
|
try:
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
from transformers import pipeline as hf_pipeline
|
from transformers import pipeline as hf_pipeline
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
if task == "text-generation":
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
elif task == "text2text-generation":
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got invalid task {task}, "
|
||||||
|
f"currently only {VALID_TASKS} are supported"
|
||||||
|
)
|
||||||
|
_model_kwargs = model_kwargs or {}
|
||||||
pipeline = hf_pipeline(
|
pipeline = hf_pipeline(
|
||||||
task=task, model=model, tokenizer=tokenizer, **model_kwargs
|
task=task, model=model, tokenizer=tokenizer, **_model_kwargs
|
||||||
)
|
)
|
||||||
if pipeline.task not in VALID_TASKS:
|
if pipeline.task not in VALID_TASKS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got invalid task {pipeline.task}, "
|
f"Got invalid task {pipeline.task}, "
|
||||||
f"currently only {VALID_TASKS} are supported"
|
f"currently only {VALID_TASKS} are supported"
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=_model_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -100,7 +112,7 @@ class HuggingFacePipeline(LLM, BaseModel):
|
|||||||
return "huggingface_pipeline"
|
return "huggingface_pipeline"
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
response = self.pipeline(text_inputs=prompt)
|
response = self.pipeline(prompt)
|
||||||
if self.pipeline.task == "text-generation":
|
if self.pipeline.task == "text-generation":
|
||||||
# Text generation return includes the starter text.
|
# Text generation return includes the starter text.
|
||||||
text = response[0]["generated_text"][len(prompt) :]
|
text = response[0]["generated_text"][len(prompt) :]
|
||||||
|
@ -163,6 +163,9 @@ class OpenAI(LLM, BaseModel):
|
|||||||
def stream(self, prompt: str) -> Generator:
|
def stream(self, prompt: str) -> Generator:
|
||||||
"""Call OpenAI with streaming flag and return the resulting generator.
|
"""Call OpenAI with streaming flag and return the resulting generator.
|
||||||
|
|
||||||
|
BETA: this is a beta feature while we figure out the right abstraction.
|
||||||
|
Once that happens, this interface could change.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: The prompts to pass into the model.
|
prompt: The prompts to pass into the model.
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain"
|
name = "langchain"
|
||||||
version = "0.0.39"
|
version = "0.0.40"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
authors = []
|
authors = []
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -18,6 +18,15 @@ def test_huggingface_pipeline_text_generation() -> None:
|
|||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_huggingface_pipeline_text2text_generation() -> None:
|
||||||
|
"""Test valid call to HuggingFace text2text generation model."""
|
||||||
|
llm = HuggingFacePipeline.from_model_id(
|
||||||
|
model_id="google/flan-t5-small", task="text2text-generation"
|
||||||
|
)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
"""Test saving/loading an HuggingFaceHub LLM."""
|
"""Test saving/loading an HuggingFaceHub LLM."""
|
||||||
llm = HuggingFacePipeline.from_model_id(
|
llm = HuggingFacePipeline.from_model_id(
|
||||||
|
Loading…
Reference in New Issue
Block a user