mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 03:51:50 +00:00
Community: Fuse HuggingFace Endpoint-related classes into one (#17254)
## Description Fuse HuggingFace Endpoint-related classes into one: - [HuggingFaceHub](5ceaf784f3/libs/community/langchain_community/llms/huggingface_hub.py
) - [HuggingFaceTextGenInference](5ceaf784f3/libs/community/langchain_community/llms/huggingface_text_gen_inference.py
) - and [HuggingFaceEndpoint](5ceaf784f3/libs/community/langchain_community/llms/huggingface_endpoint.py
) Are fused into - HuggingFaceEndpoint ## Issue The deduplication of classes was creating a lack of clarity, and additional effort to develop classes leads to issues like [this hack](5ceaf784f3/libs/community/langchain_community/llms/huggingface_endpoint.py (L159)
). ## Dependancies None, this removes dependancies. ## Twitter handle If you want to post about this: @AymericRoucher --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
8009be862e
commit
0d294760e7
238
docs/docs/integrations/llms/huggingface_endpoint.ipynb
Normal file
238
docs/docs/integrations/llms/huggingface_endpoint.ipynb
Normal file
@ -0,0 +1,238 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Huggingface Endpoints\n",
|
||||
"\n",
|
||||
">The [Hugging Face Hub](https://huggingface.co/docs/hub/index) is a platform with over 120k models, 20k datasets, and 50k demo apps (Spaces), all open source and publicly available, in an online platform where people can easily collaborate and build ML together.\n",
|
||||
"\n",
|
||||
"The `Hugging Face Hub` also offers various endpoints to build ML applications.\n",
|
||||
"This example showcases how to connect to the different Endpoints types.\n",
|
||||
"\n",
|
||||
"In particular, text generation inference is powered by [Text Generation Inference](https://github.com/huggingface/text-generation-inference): a custom-built Rust, Python and gRPC server for blazing-faset text generation inference."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import HuggingFaceEndpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation and Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To use, you should have the ``huggingface_hub`` python [package installed](https://huggingface.co/docs/huggingface_hub/installation)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet huggingface_hub"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# get a token: https://huggingface.co/docs/api-inference/quicktour#get-your-api-token\n",
|
||||
"\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"HUGGINGFACEHUB_API_TOKEN = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = HUGGINGFACEHUB_API_TOKEN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prepare Examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import HuggingFaceEndpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.prompts import PromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"Who won the FIFA World Cup in the year 1994? \"\n",
|
||||
"\n",
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate.from_template(template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Examples\n",
|
||||
"\n",
|
||||
"Here is an example of how you can access `HuggingFaceEndpoint` integration of the free [Serverless Endpoints](https://huggingface.co/inference-endpoints/serverless) API."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_id = \"mistralai/Mistral-7B-Instruct-v0.2\"\n",
|
||||
"\n",
|
||||
"llm = HuggingFaceEndpoint(\n",
|
||||
" repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN\n",
|
||||
")\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Dedicated Endpoint\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The free serverless API lets you implement solutions and iterate in no time, but it may be rate limited for heavy use cases, since the loads are shared with other requests.\n",
|
||||
"\n",
|
||||
"For enterprise workloads, the best is to use [Inference Endpoints - Dedicated](https://huggingface.co/inference-endpoints/dedicated).\n",
|
||||
"This gives access to a fully managed infrastructure that offer more flexibility and speed. These resoucres come with continuous support and uptime guarantees, as well as options like AutoScaling\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set the url to your Inference Endpoint below\n",
|
||||
"your_endpoint_url = \"https://fayjubiy2xqn36z0.us-east-1.aws.endpoints.huggingface.cloud\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = HuggingFaceEndpoint(\n",
|
||||
" endpoint_url=f\"{your_endpoint_url}\",\n",
|
||||
" max_new_tokens=512,\n",
|
||||
" top_k=10,\n",
|
||||
" top_p=0.95,\n",
|
||||
" typical_p=0.95,\n",
|
||||
" temperature=0.01,\n",
|
||||
" repetition_penalty=1.03,\n",
|
||||
")\n",
|
||||
"llm(\"What did foo say about bar?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Streaming"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"from langchain_community.llms import HuggingFaceEndpoint\n",
|
||||
"\n",
|
||||
"llm = HuggingFaceEndpoint(\n",
|
||||
" endpoint_url=f\"{your_endpoint_url}\",\n",
|
||||
" max_new_tokens=512,\n",
|
||||
" top_k=10,\n",
|
||||
" top_p=0.95,\n",
|
||||
" typical_p=0.95,\n",
|
||||
" temperature=0.01,\n",
|
||||
" repetition_penalty=1.03,\n",
|
||||
" streaming=True,\n",
|
||||
")\n",
|
||||
"llm(\"What did foo say about bar?\", callbacks=[StreamingStdOutCallbackHandler()])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "agents",
|
||||
"language": "python",
|
||||
"name": "agents"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -1,466 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "959300d4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Hugging Face Hub\n",
|
||||
"\n",
|
||||
">The [Hugging Face Hub](https://huggingface.co/docs/hub/index) is a platform with over 120k models, 20k datasets, and 50k demo apps (Spaces), all open source and publicly available, in an online platform where people can easily collaborate and build ML together.\n",
|
||||
"\n",
|
||||
"This example showcases how to connect to the `Hugging Face Hub` and use different models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1ddafc6d-7d7c-48fa-838f-0e7f50895ce3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation and Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4c1b8450-5eaf-4d34-8341-2d785448a1ff",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"To use, you should have the ``huggingface_hub`` python [package installed](https://huggingface.co/docs/huggingface_hub/installation)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d772b637-de00-4663-bd77-9bc96d798db2",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade --quiet huggingface_hub"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d597a792-354c-4ca5-b483-5965eec5d63d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdin",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" ········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# get a token: https://huggingface.co/docs/api-inference/quicktour#get-your-api-token\n",
|
||||
"\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"HUGGINGFACEHUB_API_TOKEN = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b8c5b88c-e4b8-4d0d-9a35-6e8f106452c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = HUGGINGFACEHUB_API_TOKEN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84dd44c1-c428-41f3-a911-520281386c94",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prepare Examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3fe7d1d1-241d-426a-acff-e208f1088871",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import HuggingFaceHub"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "6620f39b-3d32-4840-8931-ff7d2c3e47e8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.prompts import PromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "44adc1a0-9c0a-4f1e-af5a-fe04222e78d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"Who won the FIFA World Cup in the year 1994? \"\n",
|
||||
"\n",
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate.from_template(template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ddaa06cf-95ec-48ce-b0ab-d892a7909693",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Examples\n",
|
||||
"\n",
|
||||
"Below are some examples of models you can access through the `Hugging Face Hub` integration."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4c16fded-70d1-42af-8bfa-6ddda9f0bc63",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `Flan`, by `Google`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "39c7eeac-01c4-486b-9480-e828a9e73e78",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_id = \"google/flan-t5-xxl\" # See https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads for some other options"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "3acf0069",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The FIFA World Cup was held in the year 1994. West Germany won the FIFA World Cup in 1994\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = HuggingFaceHub(\n",
|
||||
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
|
||||
")\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1a5c97af-89bc-4e59-95c1-223742a9160b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `Dolly`, by `Databricks`\n",
|
||||
"\n",
|
||||
"See [Databricks](https://huggingface.co/databricks) organization page for a list of available models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "521fcd2b-8e38-4920-b407-5c7d330411c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_id = \"databricks/dolly-v2-3b\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "9907ec3a-fe0c-4543-81c4-d42f9453f16c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" First of all, the world cup was won by the Germany. Then the Argentina won the world cup in 2022. So, the Argentina won the world cup in 1994.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Question: Who\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = HuggingFaceHub(\n",
|
||||
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
|
||||
")\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "03f6ae52-b5f9-4de6-832c-551cb3fa11ae",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `Camel`, by `Writer`\n",
|
||||
"\n",
|
||||
"See [Writer's](https://huggingface.co/Writer) organization page for a list of available models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "257a091d-750b-4910-ac08-fe1c7b3fd98b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_id = \"Writer/camel-5b-hf\" # See https://huggingface.co/Writer for other options"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b06f6838-a11a-4d6a-88e3-91fa1747a2b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = HuggingFaceHub(\n",
|
||||
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
|
||||
")\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2bf838eb-1083-402f-b099-b07c452418c8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `XGen`, by `Salesforce`\n",
|
||||
"\n",
|
||||
"See [more information](https://github.com/salesforce/xgen)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "18c78880-65d7-41d0-9722-18090efb60e9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_id = \"Salesforce/xgen-7b-8k-base\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b1150b4-ec30-4674-849e-6a41b085aa2b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = HuggingFaceHub(\n",
|
||||
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
|
||||
")\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0aca9f9e-f333-449c-97b2-10d1dbf17e75",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `Falcon`, by `Technology Innovation Institute (TII)`\n",
|
||||
"\n",
|
||||
"See [more information](https://huggingface.co/tiiuae/falcon-40b)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "496b35ac-5ee2-4b68-a6ce-232608f56c03",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_id = \"tiiuae/falcon-40b\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff2541ad-e394-4179-93c2-7ae9c4ca2a25",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = HuggingFaceHub(\n",
|
||||
" repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 64}\n",
|
||||
")\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7e15849b-5561-4bb9-86ec-6412ca10196a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `InternLM-Chat`, by `Shanghai AI Laboratory`\n",
|
||||
"\n",
|
||||
"See [more information](https://huggingface.co/internlm/internlm-7b)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "3b533461-59f8-406e-907b-000841fa60a7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_id = \"internlm/internlm-chat-7b\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c71210b9-5895-41a2-889a-f430d22fa1aa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = HuggingFaceHub(\n",
|
||||
" repo_id=repo_id, model_kwargs={\"max_length\": 128, \"temperature\": 0.8}\n",
|
||||
")\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4f2e5132-1713-42d7-919a-8c313744ce95",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `Qwen`, by `Alibaba Cloud`\n",
|
||||
"\n",
|
||||
">`Tongyi Qianwen-7B` (`Qwen-7B`) is a model with a scale of 7 billion parameters in the `Tongyi Qianwen` large model series developed by `Alibaba Cloud`. `Qwen-7B` is a large language model based on Transformer, which is trained on ultra-large-scale pre-training data.\n",
|
||||
"\n",
|
||||
"See [more information on HuggingFace](https://huggingface.co/Qwen/Qwen-7B) of on [GitHub](https://github.com/QwenLM/Qwen-7B).\n",
|
||||
"\n",
|
||||
"See here a [big example for LangChain integration and Qwen](https://github.com/QwenLM/Qwen-7B/blob/main/examples/langchain_tooluse.ipynb)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "f598b1ca-77c7-40f1-a83f-c21ea9910c88",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_id = \"Qwen/Qwen-7B\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2c97f4e2-d401-44fb-9da7-b60b2e2cc663",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = HuggingFaceHub(\n",
|
||||
" repo_id=repo_id, model_kwargs={\"max_length\": 128, \"temperature\": 0.5}\n",
|
||||
")\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e3871376-ed0e-49a8-8d9b-7e60dbbd2b35",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### `Yi` series models, by `01.ai`\n",
|
||||
"\n",
|
||||
">The `Yi` series models are large language models trained from scratch by developers at [01.ai](https://01.ai/). The first public release contains two bilingual(English/Chinese) base models with the parameter sizes of 6B(`Yi-6B`) and 34B(`Yi-34B`). Both of them are trained with 4K sequence length and can be extended to 32K during inference time. The `Yi-6B-200K` and `Yi-34B-200K` are base model with 200K context length.\n",
|
||||
"\n",
|
||||
"Here we test the [Yi-34B](https://huggingface.co/01-ai/Yi-34B) model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "1c9d3125-3f50-48b8-93b6-b50847207afa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_id = \"01-ai/Yi-34B\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b661069-8229-4850-9f13-c4ca28c0c96b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = HuggingFaceHub(\n",
|
||||
" repo_id=repo_id, model_kwargs={\"max_length\": 128, \"temperature\": 0.5}\n",
|
||||
")\n",
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"print(llm_chain.run(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd6f3edc-9f97-47a6-ab2c-116756babbe6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -1,108 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Huggingface TextGen Inference\n",
|
||||
"\n",
|
||||
"[Text Generation Inference](https://github.com/huggingface/text-generation-inference) is a Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co/) to power LLMs api-inference widgets.\n",
|
||||
"\n",
|
||||
"This notebooks goes over how to use a self hosted LLM using `Text Generation Inference`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To use, you should have the `text_generation` python package installed."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip3 install text_generation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms import HuggingFaceTextGenInference\n",
|
||||
"\n",
|
||||
"llm = HuggingFaceTextGenInference(\n",
|
||||
" inference_server_url=\"http://localhost:8010/\",\n",
|
||||
" max_new_tokens=512,\n",
|
||||
" top_k=10,\n",
|
||||
" top_p=0.95,\n",
|
||||
" typical_p=0.95,\n",
|
||||
" temperature=0.01,\n",
|
||||
" repetition_penalty=1.03,\n",
|
||||
")\n",
|
||||
"llm(\"What did foo say about bar?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Streaming"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"from langchain_community.llms import HuggingFaceTextGenInference\n",
|
||||
"\n",
|
||||
"llm = HuggingFaceTextGenInference(\n",
|
||||
" inference_server_url=\"http://localhost:8010/\",\n",
|
||||
" max_new_tokens=512,\n",
|
||||
" top_k=10,\n",
|
||||
" top_p=0.95,\n",
|
||||
" typical_p=0.95,\n",
|
||||
" temperature=0.01,\n",
|
||||
" repetition_penalty=1.03,\n",
|
||||
" streaming=True,\n",
|
||||
")\n",
|
||||
"llm(\"What did foo say about bar?\", callbacks=[StreamingStdOutCallbackHandler()])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -1,5 +1,13 @@
|
||||
{
|
||||
"redirects": [
|
||||
{
|
||||
"source": "/docs/integrations/llms/huggingface_textgen_inference",
|
||||
"destination": "/docs/integrations/llms/huggingface_endpoint"
|
||||
},
|
||||
{
|
||||
"source": "/docs/integrations/llms/huggingface_hub",
|
||||
"destination": "/docs/integrations/llms/huggingface_endpoint"
|
||||
},
|
||||
{
|
||||
"source": "/docs/integrations/llms/watsonxllm",
|
||||
"destination": "/docs/integrations/llms/ibm_watsonx"
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
@ -52,6 +53,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self._resolve_model_id()
|
||||
|
||||
self.tokenizer = (
|
||||
AutoTokenizer.from_pretrained(self.model_id)
|
||||
if self.tokenizer is None
|
||||
@ -90,10 +92,10 @@ class ChatHuggingFace(BaseChatModel):
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
raise ValueError("at least one HumanMessage must be provided")
|
||||
raise ValueError("At least one HumanMessage must be provided!")
|
||||
|
||||
if not isinstance(messages[-1], HumanMessage):
|
||||
raise ValueError("last message must be a HumanMessage")
|
||||
raise ValueError("Last message must be a HumanMessage!")
|
||||
|
||||
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
||||
|
||||
@ -135,20 +137,15 @@ class ChatHuggingFace(BaseChatModel):
|
||||
from huggingface_hub import list_inference_endpoints
|
||||
|
||||
available_endpoints = list_inference_endpoints("*")
|
||||
|
||||
if isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
endpoint_url = self.llm.inference_server_url
|
||||
|
||||
elif isinstance(self.llm, HuggingFaceEndpoint):
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
|
||||
elif isinstance(self.llm, HuggingFaceHub):
|
||||
# no need to look up model_id for HuggingFaceHub LLM
|
||||
if isinstance(self.llm, HuggingFaceHub) or (
|
||||
hasattr(self.llm, "repo_id") and self.llm.repo_id
|
||||
):
|
||||
self.model_id = self.llm.repo_id
|
||||
return
|
||||
|
||||
elif isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM type: {type(self.llm)}")
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
|
||||
for endpoint in available_endpoints:
|
||||
if endpoint.url == endpoint_url:
|
||||
@ -156,8 +153,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
if not self.model_id:
|
||||
raise ValueError(
|
||||
"Failed to resolve model_id"
|
||||
f"Could not find model id for inference server provided: {endpoint_url}"
|
||||
"Failed to resolve model_id:"
|
||||
f"Could not find model id for inference server: {endpoint_url}"
|
||||
"Make sure that your Hugging Face token has access to the endpoint."
|
||||
)
|
||||
|
||||
|
@ -1,12 +1,17 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VALID_TASKS = (
|
||||
"text2text-generation",
|
||||
@ -17,70 +22,198 @@ VALID_TASKS = (
|
||||
|
||||
|
||||
class HuggingFaceEndpoint(LLM):
|
||||
"""HuggingFace Endpoint models.
|
||||
"""
|
||||
HuggingFace Endpoint.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `text-generation` and `text2text-generation` for now.
|
||||
To use this class, you should have installed the ``huggingface_hub`` package, and
|
||||
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
|
||||
or given as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import HuggingFaceEndpoint
|
||||
endpoint_url = (
|
||||
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
|
||||
)
|
||||
hf = HuggingFaceEndpoint(
|
||||
endpoint_url=endpoint_url,
|
||||
# Basic Example (no streaming)
|
||||
llm = HuggingFaceEndpoint(
|
||||
endpoint_url="http://localhost:8010/",
|
||||
max_new_tokens=512,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
print(llm("What is Deep Learning?"))
|
||||
|
||||
# Streaming response example
|
||||
from langchain_community.callbacks import streaming_stdout
|
||||
|
||||
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
|
||||
llm = HuggingFaceEndpoint(
|
||||
endpoint_url="http://localhost:8010/",
|
||||
max_new_tokens=512,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
typical_p=0.95,
|
||||
temperature=0.01,
|
||||
repetition_penalty=1.03,
|
||||
callbacks=callbacks,
|
||||
streaming=True,
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
print(llm("What is Deep Learning?"))
|
||||
|
||||
"""
|
||||
|
||||
endpoint_url: str = ""
|
||||
endpoint_url: Optional[str] = None
|
||||
"""Endpoint URL to use."""
|
||||
repo_id: Optional[str] = None
|
||||
"""Repo to use."""
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
max_new_tokens: int = 512
|
||||
"""Maximum number of generated tokens"""
|
||||
top_k: Optional[int] = None
|
||||
"""The number of highest probability vocabulary tokens to keep for
|
||||
top-k-filtering."""
|
||||
top_p: Optional[float] = 0.95
|
||||
"""If set to < 1, only the smallest set of most probable tokens with probabilities
|
||||
that add up to `top_p` or higher are kept for generation."""
|
||||
typical_p: Optional[float] = 0.95
|
||||
"""Typical Decoding mass. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information."""
|
||||
temperature: Optional[float] = 0.8
|
||||
"""The value used to module the logits distribution."""
|
||||
repetition_penalty: Optional[float] = None
|
||||
"""The parameter for repetition penalty. 1.0 means no penalty.
|
||||
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details."""
|
||||
return_full_text: bool = False
|
||||
"""Whether to prepend the prompt to the generated text"""
|
||||
truncate: Optional[int] = None
|
||||
"""Truncate inputs tokens to the given size"""
|
||||
stop_sequences: List[str] = Field(default_factory=list)
|
||||
"""Stop generating tokens if a member of `stop_sequences` is generated"""
|
||||
seed: Optional[int] = None
|
||||
"""Random sampling seed"""
|
||||
inference_server_url: str = ""
|
||||
"""text-generation-inference instance base url"""
|
||||
timeout: int = 120
|
||||
"""Timeout in seconds"""
|
||||
streaming: bool = False
|
||||
"""Whether to generate a stream of tokens asynchronously"""
|
||||
do_sample: bool = False
|
||||
"""Activate logits sampling"""
|
||||
watermark: bool = False
|
||||
"""Watermarking with [A Watermark for Large Language Models]
|
||||
(https://arxiv.org/abs/2301.10226)"""
|
||||
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any text-generation-inference server parameters not explicitly specified"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `call` not explicitly specified"""
|
||||
model: str
|
||||
client: Any
|
||||
async_client: Any
|
||||
task: Optional[str] = None
|
||||
"""Task to call the model with.
|
||||
Should be a task that returns `generated_text` or `summary_text`."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please make sure that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
if "endpoint_url" not in values and "repo_id" not in values:
|
||||
raise ValueError(
|
||||
"Please specify an `endpoint_url` or `repo_id` for the model."
|
||||
)
|
||||
if "endpoint_url" in values and "repo_id" in values:
|
||||
raise ValueError(
|
||||
"Please specify either an `endpoint_url` OR a `repo_id`, not both."
|
||||
)
|
||||
values["model"] = values.get("endpoint_url") or values.get("repo_id")
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
"""Validate that package is installed and that the API token is valid."""
|
||||
try:
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
|
||||
try:
|
||||
HfApi(
|
||||
endpoint="https://huggingface.co", # Can be a Private Hub endpoint.
|
||||
token=huggingfacehub_api_token,
|
||||
).whoami()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not authenticate with huggingface_hub. "
|
||||
"Please check your API token."
|
||||
) from e
|
||||
from huggingface_hub import login
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please install it with `pip install huggingface_hub`."
|
||||
)
|
||||
values["huggingfacehub_api_token"] = huggingfacehub_api_token
|
||||
try:
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
login(token=huggingfacehub_api_token)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not authenticate with huggingface_hub. "
|
||||
"Please check your API token."
|
||||
) from e
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, InferenceClient
|
||||
|
||||
values["client"] = InferenceClient(
|
||||
model=values["model"],
|
||||
timeout=values["timeout"],
|
||||
token=huggingfacehub_api_token,
|
||||
**values["server_kwargs"],
|
||||
)
|
||||
values["async_client"] = AsyncInferenceClient(
|
||||
model=values["model"],
|
||||
timeout=values["timeout"],
|
||||
token=huggingfacehub_api_token,
|
||||
**values["server_kwargs"],
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling text generation inference API."""
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"typical_p": self.typical_p,
|
||||
"temperature": self.temperature,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"return_full_text": self.return_full_text,
|
||||
"truncate": self.truncate,
|
||||
"stop_sequences": self.stop_sequences,
|
||||
"seed": self.seed,
|
||||
"do_sample": self.do_sample,
|
||||
"watermark": self.watermark,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
@ -95,6 +228,13 @@ class HuggingFaceEndpoint(LLM):
|
||||
"""Return type of llm."""
|
||||
return "huggingface_endpoint"
|
||||
|
||||
def _invocation_params(
|
||||
self, runtime_stop: Optional[List[str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
params = {**self._default_params, **kwargs}
|
||||
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
|
||||
return params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -102,62 +242,129 @@ class HuggingFaceEndpoint(LLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = hf("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
# payload samples
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
parameter_payload = {"inputs": prompt, "parameters": params}
|
||||
|
||||
# HTTP headers for authorization
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.huggingfacehub_api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# send request
|
||||
try:
|
||||
response = requests.post(
|
||||
self.endpoint_url, headers=headers, json=parameter_payload
|
||||
)
|
||||
except requests.exceptions.RequestException as e: # This is the correct syntax
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
generated_text = response.json()
|
||||
if "error" in generated_text:
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {generated_text['error']}"
|
||||
)
|
||||
if self.task == "text-generation":
|
||||
text = generated_text[0]["generated_text"]
|
||||
# Remove prompt if included in generated text.
|
||||
if text.startswith(prompt):
|
||||
text = text[len(prompt) :]
|
||||
elif self.task == "text2text-generation":
|
||||
text = generated_text[0]["generated_text"]
|
||||
elif self.task == "summarization":
|
||||
text = generated_text[0]["summary_text"]
|
||||
elif self.task == "conversational":
|
||||
text = generated_text["response"][1]
|
||||
"""Call out to HuggingFace Hub's inference endpoint."""
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **invocation_params):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
invocation_params["stop"] = invocation_params[
|
||||
"stop_sequences"
|
||||
] # porting 'stop_sequences' into the 'stop' argument
|
||||
response = self.client.post(
|
||||
json={"inputs": prompt, "parameters": invocation_params},
|
||||
stream=False,
|
||||
task=self.task,
|
||||
)
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
response_text = json.loads(response.decode())[0]["generated_text"]
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then we remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(
|
||||
prompt, stop, run_manager, **invocation_params
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
invocation_params["stop"] = invocation_params["stop_sequences"]
|
||||
response = await self.async_client.post(
|
||||
json={"inputs": prompt, "parameters": invocation_params},
|
||||
stream=False,
|
||||
task=self.task,
|
||||
)
|
||||
response_text = json.loads(response.decode())[0]["generated_text"]
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
for response in self.client.text_generation(
|
||||
prompt, **invocation_params, stream=True
|
||||
):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in response:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if stop_seq_found:
|
||||
text = response[: response.index(stop_seq_found)]
|
||||
else:
|
||||
text = response
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
chunk = GenerationChunk(text=text)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
async for response in await self.async_client.text_generation(
|
||||
prompt, **invocation_params, stream=True
|
||||
):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in response:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if stop_seq_found:
|
||||
text = response[: response.index(stop_seq_found)]
|
||||
else:
|
||||
text = response
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
chunk = GenerationChunk(text=text)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text)
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
@ -19,8 +20,10 @@ VALID_TASKS_DICT = {
|
||||
}
|
||||
|
||||
|
||||
@deprecated("0.0.21", removal="0.2.0", alternative="HuggingFaceEndpoint")
|
||||
class HuggingFaceHub(LLM):
|
||||
"""HuggingFaceHub models.
|
||||
! This class is deprecated, you should use HuggingFaceEndpoint instead.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
|
@ -9,8 +9,6 @@ from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.pydantic_v1 import Extra
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
@ -201,7 +199,12 @@ class HuggingFacePipeline(BaseLLM):
|
||||
batch_prompts = prompts[i : i + self.batch_size]
|
||||
|
||||
# Process batch of prompts
|
||||
responses = self.pipeline(batch_prompts, **pipeline_kwargs)
|
||||
responses = self.pipeline(
|
||||
batch_prompts,
|
||||
stop_sequence=stop,
|
||||
return_full_text=False,
|
||||
**pipeline_kwargs,
|
||||
)
|
||||
|
||||
# Process each response in the batch
|
||||
for j, response in enumerate(responses):
|
||||
@ -210,23 +213,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
response = response[0]
|
||||
|
||||
if self.pipeline.task == "text-generation":
|
||||
try:
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
|
||||
remove_prompt = (
|
||||
self.pipeline._postprocess_params.get("return_type")
|
||||
!= ReturnType.NEW_TEXT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unable to extract pipeline return_type. "
|
||||
f"Received error:\n\n{e}"
|
||||
)
|
||||
remove_prompt = True
|
||||
if remove_prompt:
|
||||
text = response["generated_text"][len(batch_prompts[j]) :]
|
||||
else:
|
||||
text = response["generated_text"]
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "text2text-generation":
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "summarization":
|
||||
@ -236,9 +223,6 @@ class HuggingFacePipeline(BaseLLM):
|
||||
f"Got invalid task {self.pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
if stop:
|
||||
# Enforce stop tokens
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
# Append the processed text to results
|
||||
text_generations.append(text)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@ -13,9 +14,11 @@ from langchain_core.utils import get_pydantic_field_names
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@deprecated("0.0.21", removal="0.2.0", alternative="HuggingFaceEndpoint")
|
||||
class HuggingFaceTextGenInference(LLM):
|
||||
"""
|
||||
HuggingFace text generation API.
|
||||
! This class is deprecated, you should use HuggingFaceEndpoint instead !
|
||||
|
||||
To use, you should have the `text-generation` python package installed and
|
||||
a text-generation server running.
|
||||
|
@ -1,6 +1,5 @@
|
||||
"""Test HuggingFace API wrapper."""
|
||||
"""Test HuggingFace Endpoints."""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@ -10,42 +9,9 @@ from langchain_community.llms.loading import load_llm
|
||||
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
|
||||
@unittest.skip(
|
||||
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
|
||||
)
|
||||
def test_huggingface_endpoint_text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text generation model."""
|
||||
llm = HuggingFaceEndpoint(
|
||||
endpoint_url="", task="text-generation", model_kwargs={"max_new_tokens": 10}
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
print(output) # noqa: T201
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@unittest.skip(
|
||||
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
|
||||
)
|
||||
def test_huggingface_endpoint_text2text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text2text model."""
|
||||
llm = HuggingFaceEndpoint(endpoint_url="", task="text2text-generation")
|
||||
output = llm("The capital of New York is")
|
||||
assert output == "Albany"
|
||||
|
||||
|
||||
@unittest.skip(
|
||||
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
|
||||
)
|
||||
def test_huggingface_endpoint_summarization() -> None:
|
||||
"""Test valid call to HuggingFace summarization model."""
|
||||
llm = HuggingFaceEndpoint(endpoint_url="", task="summarization")
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_endpoint_call_error() -> None:
|
||||
"""Test valid call to HuggingFace that errors."""
|
||||
llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1})
|
||||
llm = HuggingFaceEndpoint(endpoint_url="", model_kwargs={"max_new_tokens": -1})
|
||||
with pytest.raises(ValueError):
|
||||
llm("Say foo:")
|
||||
|
||||
@ -58,3 +24,58 @@ def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
|
||||
llm.save(file_path=tmp_path / "hf.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "hf.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
||||
|
||||
def test_huggingface_text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text generation model."""
|
||||
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
|
||||
output = llm("Say foo:")
|
||||
print(output) # noqa: T201
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_text2text_generation() -> None:
|
||||
"""Test valid call to HuggingFace text2text model."""
|
||||
llm = HuggingFaceEndpoint(repo_id="google/flan-t5-xl")
|
||||
output = llm("The capital of New York is")
|
||||
assert output == "Albany"
|
||||
|
||||
|
||||
def test_huggingface_summarization() -> None:
|
||||
"""Test valid call to HuggingFace summarization model."""
|
||||
llm = HuggingFaceEndpoint(repo_id="facebook/bart-large-cnn")
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_call_error() -> None:
|
||||
"""Test valid call to HuggingFace that errors."""
|
||||
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": -1})
|
||||
with pytest.raises(ValueError):
|
||||
llm("Say foo:")
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an HuggingFaceEndpoint LLM."""
|
||||
llm = HuggingFaceEndpoint(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
|
||||
llm.save(file_path=tmp_path / "hf.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "hf.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
||||
|
||||
def test_invocation_params_stop_sequences() -> None:
|
||||
llm = HuggingFaceEndpoint()
|
||||
assert llm._default_params["stop_sequences"] == []
|
||||
|
||||
runtime_stop = None
|
||||
assert llm._invocation_params(runtime_stop)["stop_sequences"] == []
|
||||
assert llm._default_params["stop_sequences"] == []
|
||||
|
||||
runtime_stop = ["stop"]
|
||||
assert llm._invocation_params(runtime_stop)["stop_sequences"] == ["stop"]
|
||||
assert llm._default_params["stop_sequences"] == []
|
||||
|
||||
llm = HuggingFaceEndpoint(stop_sequences=["."])
|
||||
runtime_stop = ["stop"]
|
||||
assert llm._invocation_params(runtime_stop)["stop_sequences"] == [".", "stop"]
|
||||
assert llm._default_params["stop_sequences"] == ["."]
|
||||
|
Loading…
Reference in New Issue
Block a user