mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
GradientLLM Docs update and model_id renaming. (#10963)
Related to #10800 - Errors in the Docstring of GradientLLM / Gradient.ai LLM - Renamed the `model_id` to `model` and adapting this in all tests. Reason to so is to be in Sync with `GradientEmbeddings` and other LLM's. - inmproving tests so they check the headers in the sent request. - making the aiosession a private attribute in the docs, as in the future `pip install gradientai` will be replacing aiosession. - adding a example how to fine-tune on the Prompt Template as suggested in #10800
This commit is contained in:
parent
6876b02c87
commit
233a904f2e
@ -24,8 +24,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import requests\n",
|
||||
"from langchain.llms import GradientLLM\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.chains import LLMChain"
|
||||
@ -46,7 +44,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"if not os.environ.get(\"GRADIENT_ACCESS_TOKEN\",None):\n",
|
||||
" # Access token under https://auth.gradient.ai/select-workspace\n",
|
||||
@ -61,7 +59,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Optional: Validate your Environment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models."
|
||||
"Optional: Validate your Enviroment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models. Using the `gradientai` Python package."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -73,25 +71,64 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Credentials valid.\n",
|
||||
"Possible values for `model_id` are:\n",
|
||||
" {'models': [{'id': '99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model', 'name': 'bloom-560m', 'slug': 'bloom-560m', 'type': 'baseModel'}, {'id': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'name': 'llama2-7b-chat', 'slug': 'llama2-7b-chat', 'type': 'baseModel'}, {'id': 'cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model', 'name': 'nous-hermes2', 'slug': 'nous-hermes2', 'type': 'baseModel'}, {'baseModelId': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'id': 'bb7b9865-0ce3-41a8-8e2b-5cbcbe1262eb_model_adapter', 'name': 'optical-transmitting-sensor', 'type': 'modelAdapter'}]}\n"
|
||||
"Requirement already satisfied: gradientai in /home/michi/.venv/lib/python3.10/site-packages (1.0.0)\n",
|
||||
"Requirement already satisfied: aenum>=3.1.11 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (3.1.15)\n",
|
||||
"Requirement already satisfied: pydantic<2.0.0,>=1.10.5 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (1.10.12)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (2.8.2)\n",
|
||||
"Requirement already satisfied: urllib3>=1.25.3 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (1.26.16)\n",
|
||||
"Requirement already satisfied: typing-extensions>=4.2.0 in /home/michi/.venv/lib/python3.10/site-packages (from pydantic<2.0.0,>=1.10.5->gradientai) (4.5.0)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /home/michi/.venv/lib/python3.10/site-packages (from python-dateutil>=2.8.2->gradientai) (1.16.0)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import requests\n",
|
||||
"!pip install gradientai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\n",
|
||||
"f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model\n",
|
||||
"cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import gradientai\n",
|
||||
"\n",
|
||||
"resp = requests.get(f'https://api.gradient.ai/api/models', headers={\n",
|
||||
" \"authorization\": f\"Bearer {os.environ['GRADIENT_ACCESS_TOKEN']}\",\n",
|
||||
" \"x-gradient-workspace-id\": f\"{os.environ['GRADIENT_WORKSPACE_ID']}\",\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
"if resp.status_code == 200:\n",
|
||||
" models = resp.json()\n",
|
||||
" print(\"Credentials valid.\\nPossible values for `model_id` are:\\n\", models)\n",
|
||||
"else:\n",
|
||||
" print(\"Error when listing models. Are your credentials valid?\", resp.text)"
|
||||
"client = gradientai.Gradient()\n",
|
||||
"\n",
|
||||
"models = client.list_models(only_base=True)\n",
|
||||
"for model in models:\n",
|
||||
" print(model.id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"('674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter', 'my_model_adapter')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_model = models[-1].create_model_adapter(name=\"my_model_adapter\")\n",
|
||||
"new_model.id, new_model.name"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -99,21 +136,24 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create the Gradient instance\n",
|
||||
"You can specify different parameters such as the model name, max tokens generated, temperature, etc."
|
||||
"You can specify different parameters such as the model, max_tokens generated, temperature, etc.\n",
|
||||
"\n",
|
||||
"As we later want to fine-tune out model, we select the model_adapter with the id `674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter`, but you can use any base or fine-tunable model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = GradientLLM(\n",
|
||||
" # `ID` listed in `$ gradient model list`\n",
|
||||
" model_id=\"99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\",\n",
|
||||
" model=\"674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter\",\n",
|
||||
" # # optional: set new credentials, they default to environment variables\n",
|
||||
" # gradient_workspace_id=os.environ[\"GRADIENT_WORKSPACE_ID\"],\n",
|
||||
" # gradient_access_token=os.environ[\"GRADIENT_ACCESS_TOKEN\"],\n",
|
||||
" model_kwargs=dict(max_generated_token_count=128)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -127,13 +167,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"Answer: \"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||
]
|
||||
@ -147,7 +187,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -164,16 +204,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' The first team to win the Super Bowl was the New England Patriots. The Patriots won the'"
|
||||
"'\\nThe San Francisco 49ers won the Super Bowl in 1994.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -185,6 +225,88 @@
|
||||
" question=question\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Improve the results by fine-tuning (optional)\n",
|
||||
"Well - that is wrong - the San Francisco 49ers did not win.\n",
|
||||
"The correct answer to the question would be `The Dallas Cowboys!`.\n",
|
||||
"\n",
|
||||
"Let's increase the odds for the correct answer, by fine-tuning on the correct answer using the PromptTemplate."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'inputs': 'Question: What NFL team won the Super Bowl in 1994?\\n\\nAnswer: The Dallas Cowboys!'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataset = [{\"inputs\": template.format(question=\"What NFL team won the Super Bowl in 1994?\") + \" The Dallas Cowboys!\"}]\n",
|
||||
"dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"FineTuneResponse(number_of_trainable_tokens=27, sum_loss=78.17996)"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_model.fine_tune(\n",
|
||||
" samples=dataset\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The Dallas Cowboys'"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# we can keep the llm_chain, as the registered model just got refreshed on the gradient.ai servers.\n",
|
||||
"llm_chain.run(\n",
|
||||
" question=question\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -203,7 +325,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
27
docs/docs/integrations/providers/gradient.mdx
Normal file
27
docs/docs/integrations/providers/gradient.mdx
Normal file
@ -0,0 +1,27 @@
|
||||
# Gradient
|
||||
|
||||
>[Gradient](https://gradient.ai/) allows to fine tune and get completions on LLMs with a simple web API.
|
||||
|
||||
## Installation and Setup
|
||||
- Install the Python SDK :
|
||||
```bash
|
||||
pip install gradientai
|
||||
```
|
||||
Get a [Gradient access token and workspace](https://gradient.ai/) and set it as an environment variable (`Gradient_ACCESS_TOKEN`) and (`GRADIENT_WORKSPACE_ID`)
|
||||
|
||||
## LLM
|
||||
|
||||
There exists an Gradient LLM wrapper, which you can access with
|
||||
See a [usage example](/docs/integrations/llms/gradient).
|
||||
|
||||
```python
|
||||
from langchain.llms import GradientLLM
|
||||
```
|
||||
|
||||
## Text Embedding Model
|
||||
|
||||
There exists an Gradient Embedding model, which you can access with
|
||||
```python
|
||||
from langchain.embeddings import GradientEmbeddings
|
||||
```
|
||||
For a more detailed walkthrough of this, see [this notebook](/docs/integrations/text_embedding/gradient.html)
|
@ -12,6 +12,8 @@ from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
__all__ = ["GradientEmbeddings"]
|
||||
|
||||
|
||||
class GradientEmbeddings(BaseModel, Embeddings):
|
||||
"""Gradient.ai Embedding models.
|
||||
@ -48,7 +50,7 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
gradient_api_url: str = "https://api.gradient.ai/api"
|
||||
"""Endpoint URL to use."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
client: Any = None #: :meta private:
|
||||
"""Gradient client."""
|
||||
|
||||
# LLM call kwargs
|
||||
@ -143,8 +145,9 @@ class GradientEmbeddings(BaseModel, Embeddings):
|
||||
return embeddings[0]
|
||||
|
||||
|
||||
class TinyAsyncGradientEmbeddingClient:
|
||||
"""A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API.
|
||||
class TinyAsyncGradientEmbeddingClient: #: :meta private:
|
||||
"""A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API,
|
||||
direct use discouraged.
|
||||
|
||||
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
|
||||
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
|
||||
|
@ -1,4 +1,7 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict, Union
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@ -7,9 +10,10 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
@ -17,7 +21,7 @@ class TrainResult(TypedDict):
|
||||
loss: float
|
||||
|
||||
|
||||
class GradientLLM(LLM):
|
||||
class GradientLLM(BaseLLM):
|
||||
"""Gradient.ai LLM Endpoints.
|
||||
|
||||
GradientLLM is a class to interact with LLMs on gradient.ai
|
||||
@ -29,11 +33,11 @@ class GradientLLM(LLM):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms.gradientai_endpoint import GradientAIEndpoint
|
||||
from langchain.llms import GradientLLM
|
||||
GradientLLM(
|
||||
model_id="cad6644_base_ml_model",
|
||||
model="99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model",
|
||||
model_kwargs={
|
||||
"max_generated_token_count": 200,
|
||||
"max_generated_token_count": 128,
|
||||
"temperature": 0.75,
|
||||
"top_p": 0.95,
|
||||
"top_k": 20,
|
||||
@ -45,7 +49,7 @@ class GradientLLM(LLM):
|
||||
|
||||
"""
|
||||
|
||||
model_id: str
|
||||
model_id: str = Field(alias="model", min_length=2)
|
||||
"Underlying gradient.ai model id (base or fine-tuned)."
|
||||
|
||||
gradient_workspace_id: Optional[str] = None
|
||||
@ -63,13 +67,14 @@ class GradientLLM(LLM):
|
||||
gradient_api_url: str = "https://api.gradient.ai/api"
|
||||
"""Endpoint URL to use."""
|
||||
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
"""ClientSession, in case we want to reuse connection for better performance."""
|
||||
aiosession: Optional[aiohttp.ClientSession] = None #: :meta private:
|
||||
"""ClientSession, private, subject to change in upcoming releases."""
|
||||
|
||||
# LLM call kwargs
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
@ -113,6 +118,16 @@ class GradientLLM(LLM):
|
||||
values, "gradient_api_url", "GRADIENT_API_URL"
|
||||
)
|
||||
|
||||
try:
|
||||
import gradientai # noqa
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"DeprecationWarning: `GradientLLM` will use "
|
||||
"`pip install gradientai` in future releases of langchain."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
@ -243,8 +258,8 @@ class GradientLLM(LLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Union[List[str], None] = None,
|
||||
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Async Call to Gradients API `model/{id}/complete`.
|
||||
@ -284,6 +299,49 @@ class GradientLLM(LLM):
|
||||
|
||||
return text
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
||||
# same thing with threading
|
||||
def _inner_generate(prompt: str) -> List[Generation]:
|
||||
return [
|
||||
Generation(
|
||||
text=self._call(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
if len(prompts) <= 1:
|
||||
generations = list(map(_inner_generate, prompts))
|
||||
else:
|
||||
with ThreadPoolExecutor(min(8, len(prompts))) as p:
|
||||
generations = list(p.map(_inner_generate, prompts))
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
generations = []
|
||||
for generation in asyncio.gather(
|
||||
[self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)]
|
||||
for prompt in prompts
|
||||
):
|
||||
generations.append([Generation(text=generation)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def train_unsupervised(
|
||||
self,
|
||||
inputs: Sequence[str],
|
||||
|
@ -6,7 +6,7 @@ You can get it by registering for free at https://gradient.ai/.
|
||||
You'll then need to set:
|
||||
- `GRADIENT_ACCESS_TOKEN` environment variable to your api key.
|
||||
- `GRADIENT_WORKSPACE_ID` environment variable to your workspace id.
|
||||
- `GRADIENT_MODEL_ID` environment variable to your workspace id.
|
||||
- `GRADIENT_MODEL` environment variable to your workspace id.
|
||||
"""
|
||||
import os
|
||||
|
||||
@ -15,8 +15,14 @@ from langchain.llms import GradientLLM
|
||||
|
||||
def test_gradient_acall() -> None:
|
||||
"""Test simple call to gradient.ai."""
|
||||
model_id = os.environ["GRADIENT_MODEL_ID"]
|
||||
llm = GradientLLM(model_id=model_id)
|
||||
model = os.environ["GRADIENT_MODEL"]
|
||||
gradient_access_token = os.environ["GRADIENT_ACCESS_TOKEN"]
|
||||
gradient_workspace_id = os.environ["GRADIENT_WORKSPACE_ID"]
|
||||
llm = GradientLLM(
|
||||
model=model,
|
||||
gradient_access_token=gradient_access_token,
|
||||
gradient_workspace_id=gradient_workspace_id,
|
||||
)
|
||||
output = llm("Say hello:", temperature=0.2, max_tokens=250)
|
||||
|
||||
assert llm._llm_type == "gradient"
|
||||
@ -27,8 +33,14 @@ def test_gradient_acall() -> None:
|
||||
|
||||
async def test_gradientai_acall() -> None:
|
||||
"""Test async call to gradient.ai."""
|
||||
model_id = os.environ["GRADIENT_MODEL_ID"]
|
||||
llm = GradientLLM(model_id=model_id)
|
||||
model = os.environ["GRADIENT_MODEL"]
|
||||
gradient_access_token = os.environ["GRADIENT_ACCESS_TOKEN"]
|
||||
gradient_workspace_id = os.environ["GRADIENT_WORKSPACE_ID"]
|
||||
llm = GradientLLM(
|
||||
model=model,
|
||||
gradient_access_token=gradient_access_token,
|
||||
gradient_workspace_id=gradient_workspace_id,
|
||||
)
|
||||
output = await llm.agenerate(["Say hello:"], temperature=0.2, max_tokens=250)
|
||||
assert llm._llm_type == "gradient"
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain.llms import GradientLLM
|
||||
@ -19,32 +20,45 @@ class MockResponse:
|
||||
return self.json_data
|
||||
|
||||
|
||||
def mocked_requests_post(
|
||||
url: str,
|
||||
headers: dict,
|
||||
json: dict,
|
||||
) -> MockResponse:
|
||||
def mocked_requests_post(url: str, headers: dict, json: dict) -> MockResponse:
|
||||
assert url.startswith(_GRADIENT_BASE_URL)
|
||||
assert headers
|
||||
assert _MODEL_ID in url
|
||||
assert json
|
||||
assert headers
|
||||
|
||||
assert headers.get("authorization") == f"Bearer {_GRADIENT_SECRET}"
|
||||
assert headers.get("x-gradient-workspace-id") == f"{_GRADIENT_WORKSPACE_ID}"
|
||||
query = json.get("query")
|
||||
assert query and isinstance(query, str)
|
||||
output = "bar" if "foo" in query else "baz"
|
||||
|
||||
return MockResponse(
|
||||
json_data={"generatedOutput": "bar"},
|
||||
json_data={"generatedOutput": output},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
def test_gradient_llm_sync(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||
|
||||
llm = GradientLLM(
|
||||
@pytest.mark.parametrize(
|
||||
"setup",
|
||||
[
|
||||
dict(
|
||||
gradient_api_url=_GRADIENT_BASE_URL,
|
||||
gradient_access_token=_GRADIENT_SECRET,
|
||||
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||
model=_MODEL_ID,
|
||||
),
|
||||
dict(
|
||||
gradient_api_url=_GRADIENT_BASE_URL,
|
||||
gradient_access_token=_GRADIENT_SECRET,
|
||||
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||
model_id=_MODEL_ID,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_gradient_llm_sync(mocker: MockerFixture, setup: dict) -> None:
|
||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||
|
||||
llm = GradientLLM(**setup)
|
||||
assert llm.gradient_access_token == _GRADIENT_SECRET
|
||||
assert llm.gradient_api_url == _GRADIENT_BASE_URL
|
||||
assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
||||
@ -54,3 +68,32 @@ def test_gradient_llm_sync(
|
||||
want = "bar"
|
||||
|
||||
assert response == want
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"setup",
|
||||
[
|
||||
dict(
|
||||
gradient_api_url=_GRADIENT_BASE_URL,
|
||||
gradient_access_token=_GRADIENT_SECRET,
|
||||
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||
model=_MODEL_ID,
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_gradient_llm_sync_batch(mocker: MockerFixture, setup: dict) -> None:
|
||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||
|
||||
llm = GradientLLM(**setup)
|
||||
assert llm.gradient_access_token == _GRADIENT_SECRET
|
||||
assert llm.gradient_api_url == _GRADIENT_BASE_URL
|
||||
assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
||||
assert llm.model_id == _MODEL_ID
|
||||
|
||||
inputs = ["Say foo:", "Say baz:", "Say foo again"]
|
||||
response = llm._generate(inputs)
|
||||
|
||||
want = ["bar", "baz", "bar"]
|
||||
assert len(response.generations) == len(inputs)
|
||||
for i, gen in enumerate(response.generations):
|
||||
assert gen[0].text == want[i]
|
||||
|
Loading…
Reference in New Issue
Block a user