mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
175
applications/ColossalChat/coati/ray/README.md
Executable file
175
applications/ColossalChat/coati/ray/README.md
Executable file
@@ -0,0 +1,175 @@
|
||||
:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
|
||||
|
||||
# Distributed PPO Training on Stage 3
|
||||
|
||||
## Detach Experience Makers and Trainers
|
||||
|
||||
We can completely separate the trainers and makers.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/basic_structure.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
- The experience maker performs inference, produces experience, and remotely delivers it to the trainer (1).
|
||||
- The trainer consumes experience to train models, and periodically transmits new model parameters to the maker (2.1, 2.2).
|
||||
- Using an experience buffer to overlap transmission and computing.
|
||||
|
||||
In this manner, each node will work continuously without model idle time, and different optimization strategies can be applied for inference and training to meet the needs of speed or storage. It is also helpful for scalability.
|
||||
|
||||
`DetachedPPOTrainer` and `ExperienceMakerHolder` are Ray Actors (distinguished from Actor Model), representing Trainer and Experience Maker on the graph above, respectively.
|
||||
|
||||
[More about Ray Core](https://docs.ray.io/en/latest/ray-core/walkthrough.html)
|
||||
|
||||
## Usage
|
||||
|
||||
See examples at `ColossalAI/application/Chat/examples/ray`
|
||||
|
||||
### Setup Makers
|
||||
|
||||
- define makers' environment variables :
|
||||
|
||||
```python
|
||||
env_info_makers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(num_makers),
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(num_makers)]
|
||||
|
||||
```
|
||||
|
||||
- define maker models :
|
||||
|
||||
```python
|
||||
def model_fn():
|
||||
actor = get_actor_from_args(...)
|
||||
critic = get_critic_from_args(...)
|
||||
reward_model = get_reward_model_from_args(...)
|
||||
initial_model = get_actor_from_args(...)
|
||||
return actor, critic, reward_model, initial_model
|
||||
|
||||
```
|
||||
|
||||
- set experience_holder_refs :
|
||||
|
||||
```python
|
||||
experience_holder_refs = [
|
||||
ExperienceMakerHolder.options(
|
||||
name=f"maker_{i}",
|
||||
num_gpus=1,
|
||||
max_concurrency=2
|
||||
).remote(
|
||||
detached_trainer_name_list=[f"trainer_{x}" for x in target_trainers(...)],
|
||||
model_fn=model_fn,
|
||||
...)
|
||||
for i, env_info_maker in enumerate(env_info_makers)
|
||||
]
|
||||
```
|
||||
|
||||
The names in the `detached_trainer_name_list` refer to the target trainers that the maker should send experience to.
|
||||
We set a trainer's name the same as a maker, by `.options(name="str")`. See below.
|
||||
|
||||
### Setup Trainers
|
||||
|
||||
- define trainers' environment variables :
|
||||
```python
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(num_trainers)]
|
||||
```
|
||||
- define trainer models :
|
||||
|
||||
```python
|
||||
def trainer_model_fn():
|
||||
actor = get_actor_from_args(...)
|
||||
critic = get_critic_from_args(...)
|
||||
return actor, critic
|
||||
```
|
||||
|
||||
- set trainer_refs :
|
||||
```python
|
||||
trainer_refs = [
|
||||
DetachedPPOTrainer.options(
|
||||
name=f"trainer{i}",
|
||||
num_gpus=1,
|
||||
max_concurrency=2
|
||||
).remote(
|
||||
experience_maker_holder_name_list=[f"maker{x}" for x in target_makers(...)],
|
||||
model_fn = trainer_model_fn(),
|
||||
...)
|
||||
for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
]
|
||||
```
|
||||
The names in `experience_maker_holder_name_list` refer to the target makers that the trainer should send updated models to.
|
||||
By setting `detached_trainer_name_list` and `experience_maker_holder_name_list`, we can customize the transmission graph.
|
||||
|
||||
### Launch Jobs
|
||||
|
||||
- define data_loader :
|
||||
|
||||
```python
|
||||
def data_loader_fn():
|
||||
return = torch.utils.data.DataLoader(dataset=dataset)
|
||||
|
||||
```
|
||||
|
||||
- launch makers :
|
||||
|
||||
```python
|
||||
wait_tasks = []
|
||||
for experience_holder_ref in experience_holder_refs:
|
||||
wait_tasks.append(
|
||||
experience_holder_ref.workingloop.remote(data_loader_fn(),
|
||||
num_steps=experience_steps))
|
||||
|
||||
```
|
||||
|
||||
- launch trainers :
|
||||
|
||||
```python
|
||||
for trainer_ref in trainer_refs:
|
||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, update_steps, train_epochs))
|
||||
```
|
||||
|
||||
- wait for done :
|
||||
```python
|
||||
ray.get(wait_tasks)
|
||||
```
|
||||
|
||||
## Flexible Structure
|
||||
|
||||
We can deploy different strategies to makers and trainers. Here are some notions.
|
||||
|
||||
### 2 Makers 1 Trainer
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/2m1t.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
### 2 Makers 2 Trainer
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/2m2t.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
### Maker Inference Quantization
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/2m2t_quantize.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
### Tensor Parallel
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/applications/chat/tp_ddp_hybrid.png?raw=true" width=600/>
|
||||
</p>
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Support LoRA
|
||||
- [ ] Support TP & PP
|
0
applications/ColossalChat/coati/ray/__init__.py
Executable file
0
applications/ColossalChat/coati/ray/__init__.py
Executable file
9
applications/ColossalChat/coati/ray/callbacks/__init__.py
Executable file
9
applications/ColossalChat/coati/ray/callbacks/__init__.py
Executable file
@@ -0,0 +1,9 @@
|
||||
from .base import MakerCallback, TrainerCallback
|
||||
from .performance_evaluator import ExperienceMakerPerformanceEvaluator, TrainerPerformanceEvaluator
|
||||
|
||||
__all__ = [
|
||||
"TrainerCallback",
|
||||
"MakerCallback",
|
||||
"ExperienceMakerPerformanceEvaluator",
|
||||
"TrainerPerformanceEvaluator",
|
||||
]
|
65
applications/ColossalChat/coati/ray/callbacks/base.py
Executable file
65
applications/ColossalChat/coati/ray/callbacks/base.py
Executable file
@@ -0,0 +1,65 @@
|
||||
from abc import ABC
|
||||
|
||||
from coati.experience_maker import Experience
|
||||
|
||||
|
||||
class TrainerCallback(ABC):
|
||||
"""
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
"""
|
||||
|
||||
def on_fit_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
pass
|
||||
|
||||
def on_episode_start(self, episode: int) -> None:
|
||||
pass
|
||||
|
||||
def on_episode_end(self, episode: int) -> None:
|
||||
pass
|
||||
|
||||
def on_epoch_start(self, epoch: int) -> None:
|
||||
pass
|
||||
|
||||
def on_epoch_end(self, epoch: int) -> None:
|
||||
pass
|
||||
|
||||
def on_batch_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
pass
|
||||
|
||||
def on_update_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_update_end(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MakerCallback(ABC):
|
||||
def on_loop_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_loop_end(self) -> None:
|
||||
pass
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_make_experience_end(self, experience: Experience) -> None:
|
||||
pass
|
||||
|
||||
def on_send_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_send_end(self) -> None:
|
||||
pass
|
||||
|
||||
def on_batch_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_batch_end(self) -> None:
|
||||
pass
|
214
applications/ColossalChat/coati/ray/callbacks/performance_evaluator.py
Executable file
214
applications/ColossalChat/coati/ray/callbacks/performance_evaluator.py
Executable file
@@ -0,0 +1,214 @@
|
||||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.experience_maker import Experience
|
||||
|
||||
from .base import MakerCallback, TrainerCallback
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
return 1
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
|
||||
def end(self) -> None:
|
||||
self.duration += time() - self.start_time
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.0
|
||||
|
||||
|
||||
class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
def __init__(
|
||||
self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
self.critic_num_params = critic_num_params
|
||||
self.initial_model_num_params = initial_model_num_params
|
||||
self.reward_model_num_params = reward_model_num_params
|
||||
|
||||
self.batch_timer = Timer()
|
||||
self.send_timer = Timer()
|
||||
self.make_experience_timer = Timer()
|
||||
self.total_samples: int = 0
|
||||
self.make_experience_flop: int = 0
|
||||
|
||||
print_rank_0(
|
||||
f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
|
||||
)
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
self.make_experience_timer.start()
|
||||
|
||||
def on_make_experience_end(self, experience: Experience) -> None:
|
||||
self.make_experience_timer.end()
|
||||
|
||||
batch_size, seq_len = experience.sequences.shape
|
||||
|
||||
self.total_samples += batch_size
|
||||
|
||||
# actor generate
|
||||
num_actions = experience.action_mask.size(1)
|
||||
input_len = seq_len - num_actions
|
||||
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
|
||||
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
|
||||
# actor forward
|
||||
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
|
||||
# critic forward
|
||||
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
|
||||
# initial model forward
|
||||
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
|
||||
# reward model forward
|
||||
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
|
||||
|
||||
def on_send_start(self) -> None:
|
||||
self.send_timer.start()
|
||||
|
||||
def on_send_end(self) -> None:
|
||||
self.send_timer.end()
|
||||
|
||||
def on_batch_start(self) -> None:
|
||||
self.batch_timer.start()
|
||||
|
||||
def on_batch_end(self) -> None:
|
||||
self.batch_timer.end()
|
||||
|
||||
def on_loop_end(self) -> None:
|
||||
avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size)
|
||||
avg_overall_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
||||
avg_send_duration = all_reduce_mean(self.send_timer.duration, self.world_size)
|
||||
|
||||
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
|
||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
||||
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
|
||||
self.total_samples * self.world_size
|
||||
)
|
||||
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
"Making Experience Performance Summary:\n"
|
||||
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
|
||||
+ f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
|
||||
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
|
||||
+ f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
+ f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
)
|
||||
|
||||
|
||||
class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
def __init__(
|
||||
self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_first_episodes: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
self.critic_num_params = critic_num_params
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_first_episodes = ignore_first_episodes
|
||||
self.ignore_this_episode = False
|
||||
|
||||
self.episode_timer = Timer()
|
||||
self.batch_timer = Timer()
|
||||
self.update_timer = Timer()
|
||||
self.total_samples: int = 0
|
||||
self.learn_flop: int = 0
|
||||
|
||||
print_rank_0(
|
||||
f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
|
||||
)
|
||||
|
||||
def on_episode_start(self, episodes: int) -> None:
|
||||
self.ignore_this_episode = episodes < self.ignore_first_episodes
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.episode_timer.start()
|
||||
|
||||
def on_episode_end(self, episodes: int) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.episode_timer.end()
|
||||
|
||||
def on_batch_start(self) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.batch_timer.start()
|
||||
|
||||
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.batch_timer.end()
|
||||
|
||||
batch_size, seq_len = experience.sequences.shape
|
||||
|
||||
self.total_samples += batch_size
|
||||
|
||||
# actor forward-backward, 3 means forward(1) + backward(2)
|
||||
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
# critic forward-backward
|
||||
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
|
||||
def on_update_start(self) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.update_timer.start()
|
||||
|
||||
def on_update_end(self) -> None:
|
||||
if self.ignore_this_episode:
|
||||
return
|
||||
self.update_timer.end()
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
if self.total_samples == 0:
|
||||
print_rank_0("No samples are collected, skip trainer performance evaluation")
|
||||
return
|
||||
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
||||
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
|
||||
avg_episode_duration = all_reduce_mean(self.episode_timer.duration, self.world_size)
|
||||
|
||||
avg_throughput = self.total_samples * self.world_size / (avg_episode_duration + 1e-12)
|
||||
avg_learn_tflops = self.learn_flop / 1e12 / (avg_train_duration + 1e-12)
|
||||
avg_time_per_sample = (avg_episode_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
avg_train_time_per_sample = (avg_train_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
"Learning Performance Summary:\n"
|
||||
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
|
||||
+ f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
|
||||
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
|
||||
+ f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
+ f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
)
|
70
applications/ColossalChat/coati/ray/detached_replay_buffer.py
Executable file
70
applications/ColossalChat/coati/ray/detached_replay_buffer.py
Executable file
@@ -0,0 +1,70 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker.base import Experience
|
||||
|
||||
# from torch.multiprocessing import Queue
|
||||
from ray.util.queue import Queue
|
||||
|
||||
|
||||
class DetachedReplayBuffer:
|
||||
"""
|
||||
Detached replay buffer. Share Experience across workers on the same node.
|
||||
Therefore, a trainer node is expected to have only one instance.
|
||||
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
|
||||
|
||||
Args:
|
||||
sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch.
|
||||
tp_world_size: Number of workers in the same tp group
|
||||
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
|
||||
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.limit = limit
|
||||
self.items = Queue(self.limit, actor_options={"num_cpus": 1})
|
||||
self.batch_collector: List[BufferItem] = []
|
||||
|
||||
@torch.no_grad()
|
||||
def append(self, experience: Experience) -> None:
|
||||
"""
|
||||
Expected to be called remotely.
|
||||
"""
|
||||
items = split_experience_batch(experience)
|
||||
self.extend(items)
|
||||
|
||||
@torch.no_grad()
|
||||
def extend(self, items: List[BufferItem]) -> None:
|
||||
"""
|
||||
Expected to be called remotely.
|
||||
"""
|
||||
self.batch_collector.extend(items)
|
||||
while len(self.batch_collector) >= self.sample_batch_size:
|
||||
items = self.batch_collector[: self.sample_batch_size]
|
||||
experience = make_experience_batch(items)
|
||||
self.items.put(experience, block=True)
|
||||
self.batch_collector = self.batch_collector[self.sample_batch_size :]
|
||||
|
||||
def clear(self) -> None:
|
||||
# self.items.close()
|
||||
self.items.shutdown()
|
||||
self.items = Queue(self.limit)
|
||||
self.worker_state = [False] * self.tp_world_size
|
||||
self.batch_collector = []
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, worker_rank=0, to_device="cpu") -> Experience:
|
||||
ret = self._sample_and_erase()
|
||||
ret.to_device(to_device)
|
||||
return ret
|
||||
|
||||
@torch.no_grad()
|
||||
def _sample_and_erase(self) -> Experience:
|
||||
ret = self.items.get(block=True)
|
||||
return ret
|
||||
|
||||
def get_length(self) -> int:
|
||||
ret = self.items.qsize()
|
||||
return ret
|
179
applications/ColossalChat/coati/ray/detached_trainer_base.py
Executable file
179
applications/ColossalChat/coati/ray/detached_trainer_base.py
Executable file
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer.utils import BufferItem
|
||||
from coati.experience_maker import Experience
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import TrainerCallback
|
||||
from .detached_replay_buffer import DetachedReplayBuffer
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
class DetachedTrainer(ABC):
|
||||
"""
|
||||
Base class for detached rlhf trainers.
|
||||
'detach' means that the experience maker is detached compared to a normal Trainer.
|
||||
Please set name attribute during init:
|
||||
>>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
|
||||
So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
|
||||
Args:
|
||||
detached_strategy (DetachedStrategy): the strategy to use for training
|
||||
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
|
||||
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[TrainerCallback] = [],
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
self.target_holder_name_list = experience_maker_holder_name_list
|
||||
self.target_holder_list = []
|
||||
self._is_target_holder_initialized = False
|
||||
self._debug = debug
|
||||
|
||||
def update_target_holder_list(self):
|
||||
# as the length of target_holder_list may be zero, we need to check it by a bool flag
|
||||
if not self._is_target_holder_initialized:
|
||||
for name in self.target_holder_name_list:
|
||||
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||||
self._is_target_holder_initialized = True
|
||||
|
||||
@abstractmethod
|
||||
def _update_remote_makers(self, fully_update: bool = False, **kwargs):
|
||||
pass
|
||||
|
||||
def sync_models_to_remote_makers(self, **kwargs):
|
||||
self._update_remote_makers(fully_update=True, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def _learn(self, update_steps: int, train_epochs: int) -> None:
|
||||
data = []
|
||||
# warmup
|
||||
pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
|
||||
self._on_epoch_start(0)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(0)
|
||||
# item is already a batch
|
||||
dataloader = DataLoader(
|
||||
data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
|
||||
)
|
||||
for epoch in range(1, train_epochs):
|
||||
pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
|
||||
self._on_epoch_start(epoch)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(epoch)
|
||||
|
||||
def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
|
||||
is_warmup = len(data) == 0
|
||||
for x in pbar:
|
||||
if self._debug:
|
||||
print("[trainer] training step")
|
||||
# sample a batch and then train to avoid waiting
|
||||
experience = x if not is_warmup else self._buffer_sample()
|
||||
experience.to_device(torch.cuda.current_device())
|
||||
self._on_batch_start()
|
||||
metrics = self.training_step(experience)
|
||||
self._on_batch_end(metrics, experience)
|
||||
|
||||
if self._debug:
|
||||
print("[trainer] step over")
|
||||
experience.to_device("cpu")
|
||||
if is_warmup:
|
||||
data.append(experience)
|
||||
pbar.set_postfix(metrics)
|
||||
|
||||
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
|
||||
self._on_fit_start()
|
||||
for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
|
||||
self._on_episode_start(i)
|
||||
self._learn(update_steps, train_epochs)
|
||||
self._on_update_start()
|
||||
self._update_remote_makers()
|
||||
self._on_update_end()
|
||||
self._on_episode_end(i)
|
||||
self._on_fit_end()
|
||||
|
||||
@ray.method(concurrency_group="buffer_length")
|
||||
def buffer_get_length(self):
|
||||
# called by ExperienceMakerHolder
|
||||
if self._debug:
|
||||
print("[trainer] telling length")
|
||||
return self.detached_replay_buffer.get_length()
|
||||
|
||||
@ray.method(concurrency_group="buffer_append")
|
||||
def buffer_append(self, experience: Experience):
|
||||
# called by ExperienceMakerHolder
|
||||
if self._debug:
|
||||
print(f"[trainer] receiving exp.")
|
||||
self.detached_replay_buffer.append(experience)
|
||||
|
||||
@ray.method(concurrency_group="buffer_append")
|
||||
def buffer_extend(self, items: List[BufferItem]):
|
||||
# called by ExperienceMakerHolder
|
||||
if self._debug:
|
||||
print(f"[trainer] receiving exp.")
|
||||
self.detached_replay_buffer.extend(items)
|
||||
|
||||
@ray.method(concurrency_group="buffer_sample")
|
||||
def _buffer_sample(self):
|
||||
return self.detached_replay_buffer.sample()
|
||||
|
||||
def _on_fit_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_start()
|
||||
|
||||
def _on_fit_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_end()
|
||||
|
||||
def _on_episode_start(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_start(episode)
|
||||
|
||||
def _on_episode_end(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_end(episode)
|
||||
|
||||
def _on_epoch_start(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_start(epoch)
|
||||
|
||||
def _on_epoch_end(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_end(epoch)
|
||||
|
||||
def _on_batch_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_start()
|
||||
|
||||
def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_end(metrics, experience)
|
||||
|
||||
def _on_update_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_update_start()
|
||||
|
||||
def _on_update_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_update_end()
|
191
applications/ColossalChat/coati/ray/detached_trainer_ppo.py
Executable file
191
applications/ColossalChat/coati/ray/detached_trainer_ppo.py
Executable file
@@ -0,0 +1,191 @@
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_maker import Experience
|
||||
from coati.models.base import Actor, Critic
|
||||
from coati.models.loss import PolicyLoss, ValueLoss
|
||||
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
|
||||
from .detached_trainer_base import DetachedTrainer
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
|
||||
|
||||
|
||||
@ray.remote(
|
||||
concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
|
||||
)
|
||||
class DetachedPPOTrainer(DetachedTrainer):
|
||||
"""
|
||||
Detached Trainer for PPO algorithm
|
||||
Args:
|
||||
strategy (Strategy): the strategy to use for training
|
||||
model (str) : for actor / critic init
|
||||
pretrained (str) : for actor / critic init
|
||||
lora_rank (int) : for actor / critic init
|
||||
train_batch_size (int, defaults to 8): the batch size to use for training
|
||||
train_batch_size (int, defaults to 8): the batch size to use for training
|
||||
buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
|
||||
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
|
||||
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
||||
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
||||
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
model_fn: Callable[[], Tuple[Actor, Critic]],
|
||||
env_info: Dict[str, str] = None,
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
eps_clip: float = 0.2,
|
||||
value_clip: float = 0.4,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[TrainerCallback] = [],
|
||||
eval_performance: bool = False,
|
||||
debug: bool = False,
|
||||
update_lora_weights: bool = False,
|
||||
) -> None:
|
||||
# set environment variables
|
||||
if env_info:
|
||||
set_dist_env(env_info=env_info)
|
||||
# configure strategy
|
||||
self.strategy = strategy_fn()
|
||||
# configure models, loss and optimizers
|
||||
with self.strategy.model_init_context():
|
||||
self.actor, self.critic = model_fn()
|
||||
|
||||
if eval_performance:
|
||||
actor_numel = get_model_numel(self.actor)
|
||||
critic_numel = get_model_numel(self.critic)
|
||||
evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
|
||||
callbacks = callbacks + [evaluator]
|
||||
|
||||
if isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)):
|
||||
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
|
||||
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
|
||||
else:
|
||||
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
|
||||
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
|
||||
|
||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
|
||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim)
|
||||
)
|
||||
|
||||
# configure trainer
|
||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||
self.critic_loss_fn = ValueLoss(value_clip)
|
||||
|
||||
super().__init__(
|
||||
experience_maker_holder_name_list,
|
||||
train_batch_size=train_batch_size,
|
||||
buffer_limit=buffer_limit,
|
||||
dataloader_pin_memory=dataloader_pin_memory,
|
||||
callbacks=callbacks,
|
||||
debug=debug,
|
||||
)
|
||||
if self._debug:
|
||||
print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
|
||||
|
||||
self._update_lora_weights = update_lora_weights
|
||||
|
||||
@ray.method(concurrency_group="model_io")
|
||||
@torch.no_grad()
|
||||
def _update_remote_makers(self, fully_update: bool = False, **config):
|
||||
# TODO: balance duties
|
||||
if not fully_update:
|
||||
config["requires_grad_only"] = True
|
||||
self.update_target_holder_list()
|
||||
# mark start, ensure order
|
||||
tasks = []
|
||||
for target_holder in self.target_holder_list:
|
||||
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
|
||||
ray.get(tasks)
|
||||
# sending loop
|
||||
tasks = []
|
||||
|
||||
for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update=fully_update, **config):
|
||||
for target_holder in self.target_holder_list:
|
||||
tasks.append(
|
||||
target_holder.update_experience_maker.remote(
|
||||
new_actor_state_dict=state_dict_shard,
|
||||
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
|
||||
fully_update=fully_update,
|
||||
)
|
||||
)
|
||||
# sending loop
|
||||
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
|
||||
for target_holder in self.target_holder_list:
|
||||
tasks.append(
|
||||
target_holder.update_experience_maker.remote(
|
||||
new_critic_state_dict=state_dict_shard,
|
||||
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
|
||||
fully_update=fully_update,
|
||||
)
|
||||
)
|
||||
ray.get(tasks)
|
||||
# mark end
|
||||
for target_holder in self.target_holder_list:
|
||||
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
|
||||
|
||||
@ray.method(concurrency_group="compute")
|
||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
|
||||
num_actions = experience.action_mask.size(1)
|
||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||
actor_loss = self.actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
)
|
||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||
self.strategy.optimizer_step(self.actor_optim)
|
||||
self.actor_optim.zero_grad()
|
||||
|
||||
values = self.critic(
|
||||
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
||||
)
|
||||
critic_loss = self.critic_loss_fn(
|
||||
values, experience.values, experience.reward, action_mask=experience.action_mask
|
||||
)
|
||||
|
||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||
self.strategy.optimizer_step(self.critic_optim)
|
||||
self.critic_optim.zero_grad()
|
||||
return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
|
||||
|
||||
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
|
||||
self.strategy.save_model(self.actor, path, only_rank0)
|
||||
|
||||
def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None:
|
||||
self.strategy.save_model(self.critic, path, only_rank0)
|
||||
|
||||
def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None:
|
||||
self.strategy.save_optimizer(self.actor_optim, path, only_rank0)
|
||||
|
||||
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
|
||||
self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
|
||||
|
||||
def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update=False, **config):
|
||||
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
|
||||
if not self._update_lora_weights or fully_update:
|
||||
yield state_dict_to(state_dict)
|
||||
else:
|
||||
state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict)
|
||||
yield state_dict_to(state_dict_lora)
|
||||
|
||||
def _get_model_lora_config_dict(self, model: torch.nn.Module):
|
||||
if not self._update_lora_weights:
|
||||
return None
|
||||
unwrapped_model = self.strategy.unwrap_model(model)
|
||||
return LoRAConstructor.extract_lora_config(unwrapped_model)
|
274
applications/ColossalChat/coati/ray/experience_maker_holder.py
Executable file
274
applications/ColossalChat/coati/ray/experience_maker_holder.py
Executable file
@@ -0,0 +1,274 @@
|
||||
import os
|
||||
import time
|
||||
import tracemalloc
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer.utils import split_experience_batch
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.trainer.strategies import Strategy
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||
class ExperienceMakerHolder:
|
||||
"""
|
||||
Args:
|
||||
detached_trainer_name_list: str list to get ray actor handles
|
||||
strategy:
|
||||
kl_coef: the coefficient of kl divergence loss
|
||||
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||
env_info: Dict[str, str] = None,
|
||||
sync_models_from_trainers: bool = False,
|
||||
buffer_cpu_offload: bool = True,
|
||||
kl_coef: float = 0.1,
|
||||
callbacks: List[MakerCallback] = [],
|
||||
eval_performance: bool = False,
|
||||
debug: bool = False,
|
||||
update_lora_weights: bool = False,
|
||||
**generate_kwargs,
|
||||
):
|
||||
# set environment variables
|
||||
if env_info:
|
||||
set_dist_env(env_info=env_info)
|
||||
self.target_trainer_list = []
|
||||
assert len(detached_trainer_name_list) > 0
|
||||
self._detached_trainer_name_list = detached_trainer_name_list
|
||||
self.strategy = strategy_fn()
|
||||
self.buffer_cpu_offload = buffer_cpu_offload
|
||||
self.kl_coef = kl_coef
|
||||
# init models
|
||||
with self.strategy.model_init_context():
|
||||
actor, critic, reward_model, initial_model = model_fn()
|
||||
self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor)
|
||||
if eval_performance:
|
||||
actor_numel = get_model_numel(actor)
|
||||
critic_numel = get_model_numel(critic)
|
||||
initial_model_numel = get_model_numel(initial_model)
|
||||
reward_model_numel = get_model_numel(reward_model)
|
||||
evaluator = ExperienceMakerPerformanceEvaluator(
|
||||
actor_numel, critic_numel, initial_model_numel, reward_model_numel
|
||||
)
|
||||
callbacks = callbacks + [evaluator]
|
||||
|
||||
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
|
||||
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
|
||||
self.callbacks = callbacks
|
||||
|
||||
self._model_visit_lock = Lock()
|
||||
|
||||
self._is_fully_initialized = not sync_models_from_trainers
|
||||
|
||||
self._debug = debug
|
||||
self._update_lora_weights = update_lora_weights
|
||||
if self._update_lora_weights:
|
||||
self.actor_lora_constructor = LoRAConstructor()
|
||||
self.critic_lora_constructor = LoRAConstructor()
|
||||
|
||||
self.target_auto_balance = False
|
||||
|
||||
self._target_idx = 0
|
||||
|
||||
if self._debug:
|
||||
print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
|
||||
if not self._is_fully_initialized:
|
||||
print(f"[maker{get_rank()}] Waiting for INIT")
|
||||
|
||||
def _get_ready(self):
|
||||
while not self._fully_initialized():
|
||||
time.sleep(1.0)
|
||||
|
||||
def _fully_initialized(self):
|
||||
return self._is_fully_initialized
|
||||
|
||||
def _init_target_trainer_list(self):
|
||||
if len(self.target_trainer_list) > 0:
|
||||
return
|
||||
for name in self._detached_trainer_name_list:
|
||||
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||||
|
||||
# copy from ../trainer/base.py
|
||||
@ray.method(concurrency_group="compute")
|
||||
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||||
if isinstance(inputs, Tensor):
|
||||
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||||
elif isinstance(inputs, dict):
|
||||
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||||
|
||||
@ray.method(concurrency_group="experience_io")
|
||||
def _send_items(self, experience: Experience) -> None:
|
||||
self._init_target_trainer_list()
|
||||
items = split_experience_batch(experience)
|
||||
items_per_trainer = [[] for _ in range(len(self.target_trainer_list))]
|
||||
for item in items:
|
||||
items_per_trainer[self._target_idx].append(item)
|
||||
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
|
||||
for i, target_trainer in enumerate(self.target_trainer_list):
|
||||
if len(items_per_trainer[i]) > 0:
|
||||
target_trainer.buffer_extend.remote(items_per_trainer[i])
|
||||
|
||||
def _inference_step(self, batch) -> None:
|
||||
self._on_batch_start()
|
||||
with self._model_visit_lock:
|
||||
self._on_make_experience_start()
|
||||
experience = self._make_experience(batch)
|
||||
self._on_make_experience_end(experience)
|
||||
self._on_send_start()
|
||||
if self.buffer_cpu_offload:
|
||||
experience.to_device("cpu")
|
||||
self._send_items(experience)
|
||||
self._on_send_end()
|
||||
self._on_batch_end()
|
||||
|
||||
def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0):
|
||||
"""Working loop of the experience maker.
|
||||
|
||||
Args:
|
||||
dataloader_fn (Callable[[], Iterable]): A function that returns a dataloader.
|
||||
num_epochs (int, optional): Iterate the dataloader for number of epochs. Defaults to 1.
|
||||
num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0.
|
||||
"""
|
||||
self._get_ready()
|
||||
self._on_loop_start()
|
||||
dataloader = dataloader_fn()
|
||||
if num_steps > 0:
|
||||
# ignore num epochs
|
||||
it = iter(dataloader)
|
||||
for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
|
||||
try:
|
||||
batch = next(it)
|
||||
except StopIteration:
|
||||
it = iter(dataloader)
|
||||
batch = next(it)
|
||||
self._inference_step(batch)
|
||||
else:
|
||||
with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
|
||||
for _ in range(num_epochs):
|
||||
for batch in dataloader:
|
||||
self._inference_step(batch)
|
||||
pbar.update()
|
||||
self._on_loop_end()
|
||||
|
||||
@ray.method(concurrency_group="model_io")
|
||||
def update_experience_maker(
|
||||
self,
|
||||
new_actor_state_dict: Dict[str, Any] = None,
|
||||
new_actor_lora_config_dict: Dict[str, Any] = None,
|
||||
new_critic_state_dict: Dict[str, Any] = None,
|
||||
new_critic_lora_config_dict: Dict[str, Any] = None,
|
||||
fully_update: bool = False,
|
||||
chunk_start: bool = None,
|
||||
chunk_end: bool = None,
|
||||
):
|
||||
"""
|
||||
called by trainer
|
||||
chunk_start: Set True at the first call. Before sending state_dict calls
|
||||
chunk_end: Set True at the last call. After sending state_dict calls.
|
||||
fully_update: Set True if you want to sync models when initializing
|
||||
|
||||
TODO: load_state_dict integrate with model-sharding strategy
|
||||
"""
|
||||
_watch_memory = self._debug
|
||||
if chunk_start:
|
||||
if self._debug:
|
||||
print("[maker] UPDATE ")
|
||||
if _watch_memory:
|
||||
tracemalloc.start()
|
||||
self._model_visit_lock.acquire()
|
||||
|
||||
with torch.no_grad():
|
||||
if new_actor_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
|
||||
else:
|
||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
|
||||
new_actor_state_dict, new_actor_lora_config_dict
|
||||
)
|
||||
self.actor_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.actor.model, state_dict_increase
|
||||
)
|
||||
if new_critic_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||
else:
|
||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
|
||||
new_critic_state_dict, new_critic_lora_config_dict
|
||||
)
|
||||
self.critic_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.critic, state_dict_increase
|
||||
)
|
||||
|
||||
# the lock must be released after both actor and critic being updated
|
||||
if chunk_end:
|
||||
self._model_visit_lock.release()
|
||||
if _watch_memory:
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
|
||||
tracemalloc.stop()
|
||||
if fully_update:
|
||||
self._is_fully_initialized = True
|
||||
|
||||
def _on_make_experience_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_make_experience_start()
|
||||
|
||||
def _on_make_experience_end(self, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_make_experience_end(experience)
|
||||
|
||||
def _on_loop_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_loop_start()
|
||||
|
||||
def _on_loop_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_loop_end()
|
||||
|
||||
def _on_send_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_send_start()
|
||||
|
||||
def _on_send_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_send_end()
|
||||
|
||||
def _on_batch_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_start()
|
||||
|
||||
def _on_batch_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_batch_end()
|
||||
|
||||
|
||||
def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = actor.model
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
|
||||
new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
|
||||
|
||||
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
|
||||
new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
|
||||
|
||||
return new_kwargs
|
123
applications/ColossalChat/coati/ray/lora_constructor.py
Executable file
123
applications/ColossalChat/coati/ray/lora_constructor.py
Executable file
@@ -0,0 +1,123 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch.nn as nn
|
||||
from coati.models.lora import LoraLinear
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
r: int = 0
|
||||
lora_alpha: int = 1
|
||||
lora_dropout: float = 0
|
||||
fan_in_fan_out: bool = False
|
||||
|
||||
|
||||
class LoRAConstructor:
|
||||
"""
|
||||
Tools for reconstructing a model from a remote LoRA model.
|
||||
(Transferring only LoRA data costs much less!)
|
||||
Usage:
|
||||
Step 1 (Sender):
|
||||
filter_state_dict_lora()
|
||||
|
||||
Step 2 (Sender, Optional):
|
||||
extract_lora_config()
|
||||
|
||||
Step 3 (Sender):
|
||||
send state_dict_lora and lora_config_dict
|
||||
|
||||
Step 4 (Receiver):
|
||||
reconstruct_increase()
|
||||
|
||||
Step 5 (Receiver):
|
||||
load_state_dict_increase()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.lora_config_dict = None
|
||||
|
||||
def register_lora_config(self, lora_config_dict: Dict[str, Any]):
|
||||
self.lora_config_dict = lora_config_dict
|
||||
|
||||
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
|
||||
"""
|
||||
xxx.lora_A, xxx.lora_B -->> xxx.weight
|
||||
Warning: the xxx.weight here is the increment actually.
|
||||
"""
|
||||
if lora_config_dict is not None:
|
||||
self.register_lora_config(lora_config_dict)
|
||||
|
||||
state_dict_increase = OrderedDict()
|
||||
config_iter = iter(self.lora_config_dict.items())
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
for k, v in state_dict_lora.items():
|
||||
if k.rpartition(".")[-1] == "lora_A":
|
||||
lora_A = v
|
||||
layer_prefix = k.rpartition(".")[0]
|
||||
elif k.rpartition(".")[-1] == "lora_B":
|
||||
assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
|
||||
layer_prefix_2, config = next(config_iter)
|
||||
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
||||
lora_B = v
|
||||
weight_data_increase = self._compute(lora_A, lora_B, config)
|
||||
state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
else:
|
||||
raise ValueError("unexpected key")
|
||||
return state_dict_increase
|
||||
|
||||
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
||||
def T(w):
|
||||
return w.T if config.fan_in_fan_out else w
|
||||
|
||||
if config.r > 0:
|
||||
scaling = config.lora_alpha / config.r
|
||||
weight_data_increase = T(lora_B @ lora_A) * scaling
|
||||
return weight_data_increase
|
||||
return 0
|
||||
|
||||
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
|
||||
"""
|
||||
The final reconstruction step
|
||||
"""
|
||||
# naive approach
|
||||
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
|
||||
|
||||
@staticmethod
|
||||
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
||||
"""
|
||||
if keep_non_lora, also return non_lora state_dict
|
||||
"""
|
||||
state_dict_lora = OrderedDict()
|
||||
state_dict_non_lora = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if "lora_A" in k or "lora_B" in k:
|
||||
state_dict_lora[k] = v
|
||||
elif keep_non_lora:
|
||||
state_dict_non_lora[k] = v
|
||||
if keep_non_lora:
|
||||
return state_dict_lora, state_dict_non_lora
|
||||
else:
|
||||
return state_dict_lora, None
|
||||
|
||||
@staticmethod
|
||||
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
|
||||
"""
|
||||
extract LoraLinear model.
|
||||
return OrderedDict(): name -> LoRAConfig
|
||||
"""
|
||||
lora_config_dict = OrderedDict()
|
||||
|
||||
for name, child in model.named_modules():
|
||||
if isinstance(child, LoraLinear):
|
||||
lora_config_dict[name] = LoRAConfig(
|
||||
r=child.r,
|
||||
lora_alpha=child.lora_alpha,
|
||||
lora_dropout=child.lora_dropout,
|
||||
fan_in_fan_out=child.fan_in_fan_out,
|
||||
)
|
||||
|
||||
return lora_config_dict
|
142
applications/ColossalChat/coati/ray/utils.py
Executable file
142
applications/ColossalChat/coati/ray/utils.py
Executable file
@@ -0,0 +1,142 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
return dist.get_rank() if dist.is_initialized() else 0
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
return dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
||||
|
||||
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||
if model == "gpt2":
|
||||
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == "bloom":
|
||||
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == "opt":
|
||||
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == "llama":
|
||||
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{model}"')
|
||||
return actor
|
||||
|
||||
|
||||
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||
if model == "gpt2":
|
||||
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||
elif model == "bloom":
|
||||
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||
elif model == "opt":
|
||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||
elif model == "llama":
|
||||
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
return critic
|
||||
|
||||
|
||||
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
||||
if model == "gpt2":
|
||||
reward_model = GPTRM(pretrained=pretrained, config=config)
|
||||
elif model == "bloom":
|
||||
reward_model = BLOOMRM(pretrained=pretrained, config=config)
|
||||
elif model == "opt":
|
||||
reward_model = OPTRM(pretrained=pretrained, config=config)
|
||||
elif model == "llama":
|
||||
reward_model = LlamaRM(pretrained=pretrained, config=config)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
return reward_model
|
||||
|
||||
|
||||
def get_strategy_from_args(strategy: str):
|
||||
if strategy == "ddp":
|
||||
strategy_ = DDPStrategy()
|
||||
elif strategy == "colossalai_gemini":
|
||||
strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif strategy == "colossalai_gemini_cpu":
|
||||
strategy_ = GeminiStrategy(
|
||||
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
|
||||
)
|
||||
elif strategy == "colossalai_zero2_cpu":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
return strategy_
|
||||
|
||||
|
||||
def get_tokenizer_from_args(model: str, **kwargs):
|
||||
if model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
elif model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||
elif model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif model == "llama":
|
||||
pretrain_path = kwargs["pretrain"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{model}"')
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def set_dist_env(env_info: Dict[str, str]):
|
||||
os.environ["RANK"] = env_info["rank"]
|
||||
os.environ["LOCAL_RANK"] = env_info["local_rank"]
|
||||
os.environ["WORLD_SIZE"] = env_info["world_size"]
|
||||
os.environ["MASTER_PORT"] = env_info["master_port"]
|
||||
os.environ["MASTER_ADDR"] = env_info["master_addr"]
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
return numel
|
||||
|
||||
|
||||
def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
|
||||
target_receivers = []
|
||||
if num_senders <= num_receivers or allow_idle_sender:
|
||||
# a sender will send data to one or more receivers
|
||||
# a receiver only has one sender
|
||||
for i in range(num_receivers):
|
||||
if i % num_senders == sender_idx:
|
||||
target_receivers.append(i)
|
||||
else:
|
||||
# a sender will send data to one receiver
|
||||
# a receiver may have more than one sender
|
||||
target_receivers.append(sender_idx % num_receivers)
|
||||
return target_receivers
|
||||
|
||||
|
||||
def state_dict_to(
|
||||
state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
|
||||
):
|
||||
"""
|
||||
keep state_dict intact
|
||||
"""
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k] = v.to(dtype=dtype, device=device)
|
||||
return new_state_dict
|
Reference in New Issue
Block a user