langchain/docs/docs/integrations/llms/predibase.ipynb
Alex Sherstinsky 617a4e617b
community: Fix a bug in handling kwargs overwrites in Predibase integration, and update the documentation. (#25893)
Thank you for contributing to LangChain!

- [x] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [x] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
2024-08-30 12:41:42 -07:00

332 lines
9.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predibase\n",
"\n",
"[Predibase](https://predibase.com/) allows you to train, fine-tune, and deploy any ML model—from linear regression to large language model. \n",
"\n",
"This example demonstrates using Langchain with models deployed on Predibase"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup\n",
"\n",
"To run this notebook, you'll need a [Predibase account](https://predibase.com/free-trial/?utm_source=langchain) and an [API key](https://docs.predibase.com/sdk-guide/intro).\n",
"\n",
"You'll also need to install the Predibase Python package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet predibase\n",
"import os\n",
"\n",
"os.environ[\"PREDIBASE_API_TOKEN\"] = \"{PREDIBASE_API_TOKEN}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initial Call"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import Predibase\n",
"\n",
"# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n",
"model = Predibase(\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"e2e_nlg\",\n",
" adapter_version=1,\n",
" **{\n",
" \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n",
" \"max_new_tokens\": 5, # default is 256\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import Predibase\n",
"\n",
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
"model = Predibase(\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"predibase/e2e_nlg\",\n",
" **{\n",
" \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n",
" \"max_new_tokens\": 5, # default is 256\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optionally use `kwargs` to dynamically overwrite \"generate()\" settings.\n",
"response = model.invoke(\n",
" \"Can you recommend me a nice dry wine?\",\n",
" **{\"temperature\": 0.5, \"max_new_tokens\": 1024},\n",
")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"source": [
"## Chain Call Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from langchain_community.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" **{\n",
" \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n",
" \"max_new_tokens\": 5, # default is 256\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"# With a fine-tuned adapter hosted at Predibase (adapter_version must be specified).\n",
"model = Predibase(\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"e2e_nlg\",\n",
" adapter_version=1,\n",
" **{\n",
" \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n",
" \"max_new_tokens\": 5, # default is 256\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# With a fine-tuned adapter hosted at HuggingFace (adapter_version does not apply and will be ignored).\n",
"llm = Predibase(\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"predibase/e2e_nlg\",\n",
" **{\n",
" \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n",
" \"max_new_tokens\": 5, # default is 256\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SequentialChain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain_core.prompts import PromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is an LLMChain to write a synopsis given a title of a play.\n",
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
"\n",
"Title: {title}\n",
"Playwright: This is a synopsis for the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is an LLMChain to write a review of a play given a synopsis.\n",
"template = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n",
"\n",
"Play Synopsis:\n",
"{synopsis}\n",
"Review from a New York Times play critic of the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n",
"review_chain = LLMChain(llm=llm, prompt=prompt_template)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is the overall chain where we run these two chains in sequence.\n",
"from langchain.chains import SimpleSequentialChain\n",
"\n",
"overall_chain = SimpleSequentialChain(\n",
" chains=[synopsis_chain, review_chain], verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"review = overall_chain.run(\"Tragedy at sunset on the beach\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fine-tuned LLM (Use your own fine-tuned LLM from Predibase)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"my-base-LLM\",\n",
" predibase_api_key=os.environ.get(\n",
" \"PREDIBASE_API_TOKEN\"\n",
" ), # Adapter argument is optional.\n",
" predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n",
" adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted adapter repositories.\n",
" adapter_version=1, # required for Predibase-hosted adapters (ignored for HuggingFace-hosted adapters)\n",
" **{\n",
" \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n",
" \"max_new_tokens\": 5, # default is 256\n",
" },\n",
")\n",
"# replace my-base-LLM with the name of your choice of a serverless base model in Predibase"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optionally use `kwargs` to dynamically overwrite \"generate()\" settings.\n",
"# response = model.invoke(\"Can you help categorize the following emails into positive, negative, and neutral?\", **{\"temperature\": 0.5, \"max_new_tokens\": 1024})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.9 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}