mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
gradient.ai LLM intregration (#10800)
- **Description:** This PR implements a new LLM API to https://gradient.ai - **Issue:** Feature request for LLM #10745 - **Dependencies**: No additional dependencies are introduced. - **Tag maintainer:** I am opening this PR for visibility, once ready for review I'll tag. - ```make format && make lint && make test``` is running. - added a `integration` and `mock unit` test. Co-authored-by: michaelfeil <me@michaelfeil.eu> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
5097007407
commit
55570e54e1
216
docs/extras/integrations/llms/gradient.ipynb
Normal file
216
docs/extras/integrations/llms/gradient.ipynb
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Gradient\n",
|
||||||
|
"\n",
|
||||||
|
"`Gradient` allows to fine tune and get completions on LLMs with a simple web API.\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use Langchain with [Gradient](https://gradient.ai/).\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Imports"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import requests\n",
|
||||||
|
"from langchain.llms import GradientLLM\n",
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from langchain.chains import LLMChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Set the Environment API Key\n",
|
||||||
|
"Make sure to get your API key from Gradient AI. You are given $10 in free credits to test and fine-tune different models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from getpass import getpass\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"if not os.environ.get(\"GRADIENT_ACCESS_TOKEN\",None):\n",
|
||||||
|
" # Access token under https://auth.gradient.ai/select-workspace\n",
|
||||||
|
" os.environ[\"GRADIENT_ACCESS_TOKEN\"] = getpass(\"gradient.ai access token:\")\n",
|
||||||
|
"if not os.environ.get(\"GRADIENT_WORKSPACE_ID\",None):\n",
|
||||||
|
" # `ID` listed in `$ gradient workspace list`\n",
|
||||||
|
" # also displayed after login at at https://auth.gradient.ai/select-workspace\n",
|
||||||
|
" os.environ[\"GRADIENT_WORKSPACE_ID\"] = getpass(\"gradient.ai workspace id:\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Optional: Validate your Enviroment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Credentials valid.\n",
|
||||||
|
"Possible values for `model_id` are:\n",
|
||||||
|
" {'models': [{'id': '99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model', 'name': 'bloom-560m', 'slug': 'bloom-560m', 'type': 'baseModel'}, {'id': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'name': 'llama2-7b-chat', 'slug': 'llama2-7b-chat', 'type': 'baseModel'}, {'id': 'cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model', 'name': 'nous-hermes2', 'slug': 'nous-hermes2', 'type': 'baseModel'}, {'baseModelId': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'id': 'bb7b9865-0ce3-41a8-8e2b-5cbcbe1262eb_model_adapter', 'name': 'optical-transmitting-sensor', 'type': 'modelAdapter'}]}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import requests\n",
|
||||||
|
"\n",
|
||||||
|
"resp = requests.get(f'https://api.gradient.ai/api/models', headers={\n",
|
||||||
|
" \"authorization\": f\"Bearer {os.environ['GRADIENT_ACCESS_TOKEN']}\",\n",
|
||||||
|
" \"x-gradient-workspace-id\": f\"{os.environ['GRADIENT_WORKSPACE_ID']}\",\n",
|
||||||
|
" },\n",
|
||||||
|
" )\n",
|
||||||
|
"if resp.status_code == 200:\n",
|
||||||
|
" models = resp.json()\n",
|
||||||
|
" print(\"Credentials valid.\\nPossible values for `model_id` are:\\n\", models)\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"Error when listing models. Are your credentials valid?\", resp.text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create the Gradient instance\n",
|
||||||
|
"You can specify different parameters such as the model name, max tokens generated, temperature, etc."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = GradientLLM(\n",
|
||||||
|
" # `ID` listed in `$ gradient model list`\n",
|
||||||
|
" model_id=\"99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\",\n",
|
||||||
|
" # # optional: set new credentials, they default to environment variables\n",
|
||||||
|
" # gradient_workspace_id=os.environ[\"GRADIENT_WORKSPACE_ID\"],\n",
|
||||||
|
" # gradient_access_token=os.environ[\"GRADIENT_ACCESS_TOKEN\"],\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create a Prompt Template\n",
|
||||||
|
"We will create a prompt template for Question and Answer."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"template = \"\"\"Question: {question}\n",
|
||||||
|
"\n",
|
||||||
|
"Answer: Let's think step by step.\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Initiate the LLMChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Run the LLMChain\n",
|
||||||
|
"Provide a question and run the LLMChain."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"' The first team to win the Super Bowl was the New England Patriots. The Patriots won the'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"question = \"What NFL team won the Super Bowl in 1994?\"\n",
|
||||||
|
"\n",
|
||||||
|
"llm_chain.run(\n",
|
||||||
|
" question=question\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.13"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -49,6 +49,7 @@ from langchain.llms.forefrontai import ForefrontAI
|
|||||||
from langchain.llms.google_palm import GooglePalm
|
from langchain.llms.google_palm import GooglePalm
|
||||||
from langchain.llms.gooseai import GooseAI
|
from langchain.llms.gooseai import GooseAI
|
||||||
from langchain.llms.gpt4all import GPT4All
|
from langchain.llms.gpt4all import GPT4All
|
||||||
|
from langchain.llms.gradient_ai import GradientLLM
|
||||||
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
@ -119,6 +120,7 @@ __all__ = [
|
|||||||
"GPT4All",
|
"GPT4All",
|
||||||
"GooglePalm",
|
"GooglePalm",
|
||||||
"GooseAI",
|
"GooseAI",
|
||||||
|
"GradientLLM",
|
||||||
"HuggingFaceEndpoint",
|
"HuggingFaceEndpoint",
|
||||||
"HuggingFaceHub",
|
"HuggingFaceHub",
|
||||||
"HuggingFacePipeline",
|
"HuggingFacePipeline",
|
||||||
@ -193,6 +195,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
|||||||
"forefrontai": ForefrontAI,
|
"forefrontai": ForefrontAI,
|
||||||
"google_palm": GooglePalm,
|
"google_palm": GooglePalm,
|
||||||
"gooseai": GooseAI,
|
"gooseai": GooseAI,
|
||||||
|
"gradient": GradientLLM,
|
||||||
"gpt4all": GPT4All,
|
"gpt4all": GPT4All,
|
||||||
"huggingface_endpoint": HuggingFaceEndpoint,
|
"huggingface_endpoint": HuggingFaceEndpoint,
|
||||||
"huggingface_hub": HuggingFaceHub,
|
"huggingface_hub": HuggingFaceHub,
|
||||||
|
236
libs/langchain/langchain/llms/gradient_ai.py
Normal file
236
libs/langchain/langchain/llms/gradient_ai.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from langchain.pydantic_v1 import Extra, root_validator
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
class GradientLLM(LLM):
|
||||||
|
"""Gradient.ai LLM Endpoints.
|
||||||
|
|
||||||
|
GradientLLM is a class to interact with LLMs on gradient.ai
|
||||||
|
|
||||||
|
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
|
||||||
|
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
|
||||||
|
or alternatively provide them as keywords to the constructor of this class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms.gradientai_endpoint import GradientAIEndpoint
|
||||||
|
GradientLLM(
|
||||||
|
model_id="cad6644_base_ml_model",
|
||||||
|
model_kwargs={
|
||||||
|
"max_generated_token_count": 200,
|
||||||
|
"temperature": 0.75,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"top_k": 20,
|
||||||
|
"stop": [],
|
||||||
|
},
|
||||||
|
gradient_workspace_id="12345614fc0_workspace",
|
||||||
|
gradient_access_token="gradientai-access_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_id: str
|
||||||
|
"Underlying gradient.ai model id (base or fine-tuned)."
|
||||||
|
|
||||||
|
gradient_workspace_id: Optional[str] = None
|
||||||
|
"Underlying gradient.ai workspace_id."
|
||||||
|
|
||||||
|
gradient_access_token: Optional[str] = None
|
||||||
|
"""gradient.ai API Token, which can be generated by going to
|
||||||
|
https://auth.gradient.ai/select-workspace
|
||||||
|
and selecting "Access tokens" under the profile drop-down.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_kwargs: Optional[dict] = None
|
||||||
|
"""Key word arguments to pass to the model."""
|
||||||
|
|
||||||
|
gradient_api_url: str = "https://api.gradient.ai/api"
|
||||||
|
"""Endpoint URL to use."""
|
||||||
|
|
||||||
|
aiosession: Optional[aiohttp.ClientSession] = None
|
||||||
|
"""ClientSession, in case we want to reuse connection for better performance."""
|
||||||
|
|
||||||
|
# LLM call kwargs
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator(allow_reuse=True)
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
|
||||||
|
values["gradient_access_token"] = get_from_dict_or_env(
|
||||||
|
values, "gradient_access_token", "GRADIENT_ACCESS_TOKEN"
|
||||||
|
)
|
||||||
|
values["gradient_workspace_id"] = get_from_dict_or_env(
|
||||||
|
values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
values["gradient_access_token"] is None
|
||||||
|
or len(values["gradient_access_token"]) < 10
|
||||||
|
):
|
||||||
|
raise ValueError("env variable `GRADIENT_ACCESS_TOKEN` must be set")
|
||||||
|
|
||||||
|
if (
|
||||||
|
values["gradient_workspace_id"] is None
|
||||||
|
or len(values["gradient_access_token"]) < 3
|
||||||
|
):
|
||||||
|
raise ValueError("env variable `GRADIENT_WORKSPACE_ID` must be set")
|
||||||
|
|
||||||
|
if values["model_kwargs"]:
|
||||||
|
kw = values["model_kwargs"]
|
||||||
|
if not 0 <= kw.get("temperature", 0.5) <= 1:
|
||||||
|
raise ValueError("`temperature` must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
|
if not 0 <= kw.get("top_p", 0.5) <= 1:
|
||||||
|
raise ValueError("`top_p` must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
|
if 0 >= kw.get("top_k", 0.5):
|
||||||
|
raise ValueError("`top_k` must be positive")
|
||||||
|
|
||||||
|
if 0 >= kw.get("max_generated_token_count", 1):
|
||||||
|
raise ValueError("`max_generated_token_count` must be positive")
|
||||||
|
|
||||||
|
values["gradient_api_url"] = get_from_dict_or_env(
|
||||||
|
values, "gradient_api_url", "GRADIENT_API_URL"
|
||||||
|
)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
return {
|
||||||
|
**{"gradient_api_url": self.gradient_api_url},
|
||||||
|
**{"model_kwargs": _model_kwargs},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "gradient"
|
||||||
|
|
||||||
|
def _kwargs_post_request(
|
||||||
|
self, prompt: str, kwargs: Mapping[str, Any]
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""Build the kwargs for the Post request, used by sync
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): prompt used in query
|
||||||
|
kwargs (dict): model kwargs in payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Union[str,dict]]: _description_
|
||||||
|
"""
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
_params = {**_model_kwargs, **kwargs}
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
url=f"{self.gradient_api_url}/models/{self.model_id}/complete",
|
||||||
|
headers={
|
||||||
|
"authorization": f"Bearer {self.gradient_access_token}",
|
||||||
|
"x-gradient-workspace-id": f"{self.gradient_workspace_id}",
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
},
|
||||||
|
json=dict(
|
||||||
|
query=prompt,
|
||||||
|
maxGeneratedTokenCount=_params.get("max_generated_token_count", None),
|
||||||
|
temperature=_params.get("temperature", None),
|
||||||
|
topK=_params.get("top_k", None),
|
||||||
|
topP=_params.get("top_p", None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call to Gradients API `model/{id}/complete`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.post(**self._kwargs_post_request(prompt, kwargs))
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Gradient returned an unexpected response with status "
|
||||||
|
f"{response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise Exception(f"RequestException while calling Gradient Endpoint: {e}")
|
||||||
|
|
||||||
|
text = response.json()["generatedOutput"]
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
# Apply stop tokens when making calls to Gradient
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Union[List[str], None] = None,
|
||||||
|
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Async Call to Gradients API `model/{id}/complete`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
"""
|
||||||
|
if not self.aiosession:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
**self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Gradient returned an unexpected response with status "
|
||||||
|
f"{response.status}: {response.text}"
|
||||||
|
)
|
||||||
|
text = (await response.json())["generatedOutput"]
|
||||||
|
else:
|
||||||
|
async with self.aiosession.post(
|
||||||
|
**self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Gradient returned an unexpected response with status "
|
||||||
|
f"{response.status}: {response.text}"
|
||||||
|
)
|
||||||
|
text = (await response.json())["generatedOutput"]
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
# Apply stop tokens when making calls to Gradient
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
|
return text
|
@ -0,0 +1,36 @@
|
|||||||
|
"""Test GradientAI API wrapper.
|
||||||
|
|
||||||
|
In order to run this test, you need to have an GradientAI api key.
|
||||||
|
You can get it by registering for free at https://gradient.ai/.
|
||||||
|
|
||||||
|
You'll then need to set:
|
||||||
|
- `GRADIENT_ACCESS_TOKEN` environment variable to your api key.
|
||||||
|
- `GRADIENT_WORKSPACE_ID` environment variable to your workspace id.
|
||||||
|
- `GRADIENT_MODEL_ID` environment variable to your workspace id.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from langchain.llms import GradientLLM
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_acall() -> None:
|
||||||
|
"""Test simple call to gradient.ai."""
|
||||||
|
model_id = os.environ["GRADIENT_MODEL_ID"]
|
||||||
|
llm = GradientLLM(model_id=model_id)
|
||||||
|
output = llm("Say hello:", temperature=0.2, max_tokens=250)
|
||||||
|
|
||||||
|
assert llm._llm_type == "gradient"
|
||||||
|
|
||||||
|
assert isinstance(output, str)
|
||||||
|
assert len(output)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_gradientai_acall() -> None:
|
||||||
|
"""Test async call to gradient.ai."""
|
||||||
|
model_id = os.environ["GRADIENT_MODEL_ID"]
|
||||||
|
llm = GradientLLM(model_id=model_id)
|
||||||
|
output = await llm.agenerate(["Say hello:"], temperature=0.2, max_tokens=250)
|
||||||
|
assert llm._llm_type == "gradient"
|
||||||
|
|
||||||
|
assert isinstance(output, str)
|
||||||
|
assert len(output)
|
56
libs/langchain/tests/unit_tests/llms/test_gradient_ai.py
Normal file
56
libs/langchain/tests/unit_tests/llms/test_gradient_ai.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from langchain.llms import GradientLLM
|
||||||
|
|
||||||
|
_MODEL_ID = "my_model_valid_id"
|
||||||
|
_GRADIENT_SECRET = "secret_valid_token_123456"
|
||||||
|
_GRADIENT_WORKSPACE_ID = "valid_workspace_12345"
|
||||||
|
_GRADIENT_BASE_URL = "https://api.gradient.ai/api"
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, json_data: Dict, status_code: int):
|
||||||
|
self.json_data = json_data
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
def json(self) -> Dict:
|
||||||
|
return self.json_data
|
||||||
|
|
||||||
|
|
||||||
|
def mocked_requests_post(
|
||||||
|
url: str,
|
||||||
|
headers: dict,
|
||||||
|
json: dict,
|
||||||
|
) -> MockResponse:
|
||||||
|
assert url.startswith(_GRADIENT_BASE_URL)
|
||||||
|
assert headers
|
||||||
|
assert json
|
||||||
|
|
||||||
|
return MockResponse(
|
||||||
|
json_data={"generatedOutput": "bar"},
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_llm_sync(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||||
|
|
||||||
|
llm = GradientLLM(
|
||||||
|
gradient_api_url=_GRADIENT_BASE_URL,
|
||||||
|
gradient_access_token=_GRADIENT_SECRET,
|
||||||
|
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||||
|
model_id=_MODEL_ID,
|
||||||
|
)
|
||||||
|
assert llm.gradient_access_token == _GRADIENT_SECRET
|
||||||
|
assert llm.gradient_api_url == _GRADIENT_BASE_URL
|
||||||
|
assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
||||||
|
assert llm.model_id == _MODEL_ID
|
||||||
|
|
||||||
|
response = llm("Say foo:")
|
||||||
|
want = "bar"
|
||||||
|
|
||||||
|
assert response == want
|
Loading…
Reference in New Issue
Block a user