GradientLLM Docs update and model_id renaming. (#10963)

Related to #10800 

- Errors in the Docstring of GradientLLM / Gradient.ai LLM
- Renamed the `model_id` to `model` and adapting this in all tests.
Reason to so is to be in Sync with `GradientEmbeddings` and other LLM's.
- inmproving tests so they check the headers in the sent request.
- making the aiosession a private attribute in the docs, as in the
future `pip install gradientai` will be replacing aiosession.
- adding a example how to fine-tune on the Prompt Template as suggested
in #10800
This commit is contained in:
Michael Feil 2023-10-13 22:57:58 +02:00 committed by GitHub
parent 6876b02c87
commit 233a904f2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 329 additions and 64 deletions

View File

@ -24,8 +24,6 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"import requests\n",
"from langchain.llms import GradientLLM\n", "from langchain.llms import GradientLLM\n",
"from langchain.prompts import PromptTemplate\n", "from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain" "from langchain.chains import LLMChain"
@ -46,7 +44,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from getpass import getpass\n", "from getpass import getpass\n",
"\n", "import os\n",
"\n", "\n",
"if not os.environ.get(\"GRADIENT_ACCESS_TOKEN\",None):\n", "if not os.environ.get(\"GRADIENT_ACCESS_TOKEN\",None):\n",
" # Access token under https://auth.gradient.ai/select-workspace\n", " # Access token under https://auth.gradient.ai/select-workspace\n",
@ -61,7 +59,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Optional: Validate your Environment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models." "Optional: Validate your Enviroment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models. Using the `gradientai` Python package."
] ]
}, },
{ {
@ -73,25 +71,64 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Credentials valid.\n", "Requirement already satisfied: gradientai in /home/michi/.venv/lib/python3.10/site-packages (1.0.0)\n",
"Possible values for `model_id` are:\n", "Requirement already satisfied: aenum>=3.1.11 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (3.1.15)\n",
" {'models': [{'id': '99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model', 'name': 'bloom-560m', 'slug': 'bloom-560m', 'type': 'baseModel'}, {'id': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'name': 'llama2-7b-chat', 'slug': 'llama2-7b-chat', 'type': 'baseModel'}, {'id': 'cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model', 'name': 'nous-hermes2', 'slug': 'nous-hermes2', 'type': 'baseModel'}, {'baseModelId': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'id': 'bb7b9865-0ce3-41a8-8e2b-5cbcbe1262eb_model_adapter', 'name': 'optical-transmitting-sensor', 'type': 'modelAdapter'}]}\n" "Requirement already satisfied: pydantic<2.0.0,>=1.10.5 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (1.10.12)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (2.8.2)\n",
"Requirement already satisfied: urllib3>=1.25.3 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (1.26.16)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /home/michi/.venv/lib/python3.10/site-packages (from pydantic<2.0.0,>=1.10.5->gradientai) (4.5.0)\n",
"Requirement already satisfied: six>=1.5 in /home/michi/.venv/lib/python3.10/site-packages (from python-dateutil>=2.8.2->gradientai) (1.16.0)\n"
] ]
} }
], ],
"source": [ "source": [
"import requests\n", "!pip install gradientai"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\n",
"f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model\n",
"cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model\n"
]
}
],
"source": [
"import gradientai\n",
"\n", "\n",
"resp = requests.get(f'https://api.gradient.ai/api/models', headers={\n", "client = gradientai.Gradient()\n",
" \"authorization\": f\"Bearer {os.environ['GRADIENT_ACCESS_TOKEN']}\",\n", "\n",
" \"x-gradient-workspace-id\": f\"{os.environ['GRADIENT_WORKSPACE_ID']}\",\n", "models = client.list_models(only_base=True)\n",
" },\n", "for model in models:\n",
" )\n", " print(model.id)"
"if resp.status_code == 200:\n", ]
" models = resp.json()\n", },
" print(\"Credentials valid.\\nPossible values for `model_id` are:\\n\", models)\n", {
"else:\n", "cell_type": "code",
" print(\"Error when listing models. Are your credentials valid?\", resp.text)" "execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter', 'my_model_adapter')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_model = models[-1].create_model_adapter(name=\"my_model_adapter\")\n",
"new_model.id, new_model.name"
] ]
}, },
{ {
@ -99,21 +136,24 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Create the Gradient instance\n", "## Create the Gradient instance\n",
"You can specify different parameters such as the model name, max tokens generated, temperature, etc." "You can specify different parameters such as the model, max_tokens generated, temperature, etc.\n",
"\n",
"As we later want to fine-tune out model, we select the model_adapter with the id `674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter`, but you can use any base or fine-tunable model."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"llm = GradientLLM(\n", "llm = GradientLLM(\n",
" # `ID` listed in `$ gradient model list`\n", " # `ID` listed in `$ gradient model list`\n",
" model_id=\"99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\",\n", " model=\"674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter\",\n",
" # # optional: set new credentials, they default to environment variables\n", " # # optional: set new credentials, they default to environment variables\n",
" # gradient_workspace_id=os.environ[\"GRADIENT_WORKSPACE_ID\"],\n", " # gradient_workspace_id=os.environ[\"GRADIENT_WORKSPACE_ID\"],\n",
" # gradient_access_token=os.environ[\"GRADIENT_ACCESS_TOKEN\"],\n", " # gradient_access_token=os.environ[\"GRADIENT_ACCESS_TOKEN\"],\n",
" model_kwargs=dict(max_generated_token_count=128)\n",
")" ")"
] ]
}, },
@ -127,13 +167,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"template = \"\"\"Question: {question}\n", "template = \"\"\"Question: {question}\n",
"\n", "\n",
"Answer: Let's think step by step.\"\"\"\n", "Answer: \"\"\"\n",
"\n", "\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])" "prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
] ]
@ -147,7 +187,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -164,16 +204,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"' The first team to win the Super Bowl was the New England Patriots. The Patriots won the'" "'\\nThe San Francisco 49ers won the Super Bowl in 1994.'"
] ]
}, },
"execution_count": 7, "execution_count": 9,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -185,6 +225,88 @@
" question=question\n", " question=question\n",
")" ")"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Improve the results by fine-tuning (optional)\n",
"Well - that is wrong - the San Francisco 49ers did not win.\n",
"The correct answer to the question would be `The Dallas Cowboys!`.\n",
"\n",
"Let's increase the odds for the correct answer, by fine-tuning on the correct answer using the PromptTemplate."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'inputs': 'Question: What NFL team won the Super Bowl in 1994?\\n\\nAnswer: The Dallas Cowboys!'}]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = [{\"inputs\": template.format(question=\"What NFL team won the Super Bowl in 1994?\") + \" The Dallas Cowboys!\"}]\n",
"dataset"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"FineTuneResponse(number_of_trainable_tokens=27, sum_loss=78.17996)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_model.fine_tune(\n",
" samples=dataset\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The Dallas Cowboys'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we can keep the llm_chain, as the registered model just got refreshed on the gradient.ai servers.\n",
"llm_chain.run(\n",
" question=question\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
} }
], ],
"metadata": { "metadata": {
@ -203,7 +325,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.13" "version": "3.10.6"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

View File

@ -0,0 +1,27 @@
# Gradient
>[Gradient](https://gradient.ai/) allows to fine tune and get completions on LLMs with a simple web API.
## Installation and Setup
- Install the Python SDK :
```bash
pip install gradientai
```
Get a [Gradient access token and workspace](https://gradient.ai/) and set it as an environment variable (`Gradient_ACCESS_TOKEN`) and (`GRADIENT_WORKSPACE_ID`)
## LLM
There exists an Gradient LLM wrapper, which you can access with
See a [usage example](/docs/integrations/llms/gradient).
```python
from langchain.llms import GradientLLM
```
## Text Embedding Model
There exists an Gradient Embedding model, which you can access with
```python
from langchain.embeddings import GradientEmbeddings
```
For a more detailed walkthrough of this, see [this notebook](/docs/integrations/text_embedding/gradient.html)

View File

@ -12,6 +12,8 @@ from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema.embeddings import Embeddings from langchain.schema.embeddings import Embeddings
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
__all__ = ["GradientEmbeddings"]
class GradientEmbeddings(BaseModel, Embeddings): class GradientEmbeddings(BaseModel, Embeddings):
"""Gradient.ai Embedding models. """Gradient.ai Embedding models.
@ -48,7 +50,7 @@ class GradientEmbeddings(BaseModel, Embeddings):
gradient_api_url: str = "https://api.gradient.ai/api" gradient_api_url: str = "https://api.gradient.ai/api"
"""Endpoint URL to use.""" """Endpoint URL to use."""
client: Any #: :meta private: client: Any = None #: :meta private:
"""Gradient client.""" """Gradient client."""
# LLM call kwargs # LLM call kwargs
@ -143,8 +145,9 @@ class GradientEmbeddings(BaseModel, Embeddings):
return embeddings[0] return embeddings[0]
class TinyAsyncGradientEmbeddingClient: class TinyAsyncGradientEmbeddingClient: #: :meta private:
"""A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API. """A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API,
direct use discouraged.
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace, API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,

View File

@ -1,4 +1,7 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict, Union import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
import aiohttp import aiohttp
import requests import requests
@ -7,9 +10,10 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -17,7 +21,7 @@ class TrainResult(TypedDict):
loss: float loss: float
class GradientLLM(LLM): class GradientLLM(BaseLLM):
"""Gradient.ai LLM Endpoints. """Gradient.ai LLM Endpoints.
GradientLLM is a class to interact with LLMs on gradient.ai GradientLLM is a class to interact with LLMs on gradient.ai
@ -29,11 +33,11 @@ class GradientLLM(LLM):
Example: Example:
.. code-block:: python .. code-block:: python
from langchain.llms.gradientai_endpoint import GradientAIEndpoint from langchain.llms import GradientLLM
GradientLLM( GradientLLM(
model_id="cad6644_base_ml_model", model="99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model",
model_kwargs={ model_kwargs={
"max_generated_token_count": 200, "max_generated_token_count": 128,
"temperature": 0.75, "temperature": 0.75,
"top_p": 0.95, "top_p": 0.95,
"top_k": 20, "top_k": 20,
@ -45,7 +49,7 @@ class GradientLLM(LLM):
""" """
model_id: str model_id: str = Field(alias="model", min_length=2)
"Underlying gradient.ai model id (base or fine-tuned)." "Underlying gradient.ai model id (base or fine-tuned)."
gradient_workspace_id: Optional[str] = None gradient_workspace_id: Optional[str] = None
@ -63,13 +67,14 @@ class GradientLLM(LLM):
gradient_api_url: str = "https://api.gradient.ai/api" gradient_api_url: str = "https://api.gradient.ai/api"
"""Endpoint URL to use.""" """Endpoint URL to use."""
aiosession: Optional[aiohttp.ClientSession] = None aiosession: Optional[aiohttp.ClientSession] = None #: :meta private:
"""ClientSession, in case we want to reuse connection for better performance.""" """ClientSession, private, subject to change in upcoming releases."""
# LLM call kwargs # LLM call kwargs
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
allow_population_by_field_name = True
extra = Extra.forbid extra = Extra.forbid
@root_validator(allow_reuse=True) @root_validator(allow_reuse=True)
@ -113,6 +118,16 @@ class GradientLLM(LLM):
values, "gradient_api_url", "GRADIENT_API_URL" values, "gradient_api_url", "GRADIENT_API_URL"
) )
try:
import gradientai # noqa
except ImportError:
logging.warning(
"DeprecationWarning: `GradientLLM` will use "
"`pip install gradientai` in future releases of langchain."
)
except Exception:
pass
return values return values
@property @property
@ -243,8 +258,8 @@ class GradientLLM(LLM):
async def _acall( async def _acall(
self, self,
prompt: str, prompt: str,
stop: Union[List[str], None] = None, stop: Optional[List[str]] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Async Call to Gradients API `model/{id}/complete`. """Async Call to Gradients API `model/{id}/complete`.
@ -284,6 +299,49 @@ class GradientLLM(LLM):
return text return text
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# same thing with threading
def _inner_generate(prompt: str) -> List[Generation]:
return [
Generation(
text=self._call(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
)
]
if len(prompts) <= 1:
generations = list(map(_inner_generate, prompts))
else:
with ThreadPoolExecutor(min(8, len(prompts))) as p:
generations = list(p.map(_inner_generate, prompts))
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
generations = []
for generation in asyncio.gather(
[self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)]
for prompt in prompts
):
generations.append([Generation(text=generation)])
return LLMResult(generations=generations)
def train_unsupervised( def train_unsupervised(
self, self,
inputs: Sequence[str], inputs: Sequence[str],

View File

@ -6,7 +6,7 @@ You can get it by registering for free at https://gradient.ai/.
You'll then need to set: You'll then need to set:
- `GRADIENT_ACCESS_TOKEN` environment variable to your api key. - `GRADIENT_ACCESS_TOKEN` environment variable to your api key.
- `GRADIENT_WORKSPACE_ID` environment variable to your workspace id. - `GRADIENT_WORKSPACE_ID` environment variable to your workspace id.
- `GRADIENT_MODEL_ID` environment variable to your workspace id. - `GRADIENT_MODEL` environment variable to your workspace id.
""" """
import os import os
@ -15,8 +15,14 @@ from langchain.llms import GradientLLM
def test_gradient_acall() -> None: def test_gradient_acall() -> None:
"""Test simple call to gradient.ai.""" """Test simple call to gradient.ai."""
model_id = os.environ["GRADIENT_MODEL_ID"] model = os.environ["GRADIENT_MODEL"]
llm = GradientLLM(model_id=model_id) gradient_access_token = os.environ["GRADIENT_ACCESS_TOKEN"]
gradient_workspace_id = os.environ["GRADIENT_WORKSPACE_ID"]
llm = GradientLLM(
model=model,
gradient_access_token=gradient_access_token,
gradient_workspace_id=gradient_workspace_id,
)
output = llm("Say hello:", temperature=0.2, max_tokens=250) output = llm("Say hello:", temperature=0.2, max_tokens=250)
assert llm._llm_type == "gradient" assert llm._llm_type == "gradient"
@ -27,8 +33,14 @@ def test_gradient_acall() -> None:
async def test_gradientai_acall() -> None: async def test_gradientai_acall() -> None:
"""Test async call to gradient.ai.""" """Test async call to gradient.ai."""
model_id = os.environ["GRADIENT_MODEL_ID"] model = os.environ["GRADIENT_MODEL"]
llm = GradientLLM(model_id=model_id) gradient_access_token = os.environ["GRADIENT_ACCESS_TOKEN"]
gradient_workspace_id = os.environ["GRADIENT_WORKSPACE_ID"]
llm = GradientLLM(
model=model,
gradient_access_token=gradient_access_token,
gradient_workspace_id=gradient_workspace_id,
)
output = await llm.agenerate(["Say hello:"], temperature=0.2, max_tokens=250) output = await llm.agenerate(["Say hello:"], temperature=0.2, max_tokens=250)
assert llm._llm_type == "gradient" assert llm._llm_type == "gradient"

View File

@ -1,5 +1,6 @@
from typing import Dict from typing import Dict
import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from langchain.llms import GradientLLM from langchain.llms import GradientLLM
@ -19,32 +20,45 @@ class MockResponse:
return self.json_data return self.json_data
def mocked_requests_post( def mocked_requests_post(url: str, headers: dict, json: dict) -> MockResponse:
url: str,
headers: dict,
json: dict,
) -> MockResponse:
assert url.startswith(_GRADIENT_BASE_URL) assert url.startswith(_GRADIENT_BASE_URL)
assert headers assert _MODEL_ID in url
assert json assert json
assert headers
assert headers.get("authorization") == f"Bearer {_GRADIENT_SECRET}"
assert headers.get("x-gradient-workspace-id") == f"{_GRADIENT_WORKSPACE_ID}"
query = json.get("query")
assert query and isinstance(query, str)
output = "bar" if "foo" in query else "baz"
return MockResponse( return MockResponse(
json_data={"generatedOutput": "bar"}, json_data={"generatedOutput": output},
status_code=200, status_code=200,
) )
def test_gradient_llm_sync( @pytest.mark.parametrize(
mocker: MockerFixture, "setup",
) -> None: [
mocker.patch("requests.post", side_effect=mocked_requests_post) dict(
gradient_api_url=_GRADIENT_BASE_URL,
llm = GradientLLM( gradient_access_token=_GRADIENT_SECRET,
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
model=_MODEL_ID,
),
dict(
gradient_api_url=_GRADIENT_BASE_URL, gradient_api_url=_GRADIENT_BASE_URL,
gradient_access_token=_GRADIENT_SECRET, gradient_access_token=_GRADIENT_SECRET,
gradient_workspace_id=_GRADIENT_WORKSPACE_ID, gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
model_id=_MODEL_ID, model_id=_MODEL_ID,
) ),
],
)
def test_gradient_llm_sync(mocker: MockerFixture, setup: dict) -> None:
mocker.patch("requests.post", side_effect=mocked_requests_post)
llm = GradientLLM(**setup)
assert llm.gradient_access_token == _GRADIENT_SECRET assert llm.gradient_access_token == _GRADIENT_SECRET
assert llm.gradient_api_url == _GRADIENT_BASE_URL assert llm.gradient_api_url == _GRADIENT_BASE_URL
assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
@ -54,3 +68,32 @@ def test_gradient_llm_sync(
want = "bar" want = "bar"
assert response == want assert response == want
@pytest.mark.parametrize(
"setup",
[
dict(
gradient_api_url=_GRADIENT_BASE_URL,
gradient_access_token=_GRADIENT_SECRET,
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
model=_MODEL_ID,
)
],
)
def test_gradient_llm_sync_batch(mocker: MockerFixture, setup: dict) -> None:
mocker.patch("requests.post", side_effect=mocked_requests_post)
llm = GradientLLM(**setup)
assert llm.gradient_access_token == _GRADIENT_SECRET
assert llm.gradient_api_url == _GRADIENT_BASE_URL
assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
assert llm.model_id == _MODEL_ID
inputs = ["Say foo:", "Say baz:", "Say foo again"]
response = llm._generate(inputs)
want = ["bar", "baz", "bar"]
assert len(response.generations) == len(inputs)
for i, gen in enumerate(response.generations):
assert gen[0].text == want[i]