[chat] fix bugs and add unit tests (#4213)

* style: rename replay buffer

Experience replay is typically for off policy algorithms.
Use this name in PPO maybe misleading.

* fix: fix wrong zero2 default arg

* test: update experience tests

* style: rename zero_pad fn

* fix: defer init in CycledDataLoader

* test: add benchmark test

* style: rename internal fn of generation

* style: rename internal fn of lora

* fix: remove unused loss fn

* fix: remove unused utils fn

* refactor: remove generate_with_actor fn

* fix: fix type annotation

* test: add models tests

* fix: skip llama due to long execution time

* style: modify dataset

* style: apply formatter

* perf: update reward dataset

* fix: fix wrong IGNORE_INDEX in sft dataset

* fix: remove DataCollatorForSupervisedDataset

* test: add dataset tests

* style: apply formatter

* style: rename test_ci to test_train

* feat: add llama in inference

* test: add inference tests

* test: change test scripts directory

* fix: update ci

* fix: fix typo

* fix: skip llama due to oom

* fix: fix file mod

* style: apply formatter

* refactor: remove duplicated llama_gptq

* style: apply formatter

* to: update rm test

* feat: add tokenizer arg

* feat: add download model script

* test: update train tests

* fix: modify gemini load and save pretrained

* test: update checkpoint io test

* to: modify nproc_per_node

* fix: do not remove existing dir

* fix: modify save path

* test: add random choice

* fix: fix sft path

* fix: enlarge nproc_per_node to avoid oom

* fix: add num_retry

* fix: make lora config of rm and critic consistent

* fix: add warning about lora weights

* fix: skip some gpt2 tests

* fix: remove grad ckpt in rm and critic due to errors

* refactor: directly use Actor in train_sft

* test: add more arguments

* fix: disable grad ckpt when using lora

* fix: fix save_pretrained and related tests

* test: enable zero2 tests

* revert: remove useless fn

* style: polish code

* test: modify test args
This commit is contained in:
Wenhao Chen
2023-08-02 10:17:36 +08:00
committed by GitHub
parent 16bf4c0221
commit da4f7b855f
62 changed files with 1404 additions and 1202 deletions

View File

@@ -1,9 +1,10 @@
from .prompt_dataset import PromptDataset
from .reward_dataset import HhRlhfDataset, RmStaticDataset
from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
from .sft_dataset import SFTDataset, SupervisedDataset
from .utils import is_rank_0
__all__ = [
'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset',
'DataCollatorForSupervisedDataset', 'PromptDataset'
'RmStaticDataset', 'HhRlhfDataset',
'SFTDataset', 'SupervisedDataset',
'PromptDataset', 'is_rank_0',
]

View File

@@ -1,20 +1,13 @@
import copy
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
from typing import Dict
import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
from tqdm import tqdm
from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload
logger = get_dist_logger()
from .utils import jload
class PromptDataset(Dataset):
@@ -27,12 +20,13 @@ class PromptDataset(Dataset):
max_length: int = 96):
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
logger.info("Loading data...")
self.logger = get_dist_logger()
self.logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
self.logger.info(f"Loaded {len(list_data_dict)} examples.")
if max_datasets_size is not None:
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict]

View File

@@ -20,44 +20,44 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt']
self.end_token = tokenizer.eos_token \
if special_token is None else special_token
chosen = prompt + data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
chosen = [
data["prompt"] + data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
reject = prompt + data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
reject = [
data["prompt"] + data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
def __len__(self):
length = len(self.chosen)
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
# Anthropic/hh-rlhf
@@ -74,39 +74,41 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
chosen = data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
self.end_token = tokenizer.eos_token \
if special_token is None else special_token
reject = data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
chosen = [
data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
reject = [
data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
def __len__(self):
length = len(self.chosen)
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]

View File

@@ -13,44 +13,64 @@
# limitations under the License.
import copy
import random
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Sequence, Tuple
from typing import Dict, Sequence, Tuple
import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from colossalai.logging import get_dist_logger
from .conversation import default_conversation
from .utils import is_rank_0, jload
# The following is a template prompt for a 4-round conversation.
"""
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
"""
# Please note that we only calculate loss on assistant's answer tokens.
logger = get_dist_logger()
IGNORE_INDEX = -100
DEFAULT_EOS_TOKEN = "</s>"
PROMPT_DICT = {
"prompt_input":
("Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
"prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
"prompt_no_input": ("Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"),
}
def _preprocess(sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
sequences = [s + t for s, t in zip(sources, targets)]
sequences_token = tokenizer(sequences,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
sources_token = tokenizer(sources,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
source_len = sources_token["attention_mask"][i].sum().item()
pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
class SFTDataset(Dataset):
"""
Dataset for sft model
@@ -61,115 +81,31 @@ class SFTDataset(Dataset):
max_length: max length of input
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
def __init__(self,
dataset: Dict,
tokenizer: PreTrainedTokenizer,
max_length: int = 512
) -> None:
super().__init__()
self.input_ids = []
for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
prompt_token = tokenizer(prompt,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
sources = [data["prompt"] for data in dataset]
targets = [
data["completion"] + tokenizer.eos_token
for data in tqdm(dataset, disable=not is_rank_0())
]
self.input_ids.append(prompt_token['input_ids'][0])
self.labels = copy.deepcopy(self.input_ids)
self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = len(self.input_ids)
length = self.input_ids.shape[0]
return length
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings."""
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
input_ids = labels = tokenized_list["input_ids"]
input_ids_lens = labels_lens = \
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict:
"""Preprocess the conversation data by tokenizing."""
conversations = []
intermediates = []
for source in sources:
header = f"{default_conversation.system}"
conversation, intermediate = _add_speaker_and_signal(header, source)
conversations.append(conversation)
intermediates.append(intermediate)
conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
assert len(targets) == len(intermediates)
for target, inters in zip(targets, intermediates):
mask = torch.zeros_like(target, dtype=torch.bool)
for inter in inters:
tokenized = _tokenize_fn(inter, tokenizer, max_length)
start_idx = tokenized["input_ids"][0].size(0) - 1
end_idx = tokenized["input_ids"][1].size(0)
mask[start_idx:end_idx] = True
target[~mask] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def _add_speaker_and_signal(header: str,
source: List[Dict],
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
END_SIGNAL = DEFAULT_EOS_TOKEN
conversation = header
intermediate = []
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = default_conversation.roles[1]
else:
from_str = 'unknown'
value = from_str + ": " + sentence["value"] + END_SIGNAL
if sentence["from"].lower() == "gpt":
start = conversation + from_str + ": "
end = conversation + value
intermediate.append([start, end])
if get_conversation:
conversation += value
return conversation, intermediate
return dict(input_ids=self.input_ids[idx],
labels=self.labels[idx],
attention_mask=self.attention_mask[idx])
class SupervisedDataset(Dataset):
@@ -177,10 +113,10 @@ class SupervisedDataset(Dataset):
def __init__(self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
tokenizer: PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 512):
super(SupervisedDataset, self).__init__()
super().__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
@@ -190,52 +126,25 @@ class SupervisedDataset(Dataset):
list_data_dict = list_data_dict[:max_datasets_size]
logger.info("Formatting inputs...")
if "conversations" not in list_data_dict[0]:
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example)
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [
example['output'] + tokenizer.eos_token
for example in list_data_dict
]
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer, max_length)
else:
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")
sources = [conv["conversations"] for conv in list_data_dict]
data_dict = preprocess_conversation(sources, tokenizer, max_length)
if is_rank_0():
logger.info("Tokenizing finish.")
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
logger.info("Tokenizing inputs... This may take some time...")
self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
return len(self.input_ids)
length = self.input_ids.shape[0]
return length
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx],
labels=self.labels[idx],
attention_mask=self.attention_mask[idx])

View File

@@ -0,0 +1,4 @@
from .base import ExperienceBuffer
from .naive import NaiveExperienceBuffer
__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']

View File

@@ -4,8 +4,8 @@ from typing import Any
from coati.experience_maker.base import Experience
class ReplayBuffer(ABC):
"""Replay buffer base class. It stores experience.
class ExperienceBuffer(ABC):
"""Experience buffer base class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.

View File

@@ -4,12 +4,12 @@ from typing import List
import torch
from coati.experience_maker.base import Experience
from .base import ReplayBuffer
from .base import ExperienceBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch
class NaiveReplayBuffer(ReplayBuffer):
"""Naive replay buffer class. It stores experience.
class NaiveExperienceBuffer(ExperienceBuffer):
"""Naive experience buffer class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.

View File

@@ -33,7 +33,8 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
keys = ('sequences', 'action_log_probs', 'values',
'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
@@ -48,7 +49,7 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
return items
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
assert side in ('left', 'right')
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
@@ -62,11 +63,12 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> tor
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(('action_log_probs', 'action_mask'))
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
keys = ('sequences', 'action_log_probs', 'values',
'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
batch_data = zero_pad_sequences(vals)
batch_data = _zero_pad_sequences(vals)
else:
batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data

View File

@@ -1,6 +1,7 @@
import torch
from coati.models.generation import generate_with_actor
from coati.models.utils import calc_action_log_probs, compute_reward, normalize
import torch.nn.functional as F
from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
from .base import Experience, ExperienceMaker
@@ -17,10 +18,27 @@ class NaiveExperienceMaker(ExperienceMaker):
self.initial_model.eval()
self.reward_model.eval()
sequences, attention_mask, action_mask = generate_with_actor(self.actor,
input_ids,
return_action_mask=True,
**generate_kwargs)
# generate sequences
sequences = generate(self.actor, input_ids, **generate_kwargs)
# calculate auxiliary tensors
attention_mask = None
pad_token_id = generate_kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id)\
.to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
eos_token_id = generate_kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len):]
num_actions = action_mask.size(1)
actor_output = self.actor(sequences, attention_mask)

View File

@@ -1,8 +1,8 @@
from .base import Actor, Critic, RewardModel
from .lora import LoRAModule, convert_to_lora_module
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
__all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
'LoRAModule', 'convert_to_lora_module'
]

View File

@@ -14,7 +14,6 @@ class BLOOMCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -22,7 +21,6 @@ class BLOOMCritic(Critic):
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
@@ -32,7 +30,6 @@ class BLOOMCritic(Critic):
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)

View File

@@ -13,7 +13,6 @@ class BLOOMRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -21,7 +20,6 @@ class BLOOMRM(RewardModel):
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
@@ -30,8 +28,7 @@ class BLOOMRM(RewardModel):
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)

View File

@@ -1,9 +1,9 @@
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .base import Actor
try:
from transformers.generation_logits_process import (
@@ -16,9 +16,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList:
def _prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
@@ -37,22 +37,22 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
def sample(model: nn.Module,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
def _sample(model: Actor,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
@@ -89,7 +89,8 @@ def sample(model: nn.Module,
return input_ids
def generate(model: nn.Module,
@torch.no_grad()
def generate(model: Actor,
input_ids: torch.Tensor,
max_length: int,
num_beams: int = 1,
@@ -128,51 +129,19 @@ def generate(model: nn.Module,
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
return sample(model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs)
return _sample(model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs)
elif is_beam_gen_mode:
raise NotImplementedError
else:
raise ValueError("Unsupported generation mode")
@torch.no_grad()
def generate_with_actor(
actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences
sequences = generate(actor_model, input_ids, **kwargs)
# calculate auxiliary tensors
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]

View File

@@ -14,7 +14,6 @@ class GPTCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -22,7 +21,6 @@ class GPTCritic(Critic):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
@@ -32,7 +30,6 @@ class GPTCritic(Critic):
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)

View File

@@ -14,7 +14,6 @@ class GPTRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -22,7 +21,6 @@ class GPTRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
@@ -31,8 +29,6 @@ class GPTRM(RewardModel):
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))

View File

@@ -13,7 +13,6 @@ class LlamaCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -21,7 +20,6 @@ class LlamaCritic(Critic):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
@@ -33,9 +31,5 @@ class LlamaCritic(Critic):
else:
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)

View File

@@ -13,7 +13,6 @@ class LlamaRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -21,7 +20,6 @@ class LlamaRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
@@ -32,8 +30,6 @@ class LlamaRM(RewardModel):
else:
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))

View File

@@ -98,18 +98,18 @@ class LoraLinear(lora.LoRALayer, nn.Module):
return F.linear(x, T(self.weight), bias=self.bias)
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, lora_linear_wrapper(child, lora_rank))
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
else:
convert_to_lora_recursively(child, lora_rank)
_convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
@@ -124,7 +124,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
"""
if lora_rank <= 0:
return module
convert_to_lora_recursively(module, lora_rank)
_convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module

View File

@@ -68,31 +68,6 @@ class ValueLoss(nn.Module):
return 0.5 * loss
class PPOPtxActorLoss(nn.Module):
"""
To Do:
PPO-ptx Actor Loss
"""
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
super().__init__()
self.pretrain_coef = pretrain_coef
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
self.pretrain_loss_fn = pretrain_loss_fn
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
lm_logits: torch.Tensor,
lm_input_ids: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
return policy_loss + self.pretrain_coef * lm_loss
class LogSigLoss(nn.Module):
"""
Pairwise Loss for Reward Model

View File

@@ -14,7 +14,6 @@ class OPTCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -22,7 +21,6 @@ class OPTCritic(Critic):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
@@ -32,7 +30,6 @@ class OPTCritic(Critic):
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)

View File

@@ -13,7 +13,6 @@ class OPTRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
@@ -21,7 +20,6 @@ class OPTRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
@@ -30,8 +28,6 @@ class OPTRM(RewardModel):
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))

View File

@@ -1,14 +1,12 @@
from typing import Optional, Union
import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_approx_kl(log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def _compute_approx_kl(log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
@@ -35,12 +33,12 @@ def compute_reward(r: Union[torch.Tensor, float],
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if kl_coef <= 0.0:
return r
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
reward = r - kl_coef * kl
return reward
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
@@ -58,7 +56,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
torch.Tensor: Action log probs.
"""
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
@@ -68,41 +66,3 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
tensor = tensor * mask
mean = masked_mean(tensor, mask, dim=dim)
mean_centered = tensor - mean
var = masked_mean(mean_centered**2, mask, dim=dim)
return mean_centered * var.clamp(min=eps).rsqrt()
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
mean = tensor.mean(dim)
mean_centered = tensor - mean
var = (mean_centered**2).mean(dim)
norm = mean_centered * var.clamp(min=eps).rsqrt()
return norm
def convert_to_lora(model: nn.Module,
input_size: int,
output_size: int,
lora_rank: int = 16,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
merge_weights: bool = True):
if lora_rank > min(input_size, output_size):
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
module._modules[name] = lora.Linear(input_size,
output_size,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
fan_in_fan_out=fan_in_fan_out,
merge_weights=merge_weights)

View File

@@ -115,12 +115,12 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)
@@ -204,9 +204,9 @@ class TrainerPerformanceEvaluator(TrainerCallback):
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)

View File

@@ -6,9 +6,9 @@ from typing import Any, List
import ray
import torch
from coati.experience_buffer import ExperienceBuffer
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker.base import Experience
from coati.replay_buffer import ReplayBuffer
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
# from torch.multiprocessing import Queue
from ray.util.queue import Queue

View File

@@ -4,8 +4,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import ray
import torch
from coati.experience_buffer.utils import BufferItem
from coati.experience_maker import Experience
from coati.replay_buffer.utils import BufferItem
from torch.utils.data import DataLoader
from tqdm import tqdm

View File

@@ -8,9 +8,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import ray
import torch
import torch.nn as nn
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy
from coati.trainer.strategies.sampler import DistributedSampler
@@ -19,13 +19,9 @@ from torch import Tensor
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .utils import (get_model_numel,
get_rank,
get_world_size,
is_rank_0,
set_dist_env,
state_dict_to)
from .lora_constructor import LoRAConstructor
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
@@ -41,7 +37,7 @@ class ExperienceMakerHolder:
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model)
# a function returns (actor, critic, reward_model, initial_model)
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
@@ -205,15 +201,19 @@ class ExperienceMakerHolder:
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase)
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
new_actor_state_dict, new_actor_lora_config_dict)
self.actor_lora_constructor.load_state_dict_increase(
self.experience_maker.actor.model, state_dict_increase)
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase)
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
new_critic_state_dict, new_critic_lora_config_dict)
self.critic_lora_constructor.load_state_dict_increase(
self.experience_maker.critic, state_dict_increase)
# the lock must be released after both actor and critic being updated
if chunk_end:

View File

@@ -1,11 +1,11 @@
from typing import Any, Callable, Dict, List, Optional
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.nn as nn
from loralib.layers import LoRALayer
from coati.models.lora import LoraLinear
from loralib.layers import LoRALayer
@dataclass
@@ -23,19 +23,19 @@ class LoRAConstructor:
Usage:
Step 1 (Sender):
filter_state_dict_lora()
Step 2 (Sender, Optional):
extract_lora_config()
Step 3 (Sender):
send state_dict_lora and lora_config_dict
Step 4 (Receiver):
reconstruct_increase()
Step 5 (Receiver):
load_state_dict_increase()
'''
def __init__(self):

View File

@@ -1,4 +0,0 @@
from .base import ReplayBuffer
from .naive import NaiveReplayBuffer
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']

View File

@@ -4,8 +4,8 @@ from typing import List
import torch.nn as nn
import tqdm
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
from coati.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
@@ -62,7 +62,7 @@ class OnPolicyTrainer(ABC):
Args:
strategy (Strategy):the strategy to use for training
buffer (NaiveReplayBuffer): the buffer to collect experiences
data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
@@ -70,13 +70,13 @@ class OnPolicyTrainer(ABC):
def __init__(self,
strategy: Strategy,
buffer: NaiveReplayBuffer,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = []) -> None:
super().__init__()
self.strategy = strategy
self.buffer = buffer
self.data_buffer = data_buffer
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
@@ -144,7 +144,7 @@ class OnPolicyTrainer(ABC):
self._on_make_experience_start()
experience = self._make_experience(collect_step)
self._on_make_experience_end(experience)
self.buffer.append(experience)
self.data_buffer.append(experience)
def _update_phase(self, update_step: int):
self._on_learn_epoch_start(update_step)
@@ -181,8 +181,8 @@ class OnPolicyTrainer(ABC):
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory)
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.buffer.clear()
self.data_buffer.clear()

View File

@@ -171,13 +171,13 @@ class PerformanceEvaluator(Callback):
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0(
f'Performance summary:\n' +
f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
+
f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' +
f'Overall time per sample: {overall_time_per_sample:.2f} s\n' +
f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
+
f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
f'Performance summary:\n'
+ f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
+ f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
+ f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
+ f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
+ f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
)

View File

@@ -1,11 +1,11 @@
from typing import Dict, List
import torch.nn as nn
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
@@ -86,9 +86,9 @@ class PPOTrainer(OnPolicyTrainer):
assert not offload_inference_models, \
"GeminiPlugin is not compatible with manual model.to('cpu')"
buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(
strategy, buffer,
strategy, data_buffer,
sample_buffer, dataloader_pin_memory,
callbacks
)
@@ -170,7 +170,7 @@ class PPOTrainer(OnPolicyTrainer):
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
experience = self.buffer.sample()
experience = self.data_buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
metrics = self._training_step(experience)

View File

@@ -4,7 +4,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from coati.replay_buffer import ReplayBuffer
from coati.experience_buffer import ExperienceBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
@@ -45,7 +45,7 @@ class Strategy(ABC):
pass
@abstractmethod
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
pass
def model_init_context(self):

View File

@@ -4,7 +4,6 @@ from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
@@ -44,7 +43,7 @@ class LowLevelZeroStrategy(DDPStrategy):
"""
def __init__(self,
stage: int = 3,
stage: int = 2,
precision: str = 'fp16',
seed: int = 42,
placement_policy: str = 'cuda',
@@ -214,14 +213,3 @@ class GeminiStrategy(DDPStrategy):
ddp_model = model.unwrap()
assert isinstance(ddp_model, GeminiDDP)
return ddp_model.module
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
def get_model_state_dict_shard(self, model: nn.Module, **config):
assert isinstance(self.plugin, GeminiPlugin)
yield from super().get_model_state_dict_shard(model, **config)

View File

@@ -7,7 +7,8 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.replay_buffer import ReplayBuffer
from coati.experience_buffer import ExperienceBuffer
from coati.models import Actor, Critic, RewardModel
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
@@ -71,13 +72,13 @@ class DDPStrategy(Strategy):
np.random.seed(seed)
torch.manual_seed(seed)
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
return self.plugin.prepare_dataloader(replay_buffer,
batch_size=replay_buffer.sample_batch_size,
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
return self.plugin.prepare_dataloader(data_buffer,
batch_size=data_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
collate_fn=data_buffer.collate_fn)
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
@@ -92,13 +93,33 @@ class DDPStrategy(Strategy):
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, PreTrainedModel)
unwrapped_model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
if not only_rank0 or dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
pretrained_model = unwrapped_model.model
assert isinstance(pretrained_model, PreTrainedModel)
# HACK: only use hf save_pretrained to save config
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
if tokenizer is not None:
tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model,
model_path,
only_rank0=only_rank0)
def _replace_keys(model_path: str,
replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {
replace_fn(k): v
for k, v in state_dict.items()
}
torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy

View File

@@ -27,7 +27,6 @@ class DistributedSampler:
assert len(indices) == self.num_samples
self.indices = indices
def sample(self, batch_size: int) -> list:
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
return [self.dataset[idx] for idx in sampled_indices]

View File

@@ -21,9 +21,13 @@ class CycledDataLoader:
self.dataloader = dataloader
self.count = 0
self.dataloader_iter = iter(dataloader)
self.dataloader_iter = None
def next(self):
# defer initialization
if self.dataloader_iter is None:
self.dataloader_iter = iter(self.dataloader)
self.count += 1
try:
return next(self.dataloader_iter)