mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
Add TrainableLLM
(#11721)
- **Description:** Add `TrainableLLM` for those LLM support fine-tuning - **Tag maintainer:** @hwchase17 This PR add training methods to `GradientLLM` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
63e516c2b0
commit
9e1e0f54d2
@ -50,12 +50,7 @@ from langchain.load.dump import dumpd
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.pydantic_v1 import Field, root_validator, validator
|
||||
from langchain.schema import (
|
||||
Generation,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain.schema import Generation, LLMResult, PromptValue, RunInfo
|
||||
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
||||
from langchain.schema.output import GenerationChunk
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@ -13,6 +13,10 @@ from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class TrainResult(TypedDict):
|
||||
loss: float
|
||||
|
||||
|
||||
class GradientLLM(LLM):
|
||||
"""Gradient.ai LLM Endpoints.
|
||||
|
||||
@ -125,6 +129,51 @@ class GradientLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "gradient"
|
||||
|
||||
def _kwargs_post_fine_tune_request(
|
||||
self, inputs: Sequence[str], kwargs: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
"""Build the kwargs for the Post request, used by sync
|
||||
|
||||
Args:
|
||||
prompt (str): prompt used in query
|
||||
kwargs (dict): model kwargs in payload
|
||||
|
||||
Returns:
|
||||
Dict[str, Union[str,dict]]: _description_
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_params = {**_model_kwargs, **kwargs}
|
||||
|
||||
multipliers = _params.get("multipliers", None)
|
||||
|
||||
return dict(
|
||||
url=f"{self.gradient_api_url}/models/{self.model_id}/fine-tune",
|
||||
headers={
|
||||
"authorization": f"Bearer {self.gradient_access_token}",
|
||||
"x-gradient-workspace-id": f"{self.gradient_workspace_id}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
json=dict(
|
||||
samples=tuple(
|
||||
{
|
||||
"inputs": input,
|
||||
}
|
||||
for input in inputs
|
||||
)
|
||||
if multipliers is None
|
||||
else tuple(
|
||||
{
|
||||
"inputs": input,
|
||||
"fineTuningParameters": {
|
||||
"multiplier": multiplier,
|
||||
},
|
||||
}
|
||||
for input, multiplier in zip(inputs, multipliers)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def _kwargs_post_request(
|
||||
self, prompt: str, kwargs: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
@ -234,3 +283,60 @@ class GradientLLM(LLM):
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
|
||||
def train_unsupervised(
|
||||
self,
|
||||
inputs: Sequence[str],
|
||||
**kwargs: Any,
|
||||
) -> TrainResult:
|
||||
try:
|
||||
response = requests.post(
|
||||
**self._kwargs_post_fine_tune_request(inputs, kwargs)
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Gradient returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"RequestException while calling Gradient Endpoint: {e}")
|
||||
|
||||
response_json = response.json()
|
||||
loss = response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
|
||||
return TrainResult(loss=loss)
|
||||
|
||||
async def atrain_unsupervised(
|
||||
self,
|
||||
inputs: Sequence[str],
|
||||
**kwargs: Any,
|
||||
) -> TrainResult:
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
**self._kwargs_post_fine_tune_request(inputs, kwargs)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Gradient returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
response_json = await response.json()
|
||||
loss = (
|
||||
response_json["sumLoss"]
|
||||
/ response_json["numberOfTrainableTokens"]
|
||||
)
|
||||
else:
|
||||
async with self.aiosession.post(
|
||||
**self._kwargs_post_fine_tune_request(inputs, kwargs)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Gradient returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
response_json = await response.json()
|
||||
loss = (
|
||||
response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
|
||||
)
|
||||
|
||||
return TrainResult(loss=loss)
|
||||
|
Loading…
Reference in New Issue
Block a user