mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[chat] fix bugs and add unit tests (#4213)
* style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
This commit is contained in:
@@ -115,12 +115,12 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
||||
f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
|
||||
f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
||||
f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
|
||||
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
)
|
||||
|
||||
|
||||
@@ -204,9 +204,9 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
||||
f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
||||
f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
)
|
||||
|
@@ -6,9 +6,9 @@ from typing import Any, List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer import ExperienceBuffer
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker.base import Experience
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
# from torch.multiprocessing import Queue
|
||||
from ray.util.queue import Queue
|
||||
|
||||
|
@@ -4,8 +4,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer.utils import BufferItem
|
||||
from coati.experience_maker import Experience
|
||||
from coati.replay_buffer.utils import BufferItem
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@@ -8,9 +8,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
import ray
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import Strategy
|
||||
from coati.trainer.strategies.sampler import DistributedSampler
|
||||
@@ -19,13 +19,9 @@ from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
|
||||
from .utils import (get_model_numel,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
is_rank_0,
|
||||
set_dist_env,
|
||||
state_dict_to)
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||
class ExperienceMakerHolder:
|
||||
@@ -41,7 +37,7 @@ class ExperienceMakerHolder:
|
||||
self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||
env_info: Dict[str, str] = None,
|
||||
sync_models_from_trainers: bool = False,
|
||||
@@ -205,15 +201,19 @@ class ExperienceMakerHolder:
|
||||
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
|
||||
else:
|
||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
|
||||
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase)
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
|
||||
new_actor_state_dict, new_actor_lora_config_dict)
|
||||
self.actor_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.actor.model, state_dict_increase)
|
||||
if new_critic_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||
else:
|
||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
|
||||
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase)
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
|
||||
new_critic_state_dict, new_critic_lora_config_dict)
|
||||
self.critic_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.critic, state_dict_increase)
|
||||
|
||||
# the lock must be released after both actor and critic being updated
|
||||
if chunk_end:
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loralib.layers import LoRALayer
|
||||
from coati.models.lora import LoraLinear
|
||||
from loralib.layers import LoRALayer
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,19 +23,19 @@ class LoRAConstructor:
|
||||
Usage:
|
||||
Step 1 (Sender):
|
||||
filter_state_dict_lora()
|
||||
|
||||
|
||||
Step 2 (Sender, Optional):
|
||||
extract_lora_config()
|
||||
|
||||
|
||||
Step 3 (Sender):
|
||||
send state_dict_lora and lora_config_dict
|
||||
|
||||
|
||||
Step 4 (Receiver):
|
||||
reconstruct_increase()
|
||||
|
||||
|
||||
Step 5 (Receiver):
|
||||
load_state_dict_increase()
|
||||
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
|
Reference in New Issue
Block a user