community: refactor Baseten integration with new API endpoints & docs (#15017)

- **Description:** In response to user feedback, this PR refactors the
Baseten integration with updated model endpoints, as well as updates
relevant documentation. This PR has been tested by end users in
production and works as expected.
  - **Issue:** N/A
- **Dependencies:** This PR actually removes the dependency on the
`baseten` package!
  - **Twitter handle:** https://twitter.com/basetenco
This commit is contained in:
Philip Kiely - Baseten
2023-12-22 14:46:24 -06:00
committed by GitHub
parent 3fc1b3553b
commit 6342da333a
4 changed files with 172 additions and 114 deletions

View File

@@ -1,6 +1,8 @@
import logging
import os
from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field
@@ -9,29 +11,51 @@ logger = logging.getLogger(__name__)
class Baseten(LLM):
"""Baseten models.
"""Baseten model
To use, you should have the ``baseten`` python package installed,
and run ``baseten.login()`` with your Baseten API key.
This module allows using LLMs hosted on Baseten.
The required ``model`` param can be either a model id or model
version id. Using a model version ID will result in
slightly faster invocation.
Any other model parameters can also
be passed in with the format input={model_param: value, ...}
The LLM deployed on Baseten must have the following properties:
The Baseten model must accept a dictionary of input with the key
"prompt" and return a dictionary with a key "data" which maps
to a list of response strings.
* Must accept input as a dictionary with the key "prompt"
* May accept other input in the dictionary passed through with kwargs
* Must return a string with the model output
Example:
To use this module, you must:
* Export your Baseten API key as the environment variable `BASETEN_API_KEY`
* Get the model ID for your model from your Baseten dashboard
* Identify the model deployment ("production" for all model library models)
These code samples use
[Mistral 7B Instruct](https://app.baseten.co/explore/mistral_7b_instruct)
from Baseten's model library.
Examples:
.. code-block:: python
from langchain_community.llms import Baseten
my_model = Baseten(model="MODEL_ID")
output = my_model("prompt")
# Production deployment
mistral = Baseten(model="MODEL_ID", deployment="production")
mistral("What is the Mistral wind?")
.. code-block:: python
from langchain_community.llms import Baseten
# Development deployment
mistral = Baseten(model="MODEL_ID", deployment="development")
mistral("What is the Mistral wind?")
.. code-block:: python
from langchain_community.llms import Baseten
# Other published deployment
mistral = Baseten(model="MODEL_ID", deployment="DEPLOYMENT_ID")
mistral("What is the Mistral wind?")
"""
model: str
deployment: str
input: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@@ -54,20 +78,17 @@ class Baseten(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to Baseten deployed model endpoint."""
try:
import baseten
except ImportError as exc:
raise ImportError(
"Could not import Baseten Python package. "
"Please install it with `pip install baseten`."
) from exc
# get the model and version
try:
model = baseten.deployed_model_version_id(self.model)
response = model.predict({"prompt": prompt, **kwargs})
except baseten.common.core.ApiError:
model = baseten.deployed_model_id(self.model)
response = model.predict({"prompt": prompt, **kwargs})
return "".join(response)
baseten_api_key = os.environ["BASETEN_API_KEY"]
model_id = self.model
if self.deployment == "production":
model_url = f"https://model-{model_id}.api.baseten.co/production/predict"
elif self.deployment == "development":
model_url = f"https://model-{model_id}.api.baseten.co/development/predict"
else: # try specific deployment ID
model_url = f"https://model-{model_id}.api.baseten.co/deployment/{self.deployment}/predict"
response = requests.post(
model_url,
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={"prompt": prompt, **kwargs},
)
return response.json()