diff --git a/docs/ecosystem/baseten.md b/docs/ecosystem/baseten.md new file mode 100644 index 00000000000..8a3d8ec1b53 --- /dev/null +++ b/docs/ecosystem/baseten.md @@ -0,0 +1,25 @@ +# Baseten + +Learn how to use LangChain with models deployed on Baseten. + +## Installation and setup + +- Create a [Baseten](https://baseten.co) account and [API key](https://docs.baseten.co/settings/api-keys). +- Install the Baseten Python client with `pip install baseten` +- Use your API key to authenticate with `baseten login` + +## Invoking a model + +Baseten integrates with LangChain through the LLM module, which provides a standardized and interoperable interface for models that are deployed on your Baseten workspace. + +You can deploy foundation models like WizardLM and Alpaca with one click from the [Baseten model library](https://app.baseten.co/explore/) or if you have your own model, [deploy it with this tutorial](https://docs.baseten.co/deploying-models/deploy). + +In this example, we'll work with WizardLM. [Deploy WizardLM here](https://app.baseten.co/explore/wizardlm) and follow along with the deployed [model's version ID](https://docs.baseten.co/managing-models/manage). + +```python +from langchain.llms import Baseten + +wizardlm = Baseten(model="MODEL_VERSION_ID", verbose=True) + +wizardlm("What is the difference between a Wizard and a Sorcerer?") +``` diff --git a/docs/modules/models/llms/integrations/baseten.ipynb b/docs/modules/models/llms/integrations/baseten.ipynb new file mode 100644 index 00000000000..442a52546c6 --- /dev/null +++ b/docs/modules/models/llms/integrations/baseten.ipynb @@ -0,0 +1,196 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Baseten\n", + "\n", + "[Baseten](https://baseten.co) provides all the infrastructure you need to deploy and serve ML models performantly, scalably, and cost-efficiently.\n", + "\n", + "This example demonstrates using Langchain with models deployed on Baseten." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup\n", + "\n", + "To run this notebook, you'll need a [Baseten account](https://baseten.co) and an [API key](https://docs.baseten.co/settings/api-keys).\n", + "\n", + "You'll also need to install the Baseten Python package:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install baseten" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import baseten\n", + "\n", + "baseten.login(\"YOUR_API_KEY\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Single model call\n", + "\n", + "First, you'll need to deploy a model to Baseten.\n", + "\n", + "You can deploy foundation models like WizardLM and Alpaca with one click from the [Baseten model library](https://app.baseten.co/explore/) or if you have your own model, [deploy it with this tutorial](https://docs.baseten.co/deploying-models/deploy).\n", + "\n", + "In this example, we'll work with WizardLM. [Deploy WizardLM here](https://app.baseten.co/explore/llama) and follow along with the deployed [model's version ID](https://docs.baseten.co/managing-models/manage)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import Baseten" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the model\n", + "wizardlm = Baseten(model=\"MODEL_VERSION_ID\", verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prompt the model\n", + "\n", + "wizardlm(\"What is the difference between a Wizard and a Sorcerer?\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Chained model calls\n", + "\n", + "We can chain together multiple calls to one or multiple models, which is the whole point of Langchain!\n", + "\n", + "This example uses WizardLM to plan a meal with an entree, three sides, and an alcoholic and non-alcoholic beverage pairing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import SimpleSequentialChain\n", + "from langchain import PromptTemplate, LLMChain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build the first link in the chain\n", + "\n", + "prompt = PromptTemplate(\n", + " input_variables=[\"cuisine\"],\n", + " template=\"Name a complex entree for a {cuisine} dinner. Respond with just the name of a single dish.\",\n", + ")\n", + "\n", + "link_one = LLMChain(llm=wizardlm, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build the second link in the chain\n", + "\n", + "prompt = PromptTemplate(\n", + " input_variables=[\"entree\"],\n", + " template=\"What are three sides that would go with {entree}. Respond with only a list of the sides.\",\n", + ")\n", + "\n", + "link_two = LLMChain(llm=wizardlm, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build the third link in the chain\n", + "\n", + "prompt = PromptTemplate(\n", + " input_variables=[\"sides\"],\n", + " template=\"What is one alcoholic and one non-alcoholic beverage that would go well with this list of sides: {sides}. Respond with only the names of the beverages.\",\n", + ")\n", + "\n", + "link_three = LLMChain(llm=wizardlm, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run the full chain!\n", + "\n", + "menu_maker = SimpleSequentialChain(chains=[link_one, link_two, link_three], verbose=True)\n", + "menu_maker.run(\"South Indian\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index e551a23f293..42d7c03037d 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -8,6 +8,7 @@ from langchain.llms.anyscale import Anyscale from langchain.llms.aviary import Aviary from langchain.llms.bananadev import Banana from langchain.llms.base import BaseLLM +from langchain.llms.baseten import Baseten from langchain.llms.beam import Beam from langchain.llms.bedrock import Bedrock from langchain.llms.cerebriumai import CerebriumAI @@ -50,6 +51,7 @@ __all__ = [ "Anyscale", "Aviary", "Banana", + "Baseten", "Beam", "Bedrock", "CerebriumAI", @@ -98,6 +100,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "anyscale": Anyscale, "aviary": Aviary, "bananadev": Banana, + "baseten": Baseten, "beam": Beam, "cerebriumai": CerebriumAI, "cohere": Cohere, diff --git a/langchain/llms/baseten.py b/langchain/llms/baseten.py new file mode 100644 index 00000000000..5637fc41570 --- /dev/null +++ b/langchain/llms/baseten.py @@ -0,0 +1,74 @@ +"""Wrapper around Baseten deployed model API.""" +import logging +from typing import Any, Dict, List, Mapping, Optional + +from pydantic import Field + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM + +logger = logging.getLogger(__name__) + + +class Baseten(LLM): + """Use your Baseten models in Langchain + + To use, you should have the ``baseten`` python package installed, + and run ``baseten.login()`` with your Baseten API key. + + The required ``model`` param can be either a model id or model + version id. Using a model version ID will result in + slightly faster invocation. + Any other model parameters can also + be passed in with the format input={model_param: value, ...} + + The Baseten model must accept a dictionary of input with the key + "prompt" and return a dictionary with a key "data" which maps + to a list of response strings. + + Example: + .. code-block:: python + from langchain.llms import Baseten + my_model = Baseten(model="MODEL_ID") + output = my_model("prompt") + """ + + model: str + input: Dict[str, Any] = Field(default_factory=dict) + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + **{"model_kwargs": self.model_kwargs}, + } + + @property + def _llm_type(self) -> str: + """Return type of model.""" + return "baseten" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + """Call to Baseten deployed model endpoint.""" + try: + import baseten + except ImportError as exc: + raise ValueError( + "Could not import Baseten Python package. " + "Please install it with `pip install baseten`." + ) from exc + + # get the model and version + try: + model = baseten.deployed_model_version_id(self.model) + response = model.predict({"prompt": prompt}) + except baseten.common.core.ApiError: + model = baseten.deployed_model_id(self.model) + response = model.predict({"prompt": prompt}) + return "".join(response) diff --git a/tests/integration_tests/llms/test_baseten.py b/tests/integration_tests/llms/test_baseten.py new file mode 100644 index 00000000000..0e1226b91b2 --- /dev/null +++ b/tests/integration_tests/llms/test_baseten.py @@ -0,0 +1,16 @@ +"""Test Baseten API wrapper.""" +import os + +import baseten +import pytest + +from langchain.llms.baseten import Baseten + + +@pytest.mark.requires(baseten) +def test_baseten_call() -> None: + """Test valid call to Baseten.""" + baseten.login(os.environ["BASETEN_API_KEY"]) + llm = Baseten(model=os.environ["BASETEN_MODEL_ID"]) + output = llm("Say foo:") + assert isinstance(output, str)