mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 21:50:25 +00:00
Add TrainableLLM
(#11721)
- **Description:** Add `TrainableLLM` for those LLM support fine-tuning - **Tag maintainer:** @hwchase17 This PR add training methods to `GradientLLM` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
63e516c2b0
commit
9e1e0f54d2
@ -50,12 +50,7 @@ from langchain.load.dump import dumpd
|
|||||||
from langchain.prompts.base import StringPromptValue
|
from langchain.prompts.base import StringPromptValue
|
||||||
from langchain.prompts.chat import ChatPromptValue
|
from langchain.prompts.chat import ChatPromptValue
|
||||||
from langchain.pydantic_v1 import Field, root_validator, validator
|
from langchain.pydantic_v1 import Field, root_validator, validator
|
||||||
from langchain.schema import (
|
from langchain.schema import Generation, LLMResult, PromptValue, RunInfo
|
||||||
Generation,
|
|
||||||
LLMResult,
|
|
||||||
PromptValue,
|
|
||||||
RunInfo,
|
|
||||||
)
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||||
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
||||||
from langchain.schema.output import GenerationChunk
|
from langchain.schema.output import GenerationChunk
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
@ -13,6 +13,10 @@ from langchain.pydantic_v1 import Extra, root_validator
|
|||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
class TrainResult(TypedDict):
|
||||||
|
loss: float
|
||||||
|
|
||||||
|
|
||||||
class GradientLLM(LLM):
|
class GradientLLM(LLM):
|
||||||
"""Gradient.ai LLM Endpoints.
|
"""Gradient.ai LLM Endpoints.
|
||||||
|
|
||||||
@ -125,6 +129,51 @@ class GradientLLM(LLM):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "gradient"
|
return "gradient"
|
||||||
|
|
||||||
|
def _kwargs_post_fine_tune_request(
|
||||||
|
self, inputs: Sequence[str], kwargs: Mapping[str, Any]
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""Build the kwargs for the Post request, used by sync
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): prompt used in query
|
||||||
|
kwargs (dict): model kwargs in payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Union[str,dict]]: _description_
|
||||||
|
"""
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
_params = {**_model_kwargs, **kwargs}
|
||||||
|
|
||||||
|
multipliers = _params.get("multipliers", None)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
url=f"{self.gradient_api_url}/models/{self.model_id}/fine-tune",
|
||||||
|
headers={
|
||||||
|
"authorization": f"Bearer {self.gradient_access_token}",
|
||||||
|
"x-gradient-workspace-id": f"{self.gradient_workspace_id}",
|
||||||
|
"accept": "application/json",
|
||||||
|
"content-type": "application/json",
|
||||||
|
},
|
||||||
|
json=dict(
|
||||||
|
samples=tuple(
|
||||||
|
{
|
||||||
|
"inputs": input,
|
||||||
|
}
|
||||||
|
for input in inputs
|
||||||
|
)
|
||||||
|
if multipliers is None
|
||||||
|
else tuple(
|
||||||
|
{
|
||||||
|
"inputs": input,
|
||||||
|
"fineTuningParameters": {
|
||||||
|
"multiplier": multiplier,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for input, multiplier in zip(inputs, multipliers)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _kwargs_post_request(
|
def _kwargs_post_request(
|
||||||
self, prompt: str, kwargs: Mapping[str, Any]
|
self, prompt: str, kwargs: Mapping[str, Any]
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
@ -234,3 +283,60 @@ class GradientLLM(LLM):
|
|||||||
text = enforce_stop_tokens(text, stop)
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def train_unsupervised(
|
||||||
|
self,
|
||||||
|
inputs: Sequence[str],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> TrainResult:
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
**self._kwargs_post_fine_tune_request(inputs, kwargs)
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Gradient returned an unexpected response with status "
|
||||||
|
f"{response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise Exception(f"RequestException while calling Gradient Endpoint: {e}")
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
loss = response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
|
||||||
|
return TrainResult(loss=loss)
|
||||||
|
|
||||||
|
async def atrain_unsupervised(
|
||||||
|
self,
|
||||||
|
inputs: Sequence[str],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> TrainResult:
|
||||||
|
if not self.aiosession:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
**self._kwargs_post_fine_tune_request(inputs, kwargs)
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Gradient returned an unexpected response with status "
|
||||||
|
f"{response.status}: {response.text}"
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
loss = (
|
||||||
|
response_json["sumLoss"]
|
||||||
|
/ response_json["numberOfTrainableTokens"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
async with self.aiosession.post(
|
||||||
|
**self._kwargs_post_fine_tune_request(inputs, kwargs)
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Gradient returned an unexpected response with status "
|
||||||
|
f"{response.status}: {response.text}"
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
loss = (
|
||||||
|
response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainResult(loss=loss)
|
||||||
|
Loading…
Reference in New Issue
Block a user