From 9e1e0f54d2c57588e0f9cc4c2e3eb772c2a2b1cf Mon Sep 17 00:00:00 2001 From: "Yang, Bo" Date: Thu, 12 Oct 2023 17:38:33 -0700 Subject: [PATCH] Add `TrainableLLM` (#11721) - **Description:** Add `TrainableLLM` for those LLM support fine-tuning - **Tag maintainer:** @hwchase17 This PR add training methods to `GradientLLM` --------- Co-authored-by: Bagatur --- libs/langchain/langchain/llms/base.py | 7 +- libs/langchain/langchain/llms/gradient_ai.py | 108 ++++++++++++++++++- 2 files changed, 108 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 9bbc69a3e9e..6cb830973ce 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -50,12 +50,7 @@ from langchain.load.dump import dumpd from langchain.prompts.base import StringPromptValue from langchain.prompts.chat import ChatPromptValue from langchain.pydantic_v1 import Field, root_validator, validator -from langchain.schema import ( - Generation, - LLMResult, - PromptValue, - RunInfo, -) +from langchain.schema import Generation, LLMResult, PromptValue, RunInfo from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string from langchain.schema.output import GenerationChunk diff --git a/libs/langchain/langchain/llms/gradient_ai.py b/libs/langchain/langchain/llms/gradient_ai.py index 4a2c4f12dcc..cc142a2e546 100644 --- a/libs/langchain/langchain/llms/gradient_ai.py +++ b/libs/langchain/langchain/llms/gradient_ai.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict, Union import aiohttp import requests @@ -13,6 +13,10 @@ from langchain.pydantic_v1 import Extra, root_validator from langchain.utils import get_from_dict_or_env +class TrainResult(TypedDict): + loss: float + + class GradientLLM(LLM): """Gradient.ai LLM Endpoints. @@ -125,6 +129,51 @@ class GradientLLM(LLM): """Return type of llm.""" return "gradient" + def _kwargs_post_fine_tune_request( + self, inputs: Sequence[str], kwargs: Mapping[str, Any] + ) -> Mapping[str, Any]: + """Build the kwargs for the Post request, used by sync + + Args: + prompt (str): prompt used in query + kwargs (dict): model kwargs in payload + + Returns: + Dict[str, Union[str,dict]]: _description_ + """ + _model_kwargs = self.model_kwargs or {} + _params = {**_model_kwargs, **kwargs} + + multipliers = _params.get("multipliers", None) + + return dict( + url=f"{self.gradient_api_url}/models/{self.model_id}/fine-tune", + headers={ + "authorization": f"Bearer {self.gradient_access_token}", + "x-gradient-workspace-id": f"{self.gradient_workspace_id}", + "accept": "application/json", + "content-type": "application/json", + }, + json=dict( + samples=tuple( + { + "inputs": input, + } + for input in inputs + ) + if multipliers is None + else tuple( + { + "inputs": input, + "fineTuningParameters": { + "multiplier": multiplier, + }, + } + for input, multiplier in zip(inputs, multipliers) + ), + ), + ) + def _kwargs_post_request( self, prompt: str, kwargs: Mapping[str, Any] ) -> Mapping[str, Any]: @@ -234,3 +283,60 @@ class GradientLLM(LLM): text = enforce_stop_tokens(text, stop) return text + + def train_unsupervised( + self, + inputs: Sequence[str], + **kwargs: Any, + ) -> TrainResult: + try: + response = requests.post( + **self._kwargs_post_fine_tune_request(inputs, kwargs) + ) + if response.status_code != 200: + raise Exception( + f"Gradient returned an unexpected response with status " + f"{response.status_code}: {response.text}" + ) + except requests.exceptions.RequestException as e: + raise Exception(f"RequestException while calling Gradient Endpoint: {e}") + + response_json = response.json() + loss = response_json["sumLoss"] / response_json["numberOfTrainableTokens"] + return TrainResult(loss=loss) + + async def atrain_unsupervised( + self, + inputs: Sequence[str], + **kwargs: Any, + ) -> TrainResult: + if not self.aiosession: + async with aiohttp.ClientSession() as session: + async with session.post( + **self._kwargs_post_fine_tune_request(inputs, kwargs) + ) as response: + if response.status != 200: + raise Exception( + f"Gradient returned an unexpected response with status " + f"{response.status}: {response.text}" + ) + response_json = await response.json() + loss = ( + response_json["sumLoss"] + / response_json["numberOfTrainableTokens"] + ) + else: + async with self.aiosession.post( + **self._kwargs_post_fine_tune_request(inputs, kwargs) + ) as response: + if response.status != 200: + raise Exception( + f"Gradient returned an unexpected response with status " + f"{response.status}: {response.text}" + ) + response_json = await response.json() + loss = ( + response_json["sumLoss"] / response_json["numberOfTrainableTokens"] + ) + + return TrainResult(loss=loss)