mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 23:13:31 +00:00
community: refactor Baseten integration with new API endpoints & docs (#15017)
- **Description:** In response to user feedback, this PR refactors the Baseten integration with updated model endpoints, as well as updates relevant documentation. This PR has been tested by end users in production and works as expected. - **Issue:** N/A - **Dependencies:** This PR actually removes the dependency on the `baseten` package! - **Twitter handle:** https://twitter.com/basetenco
This commit is contained in:
committed by
GitHub
parent
3fc1b3553b
commit
6342da333a
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
@@ -9,29 +11,51 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Baseten(LLM):
|
||||
"""Baseten models.
|
||||
"""Baseten model
|
||||
|
||||
To use, you should have the ``baseten`` python package installed,
|
||||
and run ``baseten.login()`` with your Baseten API key.
|
||||
This module allows using LLMs hosted on Baseten.
|
||||
|
||||
The required ``model`` param can be either a model id or model
|
||||
version id. Using a model version ID will result in
|
||||
slightly faster invocation.
|
||||
Any other model parameters can also
|
||||
be passed in with the format input={model_param: value, ...}
|
||||
The LLM deployed on Baseten must have the following properties:
|
||||
|
||||
The Baseten model must accept a dictionary of input with the key
|
||||
"prompt" and return a dictionary with a key "data" which maps
|
||||
to a list of response strings.
|
||||
* Must accept input as a dictionary with the key "prompt"
|
||||
* May accept other input in the dictionary passed through with kwargs
|
||||
* Must return a string with the model output
|
||||
|
||||
Example:
|
||||
To use this module, you must:
|
||||
|
||||
* Export your Baseten API key as the environment variable `BASETEN_API_KEY`
|
||||
* Get the model ID for your model from your Baseten dashboard
|
||||
* Identify the model deployment ("production" for all model library models)
|
||||
|
||||
These code samples use
|
||||
[Mistral 7B Instruct](https://app.baseten.co/explore/mistral_7b_instruct)
|
||||
from Baseten's model library.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Baseten
|
||||
my_model = Baseten(model="MODEL_ID")
|
||||
output = my_model("prompt")
|
||||
# Production deployment
|
||||
mistral = Baseten(model="MODEL_ID", deployment="production")
|
||||
mistral("What is the Mistral wind?")
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Baseten
|
||||
# Development deployment
|
||||
mistral = Baseten(model="MODEL_ID", deployment="development")
|
||||
mistral("What is the Mistral wind?")
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Baseten
|
||||
# Other published deployment
|
||||
mistral = Baseten(model="MODEL_ID", deployment="DEPLOYMENT_ID")
|
||||
mistral("What is the Mistral wind?")
|
||||
"""
|
||||
|
||||
model: str
|
||||
deployment: str
|
||||
input: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@@ -54,20 +78,17 @@ class Baseten(LLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to Baseten deployed model endpoint."""
|
||||
try:
|
||||
import baseten
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import Baseten Python package. "
|
||||
"Please install it with `pip install baseten`."
|
||||
) from exc
|
||||
|
||||
# get the model and version
|
||||
try:
|
||||
model = baseten.deployed_model_version_id(self.model)
|
||||
response = model.predict({"prompt": prompt, **kwargs})
|
||||
except baseten.common.core.ApiError:
|
||||
model = baseten.deployed_model_id(self.model)
|
||||
response = model.predict({"prompt": prompt, **kwargs})
|
||||
return "".join(response)
|
||||
baseten_api_key = os.environ["BASETEN_API_KEY"]
|
||||
model_id = self.model
|
||||
if self.deployment == "production":
|
||||
model_url = f"https://model-{model_id}.api.baseten.co/production/predict"
|
||||
elif self.deployment == "development":
|
||||
model_url = f"https://model-{model_id}.api.baseten.co/development/predict"
|
||||
else: # try specific deployment ID
|
||||
model_url = f"https://model-{model_id}.api.baseten.co/deployment/{self.deployment}/predict"
|
||||
response = requests.post(
|
||||
model_url,
|
||||
headers={"Authorization": f"Api-Key {baseten_api_key}"},
|
||||
json={"prompt": prompt, **kwargs},
|
||||
)
|
||||
return response.json()
|
||||
|
Reference in New Issue
Block a user