diff --git a/docs/modules/llms/integrations/hf_text_generation.ipynb b/docs/modules/llms/integrations/hf_text_generation.ipynb index 3dd3dabba32..8e9e662e2ba 100644 --- a/docs/modules/llms/integrations/hf_text_generation.ipynb +++ b/docs/modules/llms/integrations/hf_text_generation.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c24e8282", + "id": "16ae872d", "metadata": {}, "source": [ "# Hugging Face Text Generation\n", @@ -13,7 +13,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "1f80bc2b", + "id": "aed972a5", "metadata": {}, "outputs": [], "source": [ @@ -23,7 +23,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "1565057d", + "id": "09e693b8", "metadata": {}, "outputs": [], "source": [ @@ -34,23 +34,23 @@ { "cell_type": "code", "execution_count": 3, - "id": "106d666e", + "id": "6ce9486b", "metadata": {}, "outputs": [], "source": [ - "model = HFTextGeneration(repo_id=\"bigscience/bloomz\")" + "model = HFTextGeneration(repo_id=\"google/flan-ul2\", model_kwargs={\"temperature\":0.1, \"max_new_tokens\":200})" ] }, { "cell_type": "code", "execution_count": 4, - "id": "06674c54", + "id": "2af51b5e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'? Rayleigh scattering'" + "\"i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man i'm a music man\"" ] }, "execution_count": 4, @@ -59,12 +59,12 @@ } ], "source": [ - "model(\"why is the sky blue\")" + "model(\"write me a song about music\")" ] }, { "cell_type": "markdown", - "id": "8cc64861", + "id": "d47e1e10", "metadata": {}, "source": [ "## Streaming\n", @@ -74,7 +74,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "8d3cad0a", + "id": "3ffe8d7e", "metadata": {}, "outputs": [], "source": [ @@ -85,7 +85,7 @@ { "cell_type": "code", "execution_count": 6, - "id": "f51ccaac", + "id": "5a84bdf3", "metadata": {}, "outputs": [], "source": [ @@ -95,7 +95,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "803790e0", + "id": "6d460fe9", "metadata": {}, "outputs": [ { @@ -113,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cc200116", + "id": "6ed74872", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/llms/huggingface_text_generation.py b/langchain/llms/huggingface_text_generation.py index 1b06c753892..443b91cae0d 100644 --- a/langchain/llms/huggingface_text_generation.py +++ b/langchain/llms/huggingface_text_generation.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from pydantic import root_validator +from pydantic import root_validator, Field from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -8,10 +8,11 @@ from langchain.utils import get_from_dict_or_env class HFTextGeneration(LLM): repo_id: str - client: Any + token: str streaming: bool = False + model_kwargs: dict = Field(default_factory=dict) - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" huggingfacehub_api_token = get_from_dict_or_env( @@ -25,20 +26,20 @@ class HFTextGeneration(LLM): "Could not import huggingface_hub python package. " "Please it install it with `pip install huggingface_hub`." ) - repo_id = values["repo_id"] - client = InferenceAPIClient( - repo_id=repo_id, - token=huggingfacehub_api_token, - ) - values["client"] = client + values["token"] = huggingfacehub_api_token return values def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + from text_generation import InferenceAPIClient + client = InferenceAPIClient( + repo_id=self.repo_id, + token=self.token, + ) if not self.streaming: - return self.client.generate(prompt).generated_text + return client.generate(prompt, **self.model_kwargs).generated_text if self.streaming: text = "" - for response in self.client.generate_stream(prompt): + for response in client.generate_stream(prompt, **self.model_kwargs): if not response.token.special: self.callback_manager.on_llm_new_token( response.token.text, verbose=self.verbose @@ -47,7 +48,23 @@ class HFTextGeneration(LLM): return text async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: - raise NotImplementedError + from text_generation import InferenceAPIAsyncClient + client = InferenceAPIAsyncClient( + repo_id=self.repo_id, + token=self.token, + ) + if not self.streaming: + response = await client.generate(prompt, **self.model_kwargs) + return response.generated_text + if self.streaming: + text = "" + async for response in client.generate_stream(prompt, **self.model_kwargs): + if not response.token.special: + await self.callback_manager.on_llm_new_token( + response.token.text, verbose=self.verbose + ) + text += response.token.text + return text @property def _llm_type(self) -> str: