[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class TrainerCallback(ABC):
"""
Base callback class. It defines the interface for callbacks.
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
@@ -40,7 +40,6 @@ class TrainerCallback(ABC):
class MakerCallback(ABC):
def on_loop_start(self) -> None:
pass

View File

@@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
@@ -42,13 +41,13 @@ class Timer:
self.duration += time() - self.start_time
def reset(self) -> None:
self.duration = 0.
self.duration = 0.0
class ExperienceMakerPerformanceEvaluator(MakerCallback):
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
reward_model_num_params: int) -> None:
def __init__(
self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@@ -63,7 +62,7 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
self.make_experience_flop: int = 0
print_rank_0(
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_make_experience_start(self) -> None:
@@ -110,27 +109,29 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
(self.total_samples * self.world_size)
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
self.total_samples * self.world_size
)
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
"Making Experience Performance Summary:\n"
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)
class TrainerPerformanceEvaluator(TrainerCallback):
def __init__(self,
actor_num_params: int,
critic_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_first_episodes: int = 1) -> None:
def __init__(
self,
actor_num_params: int,
critic_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_first_episodes: int = 1,
) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@@ -146,7 +147,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
self.learn_flop: int = 0
print_rank_0(
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_episode_start(self, episodes: int) -> None:
@@ -191,7 +192,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
def on_fit_end(self) -> None:
if self.total_samples == 0:
print_rank_0('No samples are collected, skip trainer performance evaluation')
print_rank_0("No samples are collected, skip trainer performance evaluation")
return
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
@@ -204,9 +205,10 @@ class TrainerPerformanceEvaluator(TrainerCallback):
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
"Learning Performance Summary:\n"
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)

View File

@@ -1,20 +1,15 @@
import asyncio
import copy
import random
from threading import Lock
from typing import Any, List
from typing import List
import ray
import torch
from coati.experience_buffer import ExperienceBuffer
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker.base import Experience
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
class DetachedReplayBuffer:
'''
"""
Detached replay buffer. Share Experience across workers on the same node.
Therefore, a trainer node is expected to have only one instance.
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
@@ -24,7 +19,7 @@ class DetachedReplayBuffer:
tp_world_size: Number of workers in the same tp group
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
'''
"""
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
self.sample_batch_size = sample_batch_size
@@ -34,23 +29,23 @@ class DetachedReplayBuffer:
@torch.no_grad()
def append(self, experience: Experience) -> None:
'''
"""
Expected to be called remotely.
'''
"""
items = split_experience_batch(experience)
self.extend(items)
@torch.no_grad()
def extend(self, items: List[BufferItem]) -> None:
'''
"""
Expected to be called remotely.
'''
"""
self.batch_collector.extend(items)
while len(self.batch_collector) >= self.sample_batch_size:
items = self.batch_collector[:self.sample_batch_size]
items = self.batch_collector[: self.sample_batch_size]
experience = make_experience_batch(items)
self.items.put(experience, block=True)
self.batch_collector = self.batch_collector[self.sample_batch_size:]
self.batch_collector = self.batch_collector[self.sample_batch_size :]
def clear(self) -> None:
# self.items.close()

View File

@@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, List
import ray
import torch
@@ -15,7 +15,7 @@ from .utils import is_rank_0
class DetachedTrainer(ABC):
'''
"""
Base class for detached rlhf trainers.
'detach' means that the experience maker is detached compared to a normal Trainer.
Please set name attribute during init:
@@ -28,15 +28,17 @@ class DetachedTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
"""
def __init__(self,
experience_maker_holder_name_list: List[str],
train_batch_size: int = 8,
buffer_limit: int = 0,
dataloader_pin_memory: bool = True,
callbacks: List[TrainerCallback] = [],
debug: bool = False) -> None:
def __init__(
self,
experience_maker_holder_name_list: List[str],
train_batch_size: int = 8,
buffer_limit: int = 0,
dataloader_pin_memory: bool = True,
callbacks: List[TrainerCallback] = [],
debug: bool = False,
) -> None:
super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.dataloader_pin_memory = dataloader_pin_memory
@@ -67,18 +69,16 @@ class DetachedTrainer(ABC):
def _learn(self, update_steps: int, train_epochs: int) -> None:
data = []
# warmup
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(0)
self._learn_epoch(pbar, data)
self._on_epoch_end(0)
# item is already a batch
dataloader = DataLoader(data,
batch_size=1,
shuffle=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=lambda x: x[0])
dataloader = DataLoader(
data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
)
for epoch in range(1, train_epochs):
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(epoch)
self._learn_epoch(pbar, data)
self._on_epoch_end(epoch)
@@ -104,7 +104,7 @@ class DetachedTrainer(ABC):
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start()
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
self._on_episode_start(i)
self._learn(update_steps, train_epochs)
self._on_update_start()

View File

@@ -1,12 +1,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Tuple
import ray
import torch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.experience_maker import Experience
from coati.models.base import Actor, Critic
from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
@@ -14,27 +13,14 @@ from colossalai.nn.optimizer import HybridAdam
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .detached_trainer_base import DetachedTrainer
from .lora_constructor import LoRAConstructor
from .utils import (
get_actor_from_args,
get_critic_from_args,
get_model_numel,
get_rank,
get_strategy_from_args,
is_rank_0,
set_dist_env,
state_dict_to,
from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
@ray.remote(
concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
)
@ray.remote(concurrency_groups={
"buffer_length": 1,
"buffer_append": 1,
"buffer_sample": 1,
"model_io": 1,
"compute": 1
})
class DetachedPPOTrainer(DetachedTrainer):
'''
"""
Detached Trainer for PPO algorithm
Args:
strategy (Strategy): the strategy to use for training
@@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer):
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
"""
def __init__(
self,
@@ -92,21 +78,24 @@ class DetachedPPOTrainer(DetachedTrainer):
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
(self.actor, self.actor_optim), (self.critic, self.critic_optim)
)
# configure trainer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
super().__init__(experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug)
super().__init__(
experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug,
)
if self._debug:
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
self._update_lora_weights = update_lora_weights
@@ -115,7 +104,7 @@ class DetachedPPOTrainer(DetachedTrainer):
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if not fully_update:
config['requires_grad_only'] = True
config["requires_grad_only"] = True
self.update_target_holder_list()
# mark start, ensure order
tasks = []
@@ -131,7 +120,9 @@ class DetachedPPOTrainer(DetachedTrainer):
target_holder.update_experience_maker.remote(
new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
fully_update=fully_update))
fully_update=fully_update,
)
)
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
for target_holder in self.target_holder_list:
@@ -139,7 +130,9 @@ class DetachedPPOTrainer(DetachedTrainer):
target_holder.update_experience_maker.remote(
new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
fully_update=fully_update))
fully_update=fully_update,
)
)
ray.get(tasks)
# mark end
for target_holder in self.target_holder_list:
@@ -152,26 +145,24 @@ class DetachedPPOTrainer(DetachedTrainer):
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
actor_loss = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
values = self.critic(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
values = self.critic(
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
)
critic_loss = self.critic_loss_fn(
values, experience.values, experience.reward, action_mask=experience.action_mask
)
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_model(self.actor, path, only_rank0)

View File

@@ -1,53 +1,49 @@
import os
import time
import tracemalloc
from copy import deepcopy
from threading import Lock
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import ray
import torch
import torch.nn as nn
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
from coati.experience_buffer.utils import split_experience_batch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy
from coati.trainer.strategies.sampler import DistributedSampler
from ray.exceptions import GetTimeoutError
from torch import Tensor
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .lora_constructor import LoRAConstructor
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
'''
"""
Args:
detached_trainer_name_list: str list to get ray actor handles
strategy:
kl_coef: the coefficient of kl divergence loss
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
'''
"""
def __init__(
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model)
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
buffer_cpu_offload: bool = True,
kl_coef: float = 0.1,
callbacks: List[MakerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
**generate_kwargs):
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
buffer_cpu_offload: bool = True,
kl_coef: float = 0.1,
callbacks: List[MakerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
**generate_kwargs,
):
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
@@ -66,8 +62,9 @@ class ExperienceMakerHolder:
critic_numel = get_model_numel(critic)
initial_model_numel = get_model_numel(initial_model)
reward_model_numel = get_model_numel(reward_model)
evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
reward_model_numel)
evaluator = ExperienceMakerPerformanceEvaluator(
actor_numel, critic_numel, initial_model_numel, reward_model_numel
)
callbacks = callbacks + [evaluator]
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
@@ -89,9 +86,9 @@ class ExperienceMakerHolder:
self._target_idx = 0
if self._debug:
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
if not self._is_fully_initialized:
print(f'[maker{get_rank()}] Waiting for INIT')
print(f"[maker{get_rank()}] Waiting for INIT")
def _get_ready(self):
while not self._fully_initialized():
@@ -136,7 +133,7 @@ class ExperienceMakerHolder:
self._on_make_experience_end(experience)
self._on_send_start()
if self.buffer_cpu_offload:
experience.to_device('cpu')
experience.to_device("cpu")
self._send_items(experience)
self._on_send_end()
self._on_batch_end()
@@ -155,7 +152,7 @@ class ExperienceMakerHolder:
if num_steps > 0:
# ignore num epochs
it = iter(dataloader)
for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
try:
batch = next(it)
except StopIteration:
@@ -163,7 +160,7 @@ class ExperienceMakerHolder:
batch = next(it)
self._inference_step(batch)
else:
with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
for _ in range(num_epochs):
for batch in dataloader:
self._inference_step(batch)
@@ -171,22 +168,24 @@ class ExperienceMakerHolder:
self._on_loop_end()
@ray.method(concurrency_group="model_io")
def update_experience_maker(self,
new_actor_state_dict: Dict[str, Any] = None,
new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None,
fully_update: bool = False,
chunk_start: bool = None,
chunk_end: bool = None):
'''
called by trainer
chunk_start: Set True at the first call. Before sending state_dict calls
chunk_end: Set True at the last call. After sending state_dict calls.
fully_update: Set True if you want to sync models when initializing
def update_experience_maker(
self,
new_actor_state_dict: Dict[str, Any] = None,
new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None,
fully_update: bool = False,
chunk_start: bool = None,
chunk_end: bool = None,
):
"""
called by trainer
chunk_start: Set True at the first call. Before sending state_dict calls
chunk_end: Set True at the last call. After sending state_dict calls.
fully_update: Set True if you want to sync models when initializing
TODO: load_state_dict integrate with model-sharding strategy
'''
TODO: load_state_dict integrate with model-sharding strategy
"""
_watch_memory = self._debug
if chunk_start:
if self._debug:
@@ -202,18 +201,22 @@ class ExperienceMakerHolder:
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
new_actor_state_dict, new_actor_lora_config_dict)
new_actor_state_dict, new_actor_lora_config_dict
)
self.actor_lora_constructor.load_state_dict_increase(
self.experience_maker.actor.model, state_dict_increase)
self.experience_maker.actor.model, state_dict_increase
)
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
new_critic_state_dict, new_critic_lora_config_dict)
new_critic_state_dict, new_critic_lora_config_dict
)
self.critic_lora_constructor.load_state_dict_increase(
self.experience_maker.critic, state_dict_increase)
self.experience_maker.critic, state_dict_increase
)
# the lock must be released after both actor and critic being updated
if chunk_end:
@@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
origin_model = actor.model
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
return new_kwargs

View File

@@ -1,11 +1,9 @@
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict
import torch
import torch.nn as nn
from coati.models.lora import LoraLinear
from loralib.layers import LoRALayer
@dataclass
@@ -17,7 +15,7 @@ class LoRAConfig:
class LoRAConstructor:
'''
"""
Tools for reconstructing a model from a remote LoRA model.
(Transferring only LoRA data costs much less!)
Usage:
@@ -36,7 +34,7 @@ class LoRAConstructor:
Step 5 (Receiver):
load_state_dict_increase()
'''
"""
def __init__(self):
self.lora_config_dict = None
@@ -45,10 +43,10 @@ class LoRAConstructor:
self.lora_config_dict = lora_config_dict
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
'''
xxx.lora_A, xxx.lora_B -->> xxx.weight
Warning: the xxx.weight here is the increment actually.
'''
"""
xxx.lora_A, xxx.lora_B -->> xxx.weight
Warning: the xxx.weight here is the increment actually.
"""
if lora_config_dict is not None:
self.register_lora_config(lora_config_dict)
@@ -56,24 +54,25 @@ class LoRAConstructor:
config_iter = iter(self.lora_config_dict.items())
lora_A, lora_B, layer_prefix = None, None, None
for k, v in state_dict_lora.items():
if k.rpartition('.')[-1] == 'lora_A':
if k.rpartition(".")[-1] == "lora_A":
lora_A = v
layer_prefix = k.rpartition('.')[0]
elif k.rpartition('.')[-1] == 'lora_B':
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
layer_prefix = k.rpartition(".")[0]
elif k.rpartition(".")[-1] == "lora_B":
assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
layer_prefix_2, config = next(config_iter)
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
lora_B = v
weight_data_increase = self._compute(lora_A, lora_B, config)
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
lora_A, lora_B, layer_prefix = None, None, None
else:
raise ValueError('unexpected key')
raise ValueError("unexpected key")
return state_dict_increase
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
def T(w):
return w.T if config.fan_in_fan_out else w
if config.r > 0:
scaling = config.lora_alpha / config.r
weight_data_increase = T(lora_B @ lora_A) * scaling
@@ -81,21 +80,21 @@ class LoRAConstructor:
return 0
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
'''
"""
The final reconstruction step
'''
"""
# naive approach
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
@staticmethod
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
'''
"""
if keep_non_lora, also return non_lora state_dict
'''
"""
state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict()
for k, v in state_dict.items():
if 'lora_A' in k or 'lora_B' in k:
if "lora_A" in k or "lora_B" in k:
state_dict_lora[k] = v
elif keep_non_lora:
state_dict_non_lora[k] = v
@@ -106,17 +105,19 @@ class LoRAConstructor:
@staticmethod
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
'''
"""
extract LoraLinear model.
return OrderedDict(): name -> LoRAConfig
'''
"""
lora_config_dict = OrderedDict()
for name, child in model.named_modules():
if isinstance(child, LoraLinear):
lora_config_dict[name] = LoRAConfig(r=child.r,
lora_alpha=child.lora_alpha,
lora_dropout=child.lora_dropout,
fan_in_fan_out=child.fan_in_fan_out)
lora_config_dict[name] = LoRAConfig(
r=child.r,
lora_alpha=child.lora_alpha,
lora_dropout=child.lora_dropout,
fan_in_fan_out=child.fan_in_fan_out,
)
return lora_config_dict

View File

@@ -1,6 +1,6 @@
import os
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict
import torch
import torch.distributed as dist
@@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
def is_rank_0() -> bool:
@@ -26,13 +26,13 @@ def get_world_size() -> int:
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2':
if model == "gpt2":
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'bloom':
elif model == "bloom":
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'opt':
elif model == "opt":
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'llama':
elif model == "llama":
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
else:
raise ValueError(f'Unsupported actor model "{model}"')
@@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2':
if model == "gpt2":
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'bloom':
elif model == "bloom":
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'opt':
elif model == "opt":
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'llama':
elif model == "llama":
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{model}"')
@@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
if model == 'gpt2':
if model == "gpt2":
reward_model = GPTRM(pretrained=pretrained, config=config)
elif model == 'bloom':
elif model == "bloom":
reward_model = BLOOMRM(pretrained=pretrained, config=config)
elif model == 'opt':
elif model == "opt":
reward_model = OPTRM(pretrained=pretrained, config=config)
elif model == 'llama':
elif model == "llama":
reward_model = LlamaRM(pretrained=pretrained, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
@@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
def get_strategy_from_args(strategy: str):
if strategy == 'ddp':
if strategy == "ddp":
strategy_ = DDPStrategy()
elif strategy == 'colossalai_gemini':
strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
elif strategy == 'colossalai_zero2':
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif strategy == 'colossalai_gemini_cpu':
strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
elif strategy == 'colossalai_zero2_cpu':
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
elif strategy == "colossalai_gemini":
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_
def get_tokenizer_from_args(model: str, **kwargs):
if model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
elif model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif model == 'opt':
if model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
elif model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
elif model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif model == 'llama':
elif model == "llama":
pretrain_path = kwargs["pretrain"]
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
else:
@@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs):
def set_dist_env(env_info: Dict[str, str]):
os.environ["RANK"] = env_info['rank']
os.environ["LOCAL_RANK"] = env_info['local_rank']
os.environ["WORLD_SIZE"] = env_info['world_size']
os.environ['MASTER_PORT'] = env_info['master_port']
os.environ['MASTER_ADDR'] = env_info['master_addr']
os.environ["RANK"] = env_info["rank"]
os.environ["LOCAL_RANK"] = env_info["local_rank"]
os.environ["WORLD_SIZE"] = env_info["world_size"]
os.environ["MASTER_PORT"] = env_info["master_port"]
os.environ["MASTER_ADDR"] = env_info["master_addr"]
def get_model_numel(model: nn.Module) -> int:
@@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
return target_receivers
def state_dict_to(state_dict: Dict[str, Any],
dtype: torch.dtype = torch.float16,
device: torch.device = torch.device('cpu')):
'''
keep state_dict intact
'''
def state_dict_to(
state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
):
"""
keep state_dict intact
"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)