mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -5,7 +5,7 @@ from coati.experience_maker import Experience
|
||||
|
||||
class TrainerCallback(ABC):
|
||||
"""
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
"""
|
||||
|
||||
def on_fit_start(self) -> None:
|
||||
@@ -40,7 +40,6 @@ class TrainerCallback(ABC):
|
||||
|
||||
|
||||
class MakerCallback(ABC):
|
||||
|
||||
def on_loop_start(self) -> None:
|
||||
pass
|
||||
|
||||
|
@@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.
|
||||
self.duration: float = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
@@ -42,13 +41,13 @@ class Timer:
|
||||
self.duration += time() - self.start_time
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.
|
||||
self.duration = 0.0
|
||||
|
||||
|
||||
class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
|
||||
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
|
||||
reward_model_num_params: int) -> None:
|
||||
def __init__(
|
||||
self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
@@ -63,7 +62,7 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
self.make_experience_flop: int = 0
|
||||
|
||||
print_rank_0(
|
||||
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
||||
f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
|
||||
)
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
@@ -110,27 +109,29 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
|
||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
||||
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
|
||||
(self.total_samples * self.world_size)
|
||||
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
|
||||
self.total_samples * self.world_size
|
||||
)
|
||||
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
|
||||
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
"Making Experience Performance Summary:\n"
|
||||
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
|
||||
+ f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
|
||||
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
|
||||
+ f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
+ f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
)
|
||||
|
||||
|
||||
class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
|
||||
def __init__(self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_first_episodes: int = 1) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_first_episodes: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
@@ -146,7 +147,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
self.learn_flop: int = 0
|
||||
|
||||
print_rank_0(
|
||||
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
||||
f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
|
||||
)
|
||||
|
||||
def on_episode_start(self, episodes: int) -> None:
|
||||
@@ -191,7 +192,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
if self.total_samples == 0:
|
||||
print_rank_0('No samples are collected, skip trainer performance evaluation')
|
||||
print_rank_0("No samples are collected, skip trainer performance evaluation")
|
||||
return
|
||||
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
||||
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
|
||||
@@ -204,9 +205,10 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
"Learning Performance Summary:\n"
|
||||
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
|
||||
+ f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
|
||||
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
|
||||
+ f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
+ f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
)
|
||||
|
@@ -1,20 +1,15 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import random
|
||||
from threading import Lock
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer import ExperienceBuffer
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker.base import Experience
|
||||
|
||||
# from torch.multiprocessing import Queue
|
||||
from ray.util.queue import Queue
|
||||
|
||||
|
||||
class DetachedReplayBuffer:
|
||||
'''
|
||||
"""
|
||||
Detached replay buffer. Share Experience across workers on the same node.
|
||||
Therefore, a trainer node is expected to have only one instance.
|
||||
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
|
||||
@@ -24,7 +19,7 @@ class DetachedReplayBuffer:
|
||||
tp_world_size: Number of workers in the same tp group
|
||||
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
|
||||
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
||||
self.sample_batch_size = sample_batch_size
|
||||
@@ -34,23 +29,23 @@ class DetachedReplayBuffer:
|
||||
|
||||
@torch.no_grad()
|
||||
def append(self, experience: Experience) -> None:
|
||||
'''
|
||||
"""
|
||||
Expected to be called remotely.
|
||||
'''
|
||||
"""
|
||||
items = split_experience_batch(experience)
|
||||
self.extend(items)
|
||||
|
||||
@torch.no_grad()
|
||||
def extend(self, items: List[BufferItem]) -> None:
|
||||
'''
|
||||
"""
|
||||
Expected to be called remotely.
|
||||
'''
|
||||
"""
|
||||
self.batch_collector.extend(items)
|
||||
while len(self.batch_collector) >= self.sample_batch_size:
|
||||
items = self.batch_collector[:self.sample_batch_size]
|
||||
items = self.batch_collector[: self.sample_batch_size]
|
||||
experience = make_experience_batch(items)
|
||||
self.items.put(experience, block=True)
|
||||
self.batch_collector = self.batch_collector[self.sample_batch_size:]
|
||||
self.batch_collector = self.batch_collector[self.sample_batch_size :]
|
||||
|
||||
def clear(self) -> None:
|
||||
# self.items.close()
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@@ -15,7 +15,7 @@ from .utils import is_rank_0
|
||||
|
||||
|
||||
class DetachedTrainer(ABC):
|
||||
'''
|
||||
"""
|
||||
Base class for detached rlhf trainers.
|
||||
'detach' means that the experience maker is detached compared to a normal Trainer.
|
||||
Please set name attribute during init:
|
||||
@@ -28,15 +28,17 @@ class DetachedTrainer(ABC):
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[TrainerCallback] = [],
|
||||
debug: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[TrainerCallback] = [],
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
@@ -67,18 +69,16 @@ class DetachedTrainer(ABC):
|
||||
def _learn(self, update_steps: int, train_epochs: int) -> None:
|
||||
data = []
|
||||
# warmup
|
||||
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
|
||||
pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
|
||||
self._on_epoch_start(0)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(0)
|
||||
# item is already a batch
|
||||
dataloader = DataLoader(data,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
pin_memory=self.dataloader_pin_memory,
|
||||
collate_fn=lambda x: x[0])
|
||||
dataloader = DataLoader(
|
||||
data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
|
||||
)
|
||||
for epoch in range(1, train_epochs):
|
||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
|
||||
pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
|
||||
self._on_epoch_start(epoch)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(epoch)
|
||||
@@ -104,7 +104,7 @@ class DetachedTrainer(ABC):
|
||||
|
||||
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
|
||||
self._on_fit_start()
|
||||
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
|
||||
for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
|
||||
self._on_episode_start(i)
|
||||
self._learn(update_steps, train_epochs)
|
||||
self._on_update_start()
|
||||
|
@@ -1,12 +1,11 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.experience_maker import Experience
|
||||
from coati.models.base import Actor, Critic
|
||||
from coati.models.loss import PolicyLoss, ValueLoss
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
@@ -14,27 +13,14 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
|
||||
from .detached_trainer_base import DetachedTrainer
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import (
|
||||
get_actor_from_args,
|
||||
get_critic_from_args,
|
||||
get_model_numel,
|
||||
get_rank,
|
||||
get_strategy_from_args,
|
||||
is_rank_0,
|
||||
set_dist_env,
|
||||
state_dict_to,
|
||||
from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
|
||||
|
||||
|
||||
@ray.remote(
|
||||
concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
|
||||
)
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={
|
||||
"buffer_length": 1,
|
||||
"buffer_append": 1,
|
||||
"buffer_sample": 1,
|
||||
"model_io": 1,
|
||||
"compute": 1
|
||||
})
|
||||
class DetachedPPOTrainer(DetachedTrainer):
|
||||
'''
|
||||
"""
|
||||
Detached Trainer for PPO algorithm
|
||||
Args:
|
||||
strategy (Strategy): the strategy to use for training
|
||||
@@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -92,21 +78,24 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
|
||||
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
|
||||
|
||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
|
||||
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
|
||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
|
||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim)
|
||||
)
|
||||
|
||||
# configure trainer
|
||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||
self.critic_loss_fn = ValueLoss(value_clip)
|
||||
|
||||
super().__init__(experience_maker_holder_name_list,
|
||||
train_batch_size=train_batch_size,
|
||||
buffer_limit=buffer_limit,
|
||||
dataloader_pin_memory=dataloader_pin_memory,
|
||||
callbacks=callbacks,
|
||||
debug=debug)
|
||||
super().__init__(
|
||||
experience_maker_holder_name_list,
|
||||
train_batch_size=train_batch_size,
|
||||
buffer_limit=buffer_limit,
|
||||
dataloader_pin_memory=dataloader_pin_memory,
|
||||
callbacks=callbacks,
|
||||
debug=debug,
|
||||
)
|
||||
if self._debug:
|
||||
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
|
||||
print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
|
||||
|
||||
self._update_lora_weights = update_lora_weights
|
||||
|
||||
@@ -115,7 +104,7 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
def _update_remote_makers(self, fully_update: bool = False, **config):
|
||||
# TODO: balance duties
|
||||
if not fully_update:
|
||||
config['requires_grad_only'] = True
|
||||
config["requires_grad_only"] = True
|
||||
self.update_target_holder_list()
|
||||
# mark start, ensure order
|
||||
tasks = []
|
||||
@@ -131,7 +120,9 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
target_holder.update_experience_maker.remote(
|
||||
new_actor_state_dict=state_dict_shard,
|
||||
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
|
||||
fully_update=fully_update))
|
||||
fully_update=fully_update,
|
||||
)
|
||||
)
|
||||
# sending loop
|
||||
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
|
||||
for target_holder in self.target_holder_list:
|
||||
@@ -139,7 +130,9 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
target_holder.update_experience_maker.remote(
|
||||
new_critic_state_dict=state_dict_shard,
|
||||
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
|
||||
fully_update=fully_update))
|
||||
fully_update=fully_update,
|
||||
)
|
||||
)
|
||||
ray.get(tasks)
|
||||
# mark end
|
||||
for target_holder in self.target_holder_list:
|
||||
@@ -152,26 +145,24 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
|
||||
num_actions = experience.action_mask.size(1)
|
||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask)
|
||||
actor_loss = self.actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
)
|
||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||
self.strategy.optimizer_step(self.actor_optim)
|
||||
self.actor_optim.zero_grad()
|
||||
|
||||
values = self.critic(experience.sequences,
|
||||
action_mask=experience.action_mask,
|
||||
attention_mask=experience.attention_mask)
|
||||
critic_loss = self.critic_loss_fn(values,
|
||||
experience.values,
|
||||
experience.reward,
|
||||
action_mask=experience.action_mask)
|
||||
values = self.critic(
|
||||
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
||||
)
|
||||
critic_loss = self.critic_loss_fn(
|
||||
values, experience.values, experience.reward, action_mask=experience.action_mask
|
||||
)
|
||||
|
||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||
self.strategy.optimizer_step(self.critic_optim)
|
||||
self.critic_optim.zero_grad()
|
||||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||||
return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
|
||||
|
||||
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
|
||||
self.strategy.save_model(self.actor, path, only_rank0)
|
||||
|
@@ -1,53 +1,49 @@
|
||||
import os
|
||||
import time
|
||||
import tracemalloc
|
||||
from copy import deepcopy
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
|
||||
from coati.experience_buffer.utils import split_experience_batch
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import Strategy
|
||||
from coati.trainer.strategies.sampler import DistributedSampler
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
|
||||
from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||
class ExperienceMakerHolder:
|
||||
'''
|
||||
"""
|
||||
Args:
|
||||
detached_trainer_name_list: str list to get ray actor handles
|
||||
strategy:
|
||||
kl_coef: the coefficient of kl divergence loss
|
||||
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||
env_info: Dict[str, str] = None,
|
||||
sync_models_from_trainers: bool = False,
|
||||
buffer_cpu_offload: bool = True,
|
||||
kl_coef: float = 0.1,
|
||||
callbacks: List[MakerCallback] = [],
|
||||
eval_performance: bool = False,
|
||||
debug: bool = False,
|
||||
update_lora_weights: bool = False,
|
||||
**generate_kwargs):
|
||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||
env_info: Dict[str, str] = None,
|
||||
sync_models_from_trainers: bool = False,
|
||||
buffer_cpu_offload: bool = True,
|
||||
kl_coef: float = 0.1,
|
||||
callbacks: List[MakerCallback] = [],
|
||||
eval_performance: bool = False,
|
||||
debug: bool = False,
|
||||
update_lora_weights: bool = False,
|
||||
**generate_kwargs,
|
||||
):
|
||||
# set environment variables
|
||||
if env_info:
|
||||
set_dist_env(env_info=env_info)
|
||||
@@ -66,8 +62,9 @@ class ExperienceMakerHolder:
|
||||
critic_numel = get_model_numel(critic)
|
||||
initial_model_numel = get_model_numel(initial_model)
|
||||
reward_model_numel = get_model_numel(reward_model)
|
||||
evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
|
||||
reward_model_numel)
|
||||
evaluator = ExperienceMakerPerformanceEvaluator(
|
||||
actor_numel, critic_numel, initial_model_numel, reward_model_numel
|
||||
)
|
||||
callbacks = callbacks + [evaluator]
|
||||
|
||||
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
|
||||
@@ -89,9 +86,9 @@ class ExperienceMakerHolder:
|
||||
self._target_idx = 0
|
||||
|
||||
if self._debug:
|
||||
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
|
||||
print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
|
||||
if not self._is_fully_initialized:
|
||||
print(f'[maker{get_rank()}] Waiting for INIT')
|
||||
print(f"[maker{get_rank()}] Waiting for INIT")
|
||||
|
||||
def _get_ready(self):
|
||||
while not self._fully_initialized():
|
||||
@@ -136,7 +133,7 @@ class ExperienceMakerHolder:
|
||||
self._on_make_experience_end(experience)
|
||||
self._on_send_start()
|
||||
if self.buffer_cpu_offload:
|
||||
experience.to_device('cpu')
|
||||
experience.to_device("cpu")
|
||||
self._send_items(experience)
|
||||
self._on_send_end()
|
||||
self._on_batch_end()
|
||||
@@ -155,7 +152,7 @@ class ExperienceMakerHolder:
|
||||
if num_steps > 0:
|
||||
# ignore num epochs
|
||||
it = iter(dataloader)
|
||||
for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
|
||||
for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
|
||||
try:
|
||||
batch = next(it)
|
||||
except StopIteration:
|
||||
@@ -163,7 +160,7 @@ class ExperienceMakerHolder:
|
||||
batch = next(it)
|
||||
self._inference_step(batch)
|
||||
else:
|
||||
with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
|
||||
with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
|
||||
for _ in range(num_epochs):
|
||||
for batch in dataloader:
|
||||
self._inference_step(batch)
|
||||
@@ -171,22 +168,24 @@ class ExperienceMakerHolder:
|
||||
self._on_loop_end()
|
||||
|
||||
@ray.method(concurrency_group="model_io")
|
||||
def update_experience_maker(self,
|
||||
new_actor_state_dict: Dict[str, Any] = None,
|
||||
new_actor_lora_config_dict: Dict[str, Any] = None,
|
||||
new_critic_state_dict: Dict[str, Any] = None,
|
||||
new_critic_lora_config_dict: Dict[str, Any] = None,
|
||||
fully_update: bool = False,
|
||||
chunk_start: bool = None,
|
||||
chunk_end: bool = None):
|
||||
'''
|
||||
called by trainer
|
||||
chunk_start: Set True at the first call. Before sending state_dict calls
|
||||
chunk_end: Set True at the last call. After sending state_dict calls.
|
||||
fully_update: Set True if you want to sync models when initializing
|
||||
def update_experience_maker(
|
||||
self,
|
||||
new_actor_state_dict: Dict[str, Any] = None,
|
||||
new_actor_lora_config_dict: Dict[str, Any] = None,
|
||||
new_critic_state_dict: Dict[str, Any] = None,
|
||||
new_critic_lora_config_dict: Dict[str, Any] = None,
|
||||
fully_update: bool = False,
|
||||
chunk_start: bool = None,
|
||||
chunk_end: bool = None,
|
||||
):
|
||||
"""
|
||||
called by trainer
|
||||
chunk_start: Set True at the first call. Before sending state_dict calls
|
||||
chunk_end: Set True at the last call. After sending state_dict calls.
|
||||
fully_update: Set True if you want to sync models when initializing
|
||||
|
||||
TODO: load_state_dict integrate with model-sharding strategy
|
||||
'''
|
||||
TODO: load_state_dict integrate with model-sharding strategy
|
||||
"""
|
||||
_watch_memory = self._debug
|
||||
if chunk_start:
|
||||
if self._debug:
|
||||
@@ -202,18 +201,22 @@ class ExperienceMakerHolder:
|
||||
else:
|
||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
|
||||
new_actor_state_dict, new_actor_lora_config_dict)
|
||||
new_actor_state_dict, new_actor_lora_config_dict
|
||||
)
|
||||
self.actor_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.actor.model, state_dict_increase)
|
||||
self.experience_maker.actor.model, state_dict_increase
|
||||
)
|
||||
if new_critic_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||
else:
|
||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
|
||||
new_critic_state_dict, new_critic_lora_config_dict)
|
||||
new_critic_state_dict, new_critic_lora_config_dict
|
||||
)
|
||||
self.critic_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.critic, state_dict_increase)
|
||||
self.experience_maker.critic, state_dict_increase
|
||||
)
|
||||
|
||||
# the lock must be released after both actor and critic being updated
|
||||
if chunk_end:
|
||||
@@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = actor.model
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||
if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
|
||||
new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
|
||||
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
|
||||
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
|
||||
new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
|
||||
|
||||
return new_kwargs
|
||||
|
@@ -1,11 +1,9 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.lora import LoraLinear
|
||||
from loralib.layers import LoRALayer
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -17,7 +15,7 @@ class LoRAConfig:
|
||||
|
||||
|
||||
class LoRAConstructor:
|
||||
'''
|
||||
"""
|
||||
Tools for reconstructing a model from a remote LoRA model.
|
||||
(Transferring only LoRA data costs much less!)
|
||||
Usage:
|
||||
@@ -36,7 +34,7 @@ class LoRAConstructor:
|
||||
Step 5 (Receiver):
|
||||
load_state_dict_increase()
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.lora_config_dict = None
|
||||
@@ -45,10 +43,10 @@ class LoRAConstructor:
|
||||
self.lora_config_dict = lora_config_dict
|
||||
|
||||
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
|
||||
'''
|
||||
xxx.lora_A, xxx.lora_B -->> xxx.weight
|
||||
Warning: the xxx.weight here is the increment actually.
|
||||
'''
|
||||
"""
|
||||
xxx.lora_A, xxx.lora_B -->> xxx.weight
|
||||
Warning: the xxx.weight here is the increment actually.
|
||||
"""
|
||||
if lora_config_dict is not None:
|
||||
self.register_lora_config(lora_config_dict)
|
||||
|
||||
@@ -56,24 +54,25 @@ class LoRAConstructor:
|
||||
config_iter = iter(self.lora_config_dict.items())
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
for k, v in state_dict_lora.items():
|
||||
if k.rpartition('.')[-1] == 'lora_A':
|
||||
if k.rpartition(".")[-1] == "lora_A":
|
||||
lora_A = v
|
||||
layer_prefix = k.rpartition('.')[0]
|
||||
elif k.rpartition('.')[-1] == 'lora_B':
|
||||
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
|
||||
layer_prefix = k.rpartition(".")[0]
|
||||
elif k.rpartition(".")[-1] == "lora_B":
|
||||
assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
|
||||
layer_prefix_2, config = next(config_iter)
|
||||
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
||||
lora_B = v
|
||||
weight_data_increase = self._compute(lora_A, lora_B, config)
|
||||
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
|
||||
state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
else:
|
||||
raise ValueError('unexpected key')
|
||||
raise ValueError("unexpected key")
|
||||
return state_dict_increase
|
||||
|
||||
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
||||
def T(w):
|
||||
return w.T if config.fan_in_fan_out else w
|
||||
|
||||
if config.r > 0:
|
||||
scaling = config.lora_alpha / config.r
|
||||
weight_data_increase = T(lora_B @ lora_A) * scaling
|
||||
@@ -81,21 +80,21 @@ class LoRAConstructor:
|
||||
return 0
|
||||
|
||||
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
|
||||
'''
|
||||
"""
|
||||
The final reconstruction step
|
||||
'''
|
||||
"""
|
||||
# naive approach
|
||||
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
|
||||
|
||||
@staticmethod
|
||||
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
||||
'''
|
||||
"""
|
||||
if keep_non_lora, also return non_lora state_dict
|
||||
'''
|
||||
"""
|
||||
state_dict_lora = OrderedDict()
|
||||
state_dict_non_lora = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if 'lora_A' in k or 'lora_B' in k:
|
||||
if "lora_A" in k or "lora_B" in k:
|
||||
state_dict_lora[k] = v
|
||||
elif keep_non_lora:
|
||||
state_dict_non_lora[k] = v
|
||||
@@ -106,17 +105,19 @@ class LoRAConstructor:
|
||||
|
||||
@staticmethod
|
||||
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
|
||||
'''
|
||||
"""
|
||||
extract LoraLinear model.
|
||||
return OrderedDict(): name -> LoRAConfig
|
||||
'''
|
||||
"""
|
||||
lora_config_dict = OrderedDict()
|
||||
|
||||
for name, child in model.named_modules():
|
||||
if isinstance(child, LoraLinear):
|
||||
lora_config_dict[name] = LoRAConfig(r=child.r,
|
||||
lora_alpha=child.lora_alpha,
|
||||
lora_dropout=child.lora_dropout,
|
||||
fan_in_fan_out=child.fan_in_fan_out)
|
||||
lora_config_dict[name] = LoRAConfig(
|
||||
r=child.r,
|
||||
lora_alpha=child.lora_alpha,
|
||||
lora_dropout=child.lora_dropout,
|
||||
fan_in_fan_out=child.fan_in_fan_out,
|
||||
)
|
||||
|
||||
return lora_config_dict
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
@@ -26,13 +26,13 @@ def get_world_size() -> int:
|
||||
|
||||
|
||||
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||
if model == 'gpt2':
|
||||
if model == "gpt2":
|
||||
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'bloom':
|
||||
elif model == "bloom":
|
||||
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'opt':
|
||||
elif model == "opt":
|
||||
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'llama':
|
||||
elif model == "llama":
|
||||
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{model}"')
|
||||
@@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
|
||||
|
||||
|
||||
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||
if model == 'gpt2':
|
||||
if model == "gpt2":
|
||||
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'bloom':
|
||||
elif model == "bloom":
|
||||
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'opt':
|
||||
elif model == "opt":
|
||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'llama':
|
||||
elif model == "llama":
|
||||
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
@@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r
|
||||
|
||||
|
||||
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
||||
if model == 'gpt2':
|
||||
if model == "gpt2":
|
||||
reward_model = GPTRM(pretrained=pretrained, config=config)
|
||||
elif model == 'bloom':
|
||||
elif model == "bloom":
|
||||
reward_model = BLOOMRM(pretrained=pretrained, config=config)
|
||||
elif model == 'opt':
|
||||
elif model == "opt":
|
||||
reward_model = OPTRM(pretrained=pretrained, config=config)
|
||||
elif model == 'llama':
|
||||
elif model == "llama":
|
||||
reward_model = LlamaRM(pretrained=pretrained, config=config)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
@@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
||||
|
||||
|
||||
def get_strategy_from_args(strategy: str):
|
||||
if strategy == 'ddp':
|
||||
if strategy == "ddp":
|
||||
strategy_ = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif strategy == 'colossalai_gemini_cpu':
|
||||
strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2_cpu':
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
elif strategy == "colossalai_gemini":
|
||||
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif strategy == "colossalai_gemini_cpu":
|
||||
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2_cpu":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
return strategy_
|
||||
|
||||
|
||||
def get_tokenizer_from_args(model: str, **kwargs):
|
||||
if model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
elif model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
elif model == 'opt':
|
||||
if model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
elif model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||
elif model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif model == 'llama':
|
||||
elif model == "llama":
|
||||
pretrain_path = kwargs["pretrain"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
||||
else:
|
||||
@@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs):
|
||||
|
||||
|
||||
def set_dist_env(env_info: Dict[str, str]):
|
||||
os.environ["RANK"] = env_info['rank']
|
||||
os.environ["LOCAL_RANK"] = env_info['local_rank']
|
||||
os.environ["WORLD_SIZE"] = env_info['world_size']
|
||||
os.environ['MASTER_PORT'] = env_info['master_port']
|
||||
os.environ['MASTER_ADDR'] = env_info['master_addr']
|
||||
os.environ["RANK"] = env_info["rank"]
|
||||
os.environ["LOCAL_RANK"] = env_info["local_rank"]
|
||||
os.environ["WORLD_SIZE"] = env_info["world_size"]
|
||||
os.environ["MASTER_PORT"] = env_info["master_port"]
|
||||
os.environ["MASTER_ADDR"] = env_info["master_addr"]
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
@@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
|
||||
return target_receivers
|
||||
|
||||
|
||||
def state_dict_to(state_dict: Dict[str, Any],
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
'''
|
||||
keep state_dict intact
|
||||
'''
|
||||
def state_dict_to(
|
||||
state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
|
||||
):
|
||||
"""
|
||||
keep state_dict intact
|
||||
"""
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k] = v.to(dtype=dtype, device=device)
|
||||
|
Reference in New Issue
Block a user