Harrison/predibase (#8046)

Co-authored-by: Abhay Malik <32989166+Abhay-765@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-07-20 19:26:50 -07:00 committed by GitHub
parent 56c6ab1715
commit f99f497b2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 292 additions and 0 deletions

View File

@ -0,0 +1,24 @@
# Predibase
Learn how to use LangChain with models on Predibase.
## Setup
- Create a [Predibase](hhttps://predibase.com/) account and [API key](https://docs.predibase.com/sdk-guide/intro).
- Install the Predibase Python client with `pip install predibase`
- Use your API key to authenticate
### LLM
Predibase integrates with LangChain by implementing LLM module. You can see a short example below or a full notebook under LLM > Integrations > Predibase.
```python
import os
os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
from langchain.llms import Predibase
model = Predibase(model = 'vicuna-13b', predibase_api_key=os.environ.get('PREDIBASE_API_TOKEN'))
response = model("Can you recommend me a nice dry wine?")
print(response)
```

View File

@ -0,0 +1,214 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predibase\n",
"\n",
"[Predibase](https://predibase.com/) allows you to train, finetune, and deploy any ML model—from linear regression to large language model. \n",
"\n",
"This example demonstrates using Langchain with models deployed on Predibase"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup\n",
"\n",
"To run this notebook, you'll need a [Predibase account](https://predibase.com/free-trial/?utm_source=langchain) and an [API key](https://docs.predibase.com/sdk-guide/intro).\n",
"\n",
"You'll also need to install the Predibase Python package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install predibase\n",
"import os\n",
"\n",
"os.environ[\"PREDIBASE_API_TOKEN\"] = \"{PREDIBASE_API_TOKEN}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initial Call"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = model(\"Can you recommend me a nice dry wine?\")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chain Call Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = Predibase(\n",
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SequentialChain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is an LLMChain to write a synopsis given a title of a play.\n",
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
"\n",
"Title: {title}\n",
"Playwright: This is a synopsis for the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is an LLMChain to write a review of a play given a synopsis.\n",
"template = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n",
"\n",
"Play Synopsis:\n",
"{synopsis}\n",
"Review from a New York Times play critic of the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n",
"review_chain = LLMChain(llm=llm, prompt=prompt_template)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is the overall chain where we run these two chains in sequence.\n",
"from langchain.chains import SimpleSequentialChain\n",
"\n",
"overall_chain = SimpleSequentialChain(\n",
" chains=[synopsis_chain, review_chain], verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"review = overall_chain.run(\"Tragedy at sunset on the beach\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fine-tuned LLM (Use your own fine-tuned LLM from Predibase)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"my-finetuned-LLM\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
")\n",
"# replace my-finetuned-LLM with the name of your model in Predibase"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# response = model(\"Can you help categorize the following emails into positive, negative, and neutral?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.9 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -43,6 +43,7 @@ from langchain.llms.openllm import OpenLLM
from langchain.llms.openlm import OpenLM from langchain.llms.openlm import OpenLM
from langchain.llms.petals import Petals from langchain.llms.petals import Petals
from langchain.llms.pipelineai import PipelineAI from langchain.llms.pipelineai import PipelineAI
from langchain.llms.predibase import Predibase
from langchain.llms.predictionguard import PredictionGuard from langchain.llms.predictionguard import PredictionGuard
from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat
from langchain.llms.replicate import Replicate from langchain.llms.replicate import Replicate
@ -100,6 +101,7 @@ __all__ = [
"OpenLM", "OpenLM",
"Petals", "Petals",
"PipelineAI", "PipelineAI",
"Predibase",
"PredictionGuard", "PredictionGuard",
"PromptLayerOpenAI", "PromptLayerOpenAI",
"PromptLayerOpenAIChat", "PromptLayerOpenAIChat",
@ -156,6 +158,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"openlm": OpenLM, "openlm": OpenLM,
"petals": Petals, "petals": Petals,
"pipelineai": PipelineAI, "pipelineai": PipelineAI,
"predibase": Predibase,
"replicate": Replicate, "replicate": Replicate,
"rwkv": RWKV, "rwkv": RWKV,
"sagemaker_endpoint": SagemakerEndpoint, "sagemaker_endpoint": SagemakerEndpoint,

View File

@ -0,0 +1,51 @@
from typing import Any, Dict, List, Mapping, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
class Predibase(LLM):
"""Use your Predibase models with Langchain.
To use, you should have the ``predibase`` python package installed,
and have your Predibase API key.
"""
model: str
predibase_api_key: str
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@property
def _llm_type(self) -> str:
return "predibase"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> str:
try:
from predibase import PredibaseClient
pc = PredibaseClient(token=self.predibase_api_key)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install predibase`."
) from e
except ValueError as e:
raise ValueError("Your API key is not correct. Please try again") from e
# load model and version
results = pc.prompt(prompt, model_name=self.model)
return results[0].response
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"model_kwargs": self.model_kwargs},
}