[chat] fix bugs and add unit tests (#4213)

* style: rename replay buffer

Experience replay is typically for off policy algorithms.
Use this name in PPO maybe misleading.

* fix: fix wrong zero2 default arg

* test: update experience tests

* style: rename zero_pad fn

* fix: defer init in CycledDataLoader

* test: add benchmark test

* style: rename internal fn of generation

* style: rename internal fn of lora

* fix: remove unused loss fn

* fix: remove unused utils fn

* refactor: remove generate_with_actor fn

* fix: fix type annotation

* test: add models tests

* fix: skip llama due to long execution time

* style: modify dataset

* style: apply formatter

* perf: update reward dataset

* fix: fix wrong IGNORE_INDEX in sft dataset

* fix: remove DataCollatorForSupervisedDataset

* test: add dataset tests

* style: apply formatter

* style: rename test_ci to test_train

* feat: add llama in inference

* test: add inference tests

* test: change test scripts directory

* fix: update ci

* fix: fix typo

* fix: skip llama due to oom

* fix: fix file mod

* style: apply formatter

* refactor: remove duplicated llama_gptq

* style: apply formatter

* to: update rm test

* feat: add tokenizer arg

* feat: add download model script

* test: update train tests

* fix: modify gemini load and save pretrained

* test: update checkpoint io test

* to: modify nproc_per_node

* fix: do not remove existing dir

* fix: modify save path

* test: add random choice

* fix: fix sft path

* fix: enlarge nproc_per_node to avoid oom

* fix: add num_retry

* fix: make lora config of rm and critic consistent

* fix: add warning about lora weights

* fix: skip some gpt2 tests

* fix: remove grad ckpt in rm and critic due to errors

* refactor: directly use Actor in train_sft

* test: add more arguments

* fix: disable grad ckpt when using lora

* fix: fix save_pretrained and related tests

* test: enable zero2 tests

* revert: remove useless fn

* style: polish code

* test: modify test args
This commit is contained in:
Wenhao Chen
2023-08-02 10:17:36 +08:00
committed by GitHub
parent 16bf4c0221
commit da4f7b855f
62 changed files with 1404 additions and 1202 deletions

View File

@@ -115,12 +115,12 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)
@@ -204,9 +204,9 @@ class TrainerPerformanceEvaluator(TrainerCallback):
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)

View File

@@ -6,9 +6,9 @@ from typing import Any, List
import ray
import torch
from coati.experience_buffer import ExperienceBuffer
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker.base import Experience
from coati.replay_buffer import ReplayBuffer
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
# from torch.multiprocessing import Queue
from ray.util.queue import Queue

View File

@@ -4,8 +4,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import ray
import torch
from coati.experience_buffer.utils import BufferItem
from coati.experience_maker import Experience
from coati.replay_buffer.utils import BufferItem
from torch.utils.data import DataLoader
from tqdm import tqdm

View File

@@ -8,9 +8,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import ray
import torch
import torch.nn as nn
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy
from coati.trainer.strategies.sampler import DistributedSampler
@@ -19,13 +19,9 @@ from torch import Tensor
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .utils import (get_model_numel,
get_rank,
get_world_size,
is_rank_0,
set_dist_env,
state_dict_to)
from .lora_constructor import LoRAConstructor
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
@@ -41,7 +37,7 @@ class ExperienceMakerHolder:
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model)
# a function returns (actor, critic, reward_model, initial_model)
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
@@ -205,15 +201,19 @@ class ExperienceMakerHolder:
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase)
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
new_actor_state_dict, new_actor_lora_config_dict)
self.actor_lora_constructor.load_state_dict_increase(
self.experience_maker.actor.model, state_dict_increase)
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase)
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
new_critic_state_dict, new_critic_lora_config_dict)
self.critic_lora_constructor.load_state_dict_increase(
self.experience_maker.critic, state_dict_increase)
# the lock must be released after both actor and critic being updated
if chunk_end:

View File

@@ -1,11 +1,11 @@
from typing import Any, Callable, Dict, List, Optional
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.nn as nn
from loralib.layers import LoRALayer
from coati.models.lora import LoraLinear
from loralib.layers import LoRALayer
@dataclass
@@ -23,19 +23,19 @@ class LoRAConstructor:
Usage:
Step 1 (Sender):
filter_state_dict_lora()
Step 2 (Sender, Optional):
extract_lora_config()
Step 3 (Sender):
send state_dict_lora and lora_config_dict
Step 4 (Receiver):
reconstruct_increase()
Step 5 (Receiver):
load_state_dict_increase()
'''
def __init__(self):