mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-26 09:42:27 +00:00
* style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from coati.experience_maker.base import Experience
|
|
|
|
|
|
@dataclass
|
|
class BufferItem:
|
|
"""BufferItem is an item of experience data.
|
|
|
|
Shapes of each tensor:
|
|
sequences: (S)
|
|
action_log_probs: (A)
|
|
values: (1)
|
|
reward: (1)
|
|
advantages: (1)
|
|
attention_mask: (S)
|
|
action_mask: (A)
|
|
|
|
"A" is the number of actions.
|
|
"""
|
|
sequences: torch.Tensor
|
|
action_log_probs: torch.Tensor
|
|
values: torch.Tensor
|
|
reward: torch.Tensor
|
|
advantages: torch.Tensor
|
|
attention_mask: Optional[torch.LongTensor]
|
|
action_mask: Optional[torch.BoolTensor]
|
|
|
|
|
|
def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
|
batch_size = experience.sequences.size(0)
|
|
batch_kwargs = [{} for _ in range(batch_size)]
|
|
keys = ('sequences', 'action_log_probs', 'values',
|
|
'reward', 'advantages', 'attention_mask', 'action_mask')
|
|
for key in keys:
|
|
value = getattr(experience, key)
|
|
if isinstance(value, torch.Tensor):
|
|
vals = torch.unbind(value)
|
|
else:
|
|
# None
|
|
vals = [value for _ in range(batch_size)]
|
|
assert batch_size == len(vals)
|
|
for i, v in enumerate(vals):
|
|
batch_kwargs[i][key] = v
|
|
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
|
|
return items
|
|
|
|
|
|
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
|
assert side in ('left', 'right')
|
|
max_len = max(seq.size(0) for seq in sequences)
|
|
padded_sequences = []
|
|
for seq in sequences:
|
|
pad_len = max_len - seq.size(0)
|
|
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
|
|
padded_sequences.append(F.pad(seq, padding))
|
|
return torch.stack(padded_sequences, dim=0)
|
|
|
|
|
|
def make_experience_batch(items: List[BufferItem]) -> Experience:
|
|
kwargs = {}
|
|
to_pad_keys = set(('action_log_probs', 'action_mask'))
|
|
keys = ('sequences', 'action_log_probs', 'values',
|
|
'reward', 'advantages', 'attention_mask', 'action_mask')
|
|
for key in keys:
|
|
vals = [getattr(item, key) for item in items]
|
|
if key in to_pad_keys:
|
|
batch_data = _zero_pad_sequences(vals)
|
|
else:
|
|
batch_data = torch.stack(vals, dim=0)
|
|
kwargs[key] = batch_data
|
|
return Experience(**kwargs)
|