mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
community: refactor Baseten integration with new API endpoints & docs (#15017)
- **Description:** In response to user feedback, this PR refactors the Baseten integration with updated model endpoints, as well as updates relevant documentation. This PR has been tested by end users in production and works as expected. - **Issue:** N/A - **Dependencies:** This PR actually removes the dependency on the `baseten` package! - **Twitter handle:** https://twitter.com/basetenco
This commit is contained in:
parent
3fc1b3553b
commit
6342da333a
@ -7,9 +7,9 @@
|
||||
"source": [
|
||||
"# Baseten\n",
|
||||
"\n",
|
||||
"[Baseten](https://baseten.co) provides all the infrastructure you need to deploy and serve ML models performantly, scalably, and cost-efficiently.\n",
|
||||
"[Baseten](https://baseten.co) is a [Provider](https://python.langchain.com/docs/integrations/providers/baseten) in the LangChain ecosystem that implements the LLMs component.\n",
|
||||
"\n",
|
||||
"This example demonstrates using Langchain with models deployed on Baseten."
|
||||
"This example demonstrates using an LLM — Mistral 7B hosted on Baseten — with LangChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -19,29 +19,16 @@
|
||||
"source": [
|
||||
"# Setup\n",
|
||||
"\n",
|
||||
"To run this notebook, you'll need a [Baseten account](https://baseten.co) and an [API key](https://docs.baseten.co/settings/api-keys).\n",
|
||||
"To run this example, you'll need:\n",
|
||||
"\n",
|
||||
"You'll also need to install the Baseten Python package:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install baseten"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import baseten\n",
|
||||
"* A [Baseten account](https://baseten.co)\n",
|
||||
"* An [API key](https://docs.baseten.co/observability/api-keys)\n",
|
||||
"\n",
|
||||
"baseten.login(\"YOUR_API_KEY\")"
|
||||
"Export your API key to your as an environment variable called `BASETEN_API_KEY`.\n",
|
||||
"\n",
|
||||
"```sh\n",
|
||||
"export BASETEN_API_KEY=\"paste_your_api_key_here\"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -53,9 +40,9 @@
|
||||
"\n",
|
||||
"First, you'll need to deploy a model to Baseten.\n",
|
||||
"\n",
|
||||
"You can deploy foundation models like WizardLM and Alpaca with one click from the [Baseten model library](https://app.baseten.co/explore/) or if you have your own model, [deploy it with this tutorial](https://docs.baseten.co/deploying-models/deploy).\n",
|
||||
"You can deploy foundation models like Mistral and Llama 2 with one click from the [Baseten model library](https://app.baseten.co/explore/) or if you have your own model, [deploy it with Truss](https://truss.baseten.co/welcome).\n",
|
||||
"\n",
|
||||
"In this example, we'll work with WizardLM. [Deploy WizardLM here](https://app.baseten.co/explore/llama) and follow along with the deployed [model's version ID](https://docs.baseten.co/managing-models/manage)."
|
||||
"In this example, we'll work with Mistral 7B. [Deploy Mistral 7B here](https://app.baseten.co/explore/mistral_7b_instruct) and follow along with the deployed model's ID, found in the model dashboard."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -64,7 +51,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import Baseten"
|
||||
"from langchain_community.llms import Baseten"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -74,7 +61,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the model\n",
|
||||
"wizardlm = Baseten(model=\"MODEL_VERSION_ID\", verbose=True)"
|
||||
"mistral = Baseten(model=\"MODEL_ID\", deployment=\"production\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -84,8 +71,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Prompt the model\n",
|
||||
"\n",
|
||||
"wizardlm(\"What is the difference between a Wizard and a Sorcerer?\")"
|
||||
"mistral(\"What is the Mistral wind?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -97,7 +83,7 @@
|
||||
"\n",
|
||||
"We can chain together multiple calls to one or multiple models, which is the whole point of Langchain!\n",
|
||||
"\n",
|
||||
"This example uses WizardLM to plan a meal with an entree, three sides, and an alcoholic and non-alcoholic beverage pairing."
|
||||
"For example, we can replace GPT with Mistral in this [demo of terminal emulation](https://python.langchain.com/docs/modules/agents/how_to/chatgpt_clone)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -106,24 +92,37 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import LLMChain, SimpleSequentialChain\n",
|
||||
"from langchain.prompts import PromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Build the first link in the chain\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.memory import ConversationBufferWindowMemory\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"cuisine\"],\n",
|
||||
" template=\"Name a complex entree for a {cuisine} dinner. Respond with just the name of a single dish.\",\n",
|
||||
"template = \"\"\"Assistant is a large language model trained by OpenAI.\n",
|
||||
"\n",
|
||||
"Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.\n",
|
||||
"\n",
|
||||
"Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics.\n",
|
||||
"\n",
|
||||
"Overall, Assistant is a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist.\n",
|
||||
"\n",
|
||||
"{history}\n",
|
||||
"Human: {human_input}\n",
|
||||
"Assistant:\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(input_variables=[\"history\", \"human_input\"], template=template)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"chatgpt_chain = LLMChain(\n",
|
||||
" llm=mistral,\n",
|
||||
" llm_kwargs={\"max_length\": 4096},\n",
|
||||
" prompt=prompt,\n",
|
||||
" verbose=True,\n",
|
||||
" memory=ConversationBufferWindowMemory(k=2),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"link_one = LLMChain(llm=wizardlm, prompt=prompt)"
|
||||
"output = chatgpt_chain.predict(\n",
|
||||
" human_input=\"I want you to act as a Linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. Do not write explanations. Do not type commands unless I instruct you to do so. When I need to tell you something in English I will do so by putting text inside curly brackets {like this}. My first command is pwd.\"\n",
|
||||
")\n",
|
||||
"print(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -132,14 +131,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Build the second link in the chain\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"entree\"],\n",
|
||||
" template=\"What are three sides that would go with {entree}. Respond with only a list of the sides.\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"link_two = LLMChain(llm=wizardlm, prompt=prompt)"
|
||||
"output = chatgpt_chain.predict(human_input=\"ls ~\")\n",
|
||||
"print(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -148,14 +141,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Build the third link in the chain\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"sides\"],\n",
|
||||
" template=\"What is one alcoholic and one non-alcoholic beverage that would go well with this list of sides: {sides}. Respond with only the names of the beverages.\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"link_three = LLMChain(llm=wizardlm, prompt=prompt)"
|
||||
"output = chatgpt_chain.predict(human_input=\"cd ~\")\n",
|
||||
"print(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -164,12 +151,17 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run the full chain!\n",
|
||||
"\n",
|
||||
"menu_maker = SimpleSequentialChain(\n",
|
||||
" chains=[link_one, link_two, link_three], verbose=True\n",
|
||||
"output = chatgpt_chain.predict(\n",
|
||||
" human_input=\"\"\"echo -e \"x=lambda y:y*5+3;print('Result:' + str(x(6)))\" > run.py && python3 run.py\"\"\"\n",
|
||||
")\n",
|
||||
"menu_maker.run(\"South Indian\")"
|
||||
"print(output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As we can see from the final example, which outputs a number that may or may not be correct, the model is only approximating likely terminal output, not actually executing provided commands. Still, the example demonstrates Mistral's ample context window, code generation capabilities, and ability to stay on-topic even in conversational sequences."
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -1,25 +1,71 @@
|
||||
# Baseten
|
||||
|
||||
Learn how to use LangChain with models deployed on Baseten.
|
||||
[Baseten](https://baseten.co) provides all the infrastructure you need to deploy and serve ML models performantly, scalably, and cost-efficiently.
|
||||
|
||||
## Installation and setup
|
||||
As a model inference platform, Baseten is a `Provider` in the LangChain ecosystem. The Baseten integration currently implements a single `Component`, LLMs, but more are planned!
|
||||
|
||||
- Create a [Baseten](https://baseten.co) account and [API key](https://docs.baseten.co/settings/api-keys).
|
||||
- Install the Baseten Python client with `pip install baseten`
|
||||
- Use your API key to authenticate with `baseten login`
|
||||
Baseten lets you run both open source models like Llama 2 or Mistral and run proprietary or fine-tuned models on dedicated GPUs. If you're used to a provider like OpenAI, using Baseten has a few differences:
|
||||
|
||||
## Invoking a model
|
||||
* Rather than paying per token, you pay per minute of GPU used.
|
||||
* Every model on Baseten uses [Truss](https://truss.baseten.co/welcome), our open-source model packaging framework, for maximum customizability.
|
||||
* While we have some [OpenAI ChatCompletions-compatible models](https://docs.baseten.co/api-reference/openai), you can define your own I/O spec with Truss.
|
||||
|
||||
Baseten integrates with LangChain through the LLM module, which provides a standardized and interoperable interface for models that are deployed on your Baseten workspace.
|
||||
You can learn more about Baseten in [our docs](https://docs.baseten.co/) or read on for LangChain-specific info.
|
||||
|
||||
You can deploy foundation models like WizardLM and Alpaca with one click from the [Baseten model library](https://app.baseten.co/explore/) or if you have your own model, [deploy it with this tutorial](https://docs.baseten.co/deploying-models/deploy).
|
||||
## Setup: LangChain + Baseten
|
||||
|
||||
In this example, we'll work with WizardLM. [Deploy WizardLM here](https://app.baseten.co/explore/wizardlm) and follow along with the deployed [model's version ID](https://docs.baseten.co/managing-models/manage).
|
||||
You'll need two things to use Baseten models with LangChain:
|
||||
|
||||
- A [Baseten account](https://baseten.co)
|
||||
- An [API key](https://docs.baseten.co/observability/api-keys)
|
||||
|
||||
Export your API key to your as an environment variable called `BASETEN_API_KEY`.
|
||||
|
||||
```sh
|
||||
export BASETEN_API_KEY="paste_your_api_key_here"
|
||||
```
|
||||
|
||||
## Component guide: LLMs
|
||||
|
||||
Baseten integrates with LangChain through the [LLM component](https://python.langchain.com/docs/integrations/llms/baseten), which provides a standardized and interoperable interface for models that are deployed on your Baseten workspace.
|
||||
|
||||
You can deploy foundation models like Mistral and Llama 2 with one click from the [Baseten model library](https://app.baseten.co/explore/) or if you have your own model, [deploy it with Truss](https://truss.baseten.co/welcome).
|
||||
|
||||
In this example, we'll work with Mistral 7B. [Deploy Mistral 7B here](https://app.baseten.co/explore/mistral_7b_instruct) and follow along with the deployed model's ID, found in the model dashboard.
|
||||
|
||||
To use this module, you must:
|
||||
|
||||
* Export your Baseten API key as the environment variable BASETEN_API_KEY
|
||||
* Get the model ID for your model from your Baseten dashboard
|
||||
* Identify the model deployment ("production" for all model library models)
|
||||
|
||||
[Learn more](https://docs.baseten.co/deploy/lifecycle) about model IDs and deployments.
|
||||
|
||||
Production deployment (standard for model library models)
|
||||
|
||||
```python
|
||||
from langchain.llms import Baseten
|
||||
from langchain_community.llms import Baseten
|
||||
|
||||
wizardlm = Baseten(model="MODEL_VERSION_ID", verbose=True)
|
||||
|
||||
wizardlm("What is the difference between a Wizard and a Sorcerer?")
|
||||
mistral = Baseten(model="MODEL_ID", deployment="production")
|
||||
mistral("What is the Mistral wind?")
|
||||
```
|
||||
|
||||
Development deployment
|
||||
|
||||
```python
|
||||
from langchain_community.llms import Baseten
|
||||
|
||||
mistral = Baseten(model="MODEL_ID", deployment="development")
|
||||
mistral("What is the Mistral wind?")
|
||||
```
|
||||
|
||||
Other published deployment
|
||||
|
||||
```python
|
||||
from langchain_community.llms import Baseten
|
||||
|
||||
mistral = Baseten(model="MODEL_ID", deployment="DEPLOYMENT_ID")
|
||||
mistral("What is the Mistral wind?")
|
||||
```
|
||||
|
||||
Streaming LLM output, chat completions, embeddings models, and more are all supported on the Baseten platform and coming soon to our LangChain integration. Contact us at [support@baseten.co](mailto:support@baseten.co) with any questions about using Baseten with LangChain.
|
||||
|
@ -1,6 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
@ -9,29 +11,51 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Baseten(LLM):
|
||||
"""Baseten models.
|
||||
"""Baseten model
|
||||
|
||||
To use, you should have the ``baseten`` python package installed,
|
||||
and run ``baseten.login()`` with your Baseten API key.
|
||||
This module allows using LLMs hosted on Baseten.
|
||||
|
||||
The required ``model`` param can be either a model id or model
|
||||
version id. Using a model version ID will result in
|
||||
slightly faster invocation.
|
||||
Any other model parameters can also
|
||||
be passed in with the format input={model_param: value, ...}
|
||||
The LLM deployed on Baseten must have the following properties:
|
||||
|
||||
The Baseten model must accept a dictionary of input with the key
|
||||
"prompt" and return a dictionary with a key "data" which maps
|
||||
to a list of response strings.
|
||||
* Must accept input as a dictionary with the key "prompt"
|
||||
* May accept other input in the dictionary passed through with kwargs
|
||||
* Must return a string with the model output
|
||||
|
||||
Example:
|
||||
To use this module, you must:
|
||||
|
||||
* Export your Baseten API key as the environment variable `BASETEN_API_KEY`
|
||||
* Get the model ID for your model from your Baseten dashboard
|
||||
* Identify the model deployment ("production" for all model library models)
|
||||
|
||||
These code samples use
|
||||
[Mistral 7B Instruct](https://app.baseten.co/explore/mistral_7b_instruct)
|
||||
from Baseten's model library.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Baseten
|
||||
my_model = Baseten(model="MODEL_ID")
|
||||
output = my_model("prompt")
|
||||
# Production deployment
|
||||
mistral = Baseten(model="MODEL_ID", deployment="production")
|
||||
mistral("What is the Mistral wind?")
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Baseten
|
||||
# Development deployment
|
||||
mistral = Baseten(model="MODEL_ID", deployment="development")
|
||||
mistral("What is the Mistral wind?")
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Baseten
|
||||
# Other published deployment
|
||||
mistral = Baseten(model="MODEL_ID", deployment="DEPLOYMENT_ID")
|
||||
mistral("What is the Mistral wind?")
|
||||
"""
|
||||
|
||||
model: str
|
||||
deployment: str
|
||||
input: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@ -54,20 +78,17 @@ class Baseten(LLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to Baseten deployed model endpoint."""
|
||||
try:
|
||||
import baseten
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import Baseten Python package. "
|
||||
"Please install it with `pip install baseten`."
|
||||
) from exc
|
||||
|
||||
# get the model and version
|
||||
try:
|
||||
model = baseten.deployed_model_version_id(self.model)
|
||||
response = model.predict({"prompt": prompt, **kwargs})
|
||||
except baseten.common.core.ApiError:
|
||||
model = baseten.deployed_model_id(self.model)
|
||||
response = model.predict({"prompt": prompt, **kwargs})
|
||||
return "".join(response)
|
||||
baseten_api_key = os.environ["BASETEN_API_KEY"]
|
||||
model_id = self.model
|
||||
if self.deployment == "production":
|
||||
model_url = f"https://model-{model_id}.api.baseten.co/production/predict"
|
||||
elif self.deployment == "development":
|
||||
model_url = f"https://model-{model_id}.api.baseten.co/development/predict"
|
||||
else: # try specific deployment ID
|
||||
model_url = f"https://model-{model_id}.api.baseten.co/deployment/{self.deployment}/predict"
|
||||
response = requests.post(
|
||||
model_url,
|
||||
headers={"Authorization": f"Api-Key {baseten_api_key}"},
|
||||
json={"prompt": prompt, **kwargs},
|
||||
)
|
||||
return response.json()
|
||||
|
@ -3,12 +3,11 @@ import os
|
||||
|
||||
from langchain_community.llms.baseten import Baseten
|
||||
|
||||
# This test requires valid BASETEN_MODEL_ID and BASETEN_API_KEY environment variables
|
||||
|
||||
|
||||
def test_baseten_call() -> None:
|
||||
"""Test valid call to Baseten."""
|
||||
import baseten
|
||||
|
||||
baseten.login(os.environ["BASETEN_API_KEY"])
|
||||
llm = Baseten(model=os.environ["BASETEN_MODEL_ID"])
|
||||
output = llm("Say foo:")
|
||||
output = llm("Test prompt, please respond.")
|
||||
assert isinstance(output, str)
|
||||
|
Loading…
Reference in New Issue
Block a user