mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
GradientLLM Docs update and model_id renaming. (#10963)
Related to #10800 - Errors in the Docstring of GradientLLM / Gradient.ai LLM - Renamed the `model_id` to `model` and adapting this in all tests. Reason to so is to be in Sync with `GradientEmbeddings` and other LLM's. - inmproving tests so they check the headers in the sent request. - making the aiosession a private attribute in the docs, as in the future `pip install gradientai` will be replacing aiosession. - adding a example how to fine-tune on the Prompt Template as suggested in #10800
This commit is contained in:
parent
6876b02c87
commit
233a904f2e
@ -24,8 +24,6 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
|
||||||
"import requests\n",
|
|
||||||
"from langchain.llms import GradientLLM\n",
|
"from langchain.llms import GradientLLM\n",
|
||||||
"from langchain.prompts import PromptTemplate\n",
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
"from langchain.chains import LLMChain"
|
"from langchain.chains import LLMChain"
|
||||||
@ -46,7 +44,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from getpass import getpass\n",
|
"from getpass import getpass\n",
|
||||||
"\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"if not os.environ.get(\"GRADIENT_ACCESS_TOKEN\",None):\n",
|
"if not os.environ.get(\"GRADIENT_ACCESS_TOKEN\",None):\n",
|
||||||
" # Access token under https://auth.gradient.ai/select-workspace\n",
|
" # Access token under https://auth.gradient.ai/select-workspace\n",
|
||||||
@ -61,7 +59,7 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Optional: Validate your Environment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models."
|
"Optional: Validate your Enviroment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models. Using the `gradientai` Python package."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -73,25 +71,64 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Credentials valid.\n",
|
"Requirement already satisfied: gradientai in /home/michi/.venv/lib/python3.10/site-packages (1.0.0)\n",
|
||||||
"Possible values for `model_id` are:\n",
|
"Requirement already satisfied: aenum>=3.1.11 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (3.1.15)\n",
|
||||||
" {'models': [{'id': '99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model', 'name': 'bloom-560m', 'slug': 'bloom-560m', 'type': 'baseModel'}, {'id': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'name': 'llama2-7b-chat', 'slug': 'llama2-7b-chat', 'type': 'baseModel'}, {'id': 'cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model', 'name': 'nous-hermes2', 'slug': 'nous-hermes2', 'type': 'baseModel'}, {'baseModelId': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'id': 'bb7b9865-0ce3-41a8-8e2b-5cbcbe1262eb_model_adapter', 'name': 'optical-transmitting-sensor', 'type': 'modelAdapter'}]}\n"
|
"Requirement already satisfied: pydantic<2.0.0,>=1.10.5 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (1.10.12)\n",
|
||||||
|
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (2.8.2)\n",
|
||||||
|
"Requirement already satisfied: urllib3>=1.25.3 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (1.26.16)\n",
|
||||||
|
"Requirement already satisfied: typing-extensions>=4.2.0 in /home/michi/.venv/lib/python3.10/site-packages (from pydantic<2.0.0,>=1.10.5->gradientai) (4.5.0)\n",
|
||||||
|
"Requirement already satisfied: six>=1.5 in /home/michi/.venv/lib/python3.10/site-packages (from python-dateutil>=2.8.2->gradientai) (1.16.0)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import requests\n",
|
"!pip install gradientai"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\n",
|
||||||
|
"f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model\n",
|
||||||
|
"cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import gradientai\n",
|
||||||
"\n",
|
"\n",
|
||||||
"resp = requests.get(f'https://api.gradient.ai/api/models', headers={\n",
|
"client = gradientai.Gradient()\n",
|
||||||
" \"authorization\": f\"Bearer {os.environ['GRADIENT_ACCESS_TOKEN']}\",\n",
|
"\n",
|
||||||
" \"x-gradient-workspace-id\": f\"{os.environ['GRADIENT_WORKSPACE_ID']}\",\n",
|
"models = client.list_models(only_base=True)\n",
|
||||||
" },\n",
|
"for model in models:\n",
|
||||||
" )\n",
|
" print(model.id)"
|
||||||
"if resp.status_code == 200:\n",
|
]
|
||||||
" models = resp.json()\n",
|
},
|
||||||
" print(\"Credentials valid.\\nPossible values for `model_id` are:\\n\", models)\n",
|
{
|
||||||
"else:\n",
|
"cell_type": "code",
|
||||||
" print(\"Error when listing models. Are your credentials valid?\", resp.text)"
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"('674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter', 'my_model_adapter')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"new_model = models[-1].create_model_adapter(name=\"my_model_adapter\")\n",
|
||||||
|
"new_model.id, new_model.name"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -99,21 +136,24 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Create the Gradient instance\n",
|
"## Create the Gradient instance\n",
|
||||||
"You can specify different parameters such as the model name, max tokens generated, temperature, etc."
|
"You can specify different parameters such as the model, max_tokens generated, temperature, etc.\n",
|
||||||
|
"\n",
|
||||||
|
"As we later want to fine-tune out model, we select the model_adapter with the id `674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter`, but you can use any base or fine-tunable model."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"llm = GradientLLM(\n",
|
"llm = GradientLLM(\n",
|
||||||
" # `ID` listed in `$ gradient model list`\n",
|
" # `ID` listed in `$ gradient model list`\n",
|
||||||
" model_id=\"99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\",\n",
|
" model=\"674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter\",\n",
|
||||||
" # # optional: set new credentials, they default to environment variables\n",
|
" # # optional: set new credentials, they default to environment variables\n",
|
||||||
" # gradient_workspace_id=os.environ[\"GRADIENT_WORKSPACE_ID\"],\n",
|
" # gradient_workspace_id=os.environ[\"GRADIENT_WORKSPACE_ID\"],\n",
|
||||||
" # gradient_access_token=os.environ[\"GRADIENT_ACCESS_TOKEN\"],\n",
|
" # gradient_access_token=os.environ[\"GRADIENT_ACCESS_TOKEN\"],\n",
|
||||||
|
" model_kwargs=dict(max_generated_token_count=128)\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -127,13 +167,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"template = \"\"\"Question: {question}\n",
|
"template = \"\"\"Question: {question}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Answer: Let's think step by step.\"\"\"\n",
|
"Answer: \"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||||
]
|
]
|
||||||
@ -147,7 +187,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -164,16 +204,16 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"' The first team to win the Super Bowl was the New England Patriots. The Patriots won the'"
|
"'\\nThe San Francisco 49ers won the Super Bowl in 1994.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -185,6 +225,88 @@
|
|||||||
" question=question\n",
|
" question=question\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Improve the results by fine-tuning (optional)\n",
|
||||||
|
"Well - that is wrong - the San Francisco 49ers did not win.\n",
|
||||||
|
"The correct answer to the question would be `The Dallas Cowboys!`.\n",
|
||||||
|
"\n",
|
||||||
|
"Let's increase the odds for the correct answer, by fine-tuning on the correct answer using the PromptTemplate."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[{'inputs': 'Question: What NFL team won the Super Bowl in 1994?\\n\\nAnswer: The Dallas Cowboys!'}]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"dataset = [{\"inputs\": template.format(question=\"What NFL team won the Super Bowl in 1994?\") + \" The Dallas Cowboys!\"}]\n",
|
||||||
|
"dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"FineTuneResponse(number_of_trainable_tokens=27, sum_loss=78.17996)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"new_model.fine_tune(\n",
|
||||||
|
" samples=dataset\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'The Dallas Cowboys'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# we can keep the llm_chain, as the registered model just got refreshed on the gradient.ai servers.\n",
|
||||||
|
"llm_chain.run(\n",
|
||||||
|
" question=question\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -203,7 +325,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.13"
|
"version": "3.10.6"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
|
27
docs/docs/integrations/providers/gradient.mdx
Normal file
27
docs/docs/integrations/providers/gradient.mdx
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Gradient
|
||||||
|
|
||||||
|
>[Gradient](https://gradient.ai/) allows to fine tune and get completions on LLMs with a simple web API.
|
||||||
|
|
||||||
|
## Installation and Setup
|
||||||
|
- Install the Python SDK :
|
||||||
|
```bash
|
||||||
|
pip install gradientai
|
||||||
|
```
|
||||||
|
Get a [Gradient access token and workspace](https://gradient.ai/) and set it as an environment variable (`Gradient_ACCESS_TOKEN`) and (`GRADIENT_WORKSPACE_ID`)
|
||||||
|
|
||||||
|
## LLM
|
||||||
|
|
||||||
|
There exists an Gradient LLM wrapper, which you can access with
|
||||||
|
See a [usage example](/docs/integrations/llms/gradient).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.llms import GradientLLM
|
||||||
|
```
|
||||||
|
|
||||||
|
## Text Embedding Model
|
||||||
|
|
||||||
|
There exists an Gradient Embedding model, which you can access with
|
||||||
|
```python
|
||||||
|
from langchain.embeddings import GradientEmbeddings
|
||||||
|
```
|
||||||
|
For a more detailed walkthrough of this, see [this notebook](/docs/integrations/text_embedding/gradient.html)
|
@ -12,6 +12,8 @@ from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
|||||||
from langchain.schema.embeddings import Embeddings
|
from langchain.schema.embeddings import Embeddings
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
__all__ = ["GradientEmbeddings"]
|
||||||
|
|
||||||
|
|
||||||
class GradientEmbeddings(BaseModel, Embeddings):
|
class GradientEmbeddings(BaseModel, Embeddings):
|
||||||
"""Gradient.ai Embedding models.
|
"""Gradient.ai Embedding models.
|
||||||
@ -48,7 +50,7 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
|||||||
gradient_api_url: str = "https://api.gradient.ai/api"
|
gradient_api_url: str = "https://api.gradient.ai/api"
|
||||||
"""Endpoint URL to use."""
|
"""Endpoint URL to use."""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any = None #: :meta private:
|
||||||
"""Gradient client."""
|
"""Gradient client."""
|
||||||
|
|
||||||
# LLM call kwargs
|
# LLM call kwargs
|
||||||
@ -143,8 +145,9 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
|||||||
return embeddings[0]
|
return embeddings[0]
|
||||||
|
|
||||||
|
|
||||||
class TinyAsyncGradientEmbeddingClient:
|
class TinyAsyncGradientEmbeddingClient: #: :meta private:
|
||||||
"""A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API.
|
"""A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API,
|
||||||
|
direct use discouraged.
|
||||||
|
|
||||||
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
|
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
|
||||||
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
|
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict, Union
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
@ -7,9 +10,10 @@ from langchain.callbacks.manager import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.pydantic_v1 import Extra, root_validator
|
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
@ -17,7 +21,7 @@ class TrainResult(TypedDict):
|
|||||||
loss: float
|
loss: float
|
||||||
|
|
||||||
|
|
||||||
class GradientLLM(LLM):
|
class GradientLLM(BaseLLM):
|
||||||
"""Gradient.ai LLM Endpoints.
|
"""Gradient.ai LLM Endpoints.
|
||||||
|
|
||||||
GradientLLM is a class to interact with LLMs on gradient.ai
|
GradientLLM is a class to interact with LLMs on gradient.ai
|
||||||
@ -29,11 +33,11 @@ class GradientLLM(LLM):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.llms.gradientai_endpoint import GradientAIEndpoint
|
from langchain.llms import GradientLLM
|
||||||
GradientLLM(
|
GradientLLM(
|
||||||
model_id="cad6644_base_ml_model",
|
model="99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model",
|
||||||
model_kwargs={
|
model_kwargs={
|
||||||
"max_generated_token_count": 200,
|
"max_generated_token_count": 128,
|
||||||
"temperature": 0.75,
|
"temperature": 0.75,
|
||||||
"top_p": 0.95,
|
"top_p": 0.95,
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
@ -45,7 +49,7 @@ class GradientLLM(LLM):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_id: str
|
model_id: str = Field(alias="model", min_length=2)
|
||||||
"Underlying gradient.ai model id (base or fine-tuned)."
|
"Underlying gradient.ai model id (base or fine-tuned)."
|
||||||
|
|
||||||
gradient_workspace_id: Optional[str] = None
|
gradient_workspace_id: Optional[str] = None
|
||||||
@ -63,13 +67,14 @@ class GradientLLM(LLM):
|
|||||||
gradient_api_url: str = "https://api.gradient.ai/api"
|
gradient_api_url: str = "https://api.gradient.ai/api"
|
||||||
"""Endpoint URL to use."""
|
"""Endpoint URL to use."""
|
||||||
|
|
||||||
aiosession: Optional[aiohttp.ClientSession] = None
|
aiosession: Optional[aiohttp.ClientSession] = None #: :meta private:
|
||||||
"""ClientSession, in case we want to reuse connection for better performance."""
|
"""ClientSession, private, subject to change in upcoming releases."""
|
||||||
|
|
||||||
# LLM call kwargs
|
# LLM call kwargs
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
@root_validator(allow_reuse=True)
|
@root_validator(allow_reuse=True)
|
||||||
@ -113,6 +118,16 @@ class GradientLLM(LLM):
|
|||||||
values, "gradient_api_url", "GRADIENT_API_URL"
|
values, "gradient_api_url", "GRADIENT_API_URL"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gradientai # noqa
|
||||||
|
except ImportError:
|
||||||
|
logging.warning(
|
||||||
|
"DeprecationWarning: `GradientLLM` will use "
|
||||||
|
"`pip install gradientai` in future releases of langchain."
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -243,8 +258,8 @@ class GradientLLM(LLM):
|
|||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Union[List[str], None] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Async Call to Gradients API `model/{id}/complete`.
|
"""Async Call to Gradients API `model/{id}/complete`.
|
||||||
@ -284,6 +299,49 @@ class GradientLLM(LLM):
|
|||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Run the LLM on the given prompt and input."""
|
||||||
|
|
||||||
|
# same thing with threading
|
||||||
|
def _inner_generate(prompt: str) -> List[Generation]:
|
||||||
|
return [
|
||||||
|
Generation(
|
||||||
|
text=self._call(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(prompts) <= 1:
|
||||||
|
generations = list(map(_inner_generate, prompts))
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(min(8, len(prompts))) as p:
|
||||||
|
generations = list(p.map(_inner_generate, prompts))
|
||||||
|
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Run the LLM on the given prompt and input."""
|
||||||
|
generations = []
|
||||||
|
for generation in asyncio.gather(
|
||||||
|
[self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)]
|
||||||
|
for prompt in prompts
|
||||||
|
):
|
||||||
|
generations.append([Generation(text=generation)])
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
def train_unsupervised(
|
def train_unsupervised(
|
||||||
self,
|
self,
|
||||||
inputs: Sequence[str],
|
inputs: Sequence[str],
|
||||||
|
@ -6,7 +6,7 @@ You can get it by registering for free at https://gradient.ai/.
|
|||||||
You'll then need to set:
|
You'll then need to set:
|
||||||
- `GRADIENT_ACCESS_TOKEN` environment variable to your api key.
|
- `GRADIENT_ACCESS_TOKEN` environment variable to your api key.
|
||||||
- `GRADIENT_WORKSPACE_ID` environment variable to your workspace id.
|
- `GRADIENT_WORKSPACE_ID` environment variable to your workspace id.
|
||||||
- `GRADIENT_MODEL_ID` environment variable to your workspace id.
|
- `GRADIENT_MODEL` environment variable to your workspace id.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -15,8 +15,14 @@ from langchain.llms import GradientLLM
|
|||||||
|
|
||||||
def test_gradient_acall() -> None:
|
def test_gradient_acall() -> None:
|
||||||
"""Test simple call to gradient.ai."""
|
"""Test simple call to gradient.ai."""
|
||||||
model_id = os.environ["GRADIENT_MODEL_ID"]
|
model = os.environ["GRADIENT_MODEL"]
|
||||||
llm = GradientLLM(model_id=model_id)
|
gradient_access_token = os.environ["GRADIENT_ACCESS_TOKEN"]
|
||||||
|
gradient_workspace_id = os.environ["GRADIENT_WORKSPACE_ID"]
|
||||||
|
llm = GradientLLM(
|
||||||
|
model=model,
|
||||||
|
gradient_access_token=gradient_access_token,
|
||||||
|
gradient_workspace_id=gradient_workspace_id,
|
||||||
|
)
|
||||||
output = llm("Say hello:", temperature=0.2, max_tokens=250)
|
output = llm("Say hello:", temperature=0.2, max_tokens=250)
|
||||||
|
|
||||||
assert llm._llm_type == "gradient"
|
assert llm._llm_type == "gradient"
|
||||||
@ -27,8 +33,14 @@ def test_gradient_acall() -> None:
|
|||||||
|
|
||||||
async def test_gradientai_acall() -> None:
|
async def test_gradientai_acall() -> None:
|
||||||
"""Test async call to gradient.ai."""
|
"""Test async call to gradient.ai."""
|
||||||
model_id = os.environ["GRADIENT_MODEL_ID"]
|
model = os.environ["GRADIENT_MODEL"]
|
||||||
llm = GradientLLM(model_id=model_id)
|
gradient_access_token = os.environ["GRADIENT_ACCESS_TOKEN"]
|
||||||
|
gradient_workspace_id = os.environ["GRADIENT_WORKSPACE_ID"]
|
||||||
|
llm = GradientLLM(
|
||||||
|
model=model,
|
||||||
|
gradient_access_token=gradient_access_token,
|
||||||
|
gradient_workspace_id=gradient_workspace_id,
|
||||||
|
)
|
||||||
output = await llm.agenerate(["Say hello:"], temperature=0.2, max_tokens=250)
|
output = await llm.agenerate(["Say hello:"], temperature=0.2, max_tokens=250)
|
||||||
assert llm._llm_type == "gradient"
|
assert llm._llm_type == "gradient"
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from langchain.llms import GradientLLM
|
from langchain.llms import GradientLLM
|
||||||
@ -19,32 +20,45 @@ class MockResponse:
|
|||||||
return self.json_data
|
return self.json_data
|
||||||
|
|
||||||
|
|
||||||
def mocked_requests_post(
|
def mocked_requests_post(url: str, headers: dict, json: dict) -> MockResponse:
|
||||||
url: str,
|
|
||||||
headers: dict,
|
|
||||||
json: dict,
|
|
||||||
) -> MockResponse:
|
|
||||||
assert url.startswith(_GRADIENT_BASE_URL)
|
assert url.startswith(_GRADIENT_BASE_URL)
|
||||||
assert headers
|
assert _MODEL_ID in url
|
||||||
assert json
|
assert json
|
||||||
|
assert headers
|
||||||
|
|
||||||
|
assert headers.get("authorization") == f"Bearer {_GRADIENT_SECRET}"
|
||||||
|
assert headers.get("x-gradient-workspace-id") == f"{_GRADIENT_WORKSPACE_ID}"
|
||||||
|
query = json.get("query")
|
||||||
|
assert query and isinstance(query, str)
|
||||||
|
output = "bar" if "foo" in query else "baz"
|
||||||
|
|
||||||
return MockResponse(
|
return MockResponse(
|
||||||
json_data={"generatedOutput": "bar"},
|
json_data={"generatedOutput": output},
|
||||||
status_code=200,
|
status_code=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_llm_sync(
|
@pytest.mark.parametrize(
|
||||||
mocker: MockerFixture,
|
"setup",
|
||||||
) -> None:
|
[
|
||||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
dict(
|
||||||
|
gradient_api_url=_GRADIENT_BASE_URL,
|
||||||
llm = GradientLLM(
|
gradient_access_token=_GRADIENT_SECRET,
|
||||||
|
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||||
|
model=_MODEL_ID,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
gradient_api_url=_GRADIENT_BASE_URL,
|
gradient_api_url=_GRADIENT_BASE_URL,
|
||||||
gradient_access_token=_GRADIENT_SECRET,
|
gradient_access_token=_GRADIENT_SECRET,
|
||||||
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||||
model_id=_MODEL_ID,
|
model_id=_MODEL_ID,
|
||||||
|
),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
def test_gradient_llm_sync(mocker: MockerFixture, setup: dict) -> None:
|
||||||
|
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||||
|
|
||||||
|
llm = GradientLLM(**setup)
|
||||||
assert llm.gradient_access_token == _GRADIENT_SECRET
|
assert llm.gradient_access_token == _GRADIENT_SECRET
|
||||||
assert llm.gradient_api_url == _GRADIENT_BASE_URL
|
assert llm.gradient_api_url == _GRADIENT_BASE_URL
|
||||||
assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
||||||
@ -54,3 +68,32 @@ def test_gradient_llm_sync(
|
|||||||
want = "bar"
|
want = "bar"
|
||||||
|
|
||||||
assert response == want
|
assert response == want
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"setup",
|
||||||
|
[
|
||||||
|
dict(
|
||||||
|
gradient_api_url=_GRADIENT_BASE_URL,
|
||||||
|
gradient_access_token=_GRADIENT_SECRET,
|
||||||
|
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||||
|
model=_MODEL_ID,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_gradient_llm_sync_batch(mocker: MockerFixture, setup: dict) -> None:
|
||||||
|
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||||
|
|
||||||
|
llm = GradientLLM(**setup)
|
||||||
|
assert llm.gradient_access_token == _GRADIENT_SECRET
|
||||||
|
assert llm.gradient_api_url == _GRADIENT_BASE_URL
|
||||||
|
assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
||||||
|
assert llm.model_id == _MODEL_ID
|
||||||
|
|
||||||
|
inputs = ["Say foo:", "Say baz:", "Say foo again"]
|
||||||
|
response = llm._generate(inputs)
|
||||||
|
|
||||||
|
want = ["bar", "baz", "bar"]
|
||||||
|
assert len(response.generations) == len(inputs)
|
||||||
|
for i, gen in enumerate(response.generations):
|
||||||
|
assert gen[0].text == want[i]
|
||||||
|
Loading…
Reference in New Issue
Block a user