mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
[chat] fix bugs and add unit tests (#4213)
* style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from .prompt_dataset import PromptDataset
|
||||
from .reward_dataset import HhRlhfDataset, RmStaticDataset
|
||||
from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from .sft_dataset import SFTDataset, SupervisedDataset
|
||||
from .utils import is_rank_0
|
||||
|
||||
__all__ = [
|
||||
'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset',
|
||||
'DataCollatorForSupervisedDataset', 'PromptDataset'
|
||||
'RmStaticDataset', 'HhRlhfDataset',
|
||||
'SFTDataset', 'SupervisedDataset',
|
||||
'PromptDataset', 'is_rank_0',
|
||||
]
|
||||
|
@@ -1,20 +1,13 @@
|
||||
import copy
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, Sequence
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .utils import is_rank_0, jload
|
||||
|
||||
logger = get_dist_logger()
|
||||
from .utils import jload
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
@@ -27,12 +20,13 @@ class PromptDataset(Dataset):
|
||||
max_length: int = 96):
|
||||
super(PromptDataset, self).__init__()
|
||||
self.keyed_prompt = defaultdict(list)
|
||||
logger.info("Loading data...")
|
||||
self.logger = get_dist_logger()
|
||||
self.logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||
self.logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||
|
||||
if max_datasets_size is not None:
|
||||
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
||||
self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
||||
list_data_dict = list_data_dict[:max_datasets_size]
|
||||
|
||||
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
|
||||
|
@@ -20,44 +20,44 @@ class RmStaticDataset(Dataset):
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||
super().__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
self.end_token = tokenizer.eos_token
|
||||
else:
|
||||
self.end_token = special_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
prompt = data['prompt']
|
||||
self.end_token = tokenizer.eos_token \
|
||||
if special_token is None else special_token
|
||||
|
||||
chosen = prompt + data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
chosen = [
|
||||
data["prompt"] + data["chosen"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen = {
|
||||
"input_ids": chosen_token["input_ids"],
|
||||
"attention_mask": chosen_token["attention_mask"]
|
||||
}
|
||||
|
||||
reject = prompt + data['rejected'] + self.end_token
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
reject = [
|
||||
data["prompt"] + data["rejected"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject = {
|
||||
"input_ids": reject_token["input_ids"],
|
||||
"attention_mask": reject_token["attention_mask"]
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
length = self.chosen["input_ids"].shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
||||
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
|
||||
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
|
||||
|
||||
|
||||
# Anthropic/hh-rlhf
|
||||
@@ -74,39 +74,41 @@ class HhRlhfDataset(Dataset):
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||
super().__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
self.end_token = tokenizer.eos_token
|
||||
else:
|
||||
self.end_token = special_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
chosen = data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
self.end_token = tokenizer.eos_token \
|
||||
if special_token is None else special_token
|
||||
|
||||
reject = data['rejected'] + self.end_token
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
chosen = [
|
||||
data["chosen"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen = {
|
||||
"input_ids": chosen_token["input_ids"],
|
||||
"attention_mask": chosen_token["attention_mask"]
|
||||
}
|
||||
|
||||
reject = [
|
||||
data["rejected"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject = {
|
||||
"input_ids": reject_token["input_ids"],
|
||||
"attention_mask": reject_token["attention_mask"]
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
length = self.chosen["input_ids"].shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
||||
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
|
||||
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
|
||||
|
@@ -13,44 +13,64 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Sequence, Tuple
|
||||
from typing import Dict, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .conversation import default_conversation
|
||||
from .utils import is_rank_0, jload
|
||||
|
||||
# The following is a template prompt for a 4-round conversation.
|
||||
"""
|
||||
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
|
||||
|
||||
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
|
||||
"""
|
||||
# Please note that we only calculate loss on assistant's answer tokens.
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
PROMPT_DICT = {
|
||||
"prompt_input":
|
||||
("Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
|
||||
"prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
|
||||
"prompt_no_input": ("Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Response:"),
|
||||
}
|
||||
|
||||
|
||||
def _preprocess(sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
sequences = [s + t for s, t in zip(sources, targets)]
|
||||
sequences_token = tokenizer(sequences,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
sources_token = tokenizer(sources,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
|
||||
labels = copy.deepcopy(sequences_token["input_ids"])
|
||||
for i in range(labels.shape[0]):
|
||||
source_len = sources_token["attention_mask"][i].sum().item()
|
||||
pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
|
||||
if tokenizer.padding_side == "right":
|
||||
# |prompt|completion|eos|pad|
|
||||
labels[i][:source_len] = IGNORE_INDEX
|
||||
elif tokenizer.padding_side == "left":
|
||||
# |pad|prompt|completion|eos|
|
||||
labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
|
||||
else:
|
||||
raise RuntimeError()
|
||||
|
||||
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
"""
|
||||
Dataset for sft model
|
||||
@@ -61,115 +81,31 @@ class SFTDataset(Dataset):
|
||||
max_length: max length of input
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
|
||||
def __init__(self,
|
||||
dataset: Dict,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int = 512
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.input_ids = []
|
||||
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
|
||||
prompt_token = tokenizer(prompt,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
sources = [data["prompt"] for data in dataset]
|
||||
targets = [
|
||||
data["completion"] + tokenizer.eos_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
|
||||
self.input_ids.append(prompt_token['input_ids'][0])
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.input_ids)
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int) -> Dict[str, torch.Tensor]:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
|
||||
input_ids = labels = tokenized_list["input_ids"]
|
||||
input_ids_lens = labels_lens = \
|
||||
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [
|
||||
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
||||
]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int) -> Dict:
|
||||
"""Preprocess the conversation data by tokenizing."""
|
||||
conversations = []
|
||||
intermediates = []
|
||||
for source in sources:
|
||||
header = f"{default_conversation.system}"
|
||||
conversation, intermediate = _add_speaker_and_signal(header, source)
|
||||
conversations.append(conversation)
|
||||
intermediates.append(intermediate)
|
||||
|
||||
conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
|
||||
input_ids = conversations_tokenized["input_ids"]
|
||||
targets = copy.deepcopy(input_ids)
|
||||
|
||||
assert len(targets) == len(intermediates)
|
||||
for target, inters in zip(targets, intermediates):
|
||||
mask = torch.zeros_like(target, dtype=torch.bool)
|
||||
for inter in inters:
|
||||
tokenized = _tokenize_fn(inter, tokenizer, max_length)
|
||||
|
||||
start_idx = tokenized["input_ids"][0].size(0) - 1
|
||||
end_idx = tokenized["input_ids"][1].size(0)
|
||||
|
||||
mask[start_idx:end_idx] = True
|
||||
target[~mask] = IGNORE_INDEX
|
||||
|
||||
return dict(input_ids=input_ids, labels=targets)
|
||||
|
||||
|
||||
def _add_speaker_and_signal(header: str,
|
||||
source: List[Dict],
|
||||
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
|
||||
END_SIGNAL = DEFAULT_EOS_TOKEN
|
||||
conversation = header
|
||||
intermediate = []
|
||||
for sentence in source:
|
||||
from_str = sentence["from"]
|
||||
if from_str.lower() == "human":
|
||||
from_str = default_conversation.roles[0]
|
||||
elif from_str.lower() == "gpt":
|
||||
from_str = default_conversation.roles[1]
|
||||
else:
|
||||
from_str = 'unknown'
|
||||
|
||||
value = from_str + ": " + sentence["value"] + END_SIGNAL
|
||||
if sentence["from"].lower() == "gpt":
|
||||
start = conversation + from_str + ": "
|
||||
end = conversation + value
|
||||
intermediate.append([start, end])
|
||||
if get_conversation:
|
||||
conversation += value
|
||||
return conversation, intermediate
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
@@ -177,10 +113,10 @@ class SupervisedDataset(Dataset):
|
||||
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_datasets_size: int = None,
|
||||
max_length: int = 512):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
super().__init__()
|
||||
logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||
@@ -190,52 +126,25 @@ class SupervisedDataset(Dataset):
|
||||
list_data_dict = list_data_dict[:max_datasets_size]
|
||||
|
||||
logger.info("Formatting inputs...")
|
||||
if "conversations" not in list_data_dict[0]:
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example)
|
||||
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
|
||||
]
|
||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [
|
||||
example['output'] + tokenizer.eos_token
|
||||
for example in list_data_dict
|
||||
]
|
||||
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
|
||||
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||
else:
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
|
||||
sources = [conv["conversations"] for conv in list_data_dict]
|
||||
data_dict = preprocess_conversation(sources, tokenizer, max_length)
|
||||
|
||||
if is_rank_0():
|
||||
logger.info("Tokenizing finish.")
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||
)
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
|
4
applications/Chat/coati/experience_buffer/__init__.py
Normal file
4
applications/Chat/coati/experience_buffer/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import ExperienceBuffer
|
||||
from .naive import NaiveExperienceBuffer
|
||||
|
||||
__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
|
@@ -4,8 +4,8 @@ from typing import Any
|
||||
from coati.experience_maker.base import Experience
|
||||
|
||||
|
||||
class ReplayBuffer(ABC):
|
||||
"""Replay buffer base class. It stores experience.
|
||||
class ExperienceBuffer(ABC):
|
||||
"""Experience buffer base class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
@@ -4,12 +4,12 @@ from typing import List
|
||||
import torch
|
||||
from coati.experience_maker.base import Experience
|
||||
|
||||
from .base import ReplayBuffer
|
||||
from .base import ExperienceBuffer
|
||||
from .utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
|
||||
|
||||
class NaiveReplayBuffer(ReplayBuffer):
|
||||
"""Naive replay buffer class. It stores experience.
|
||||
class NaiveExperienceBuffer(ExperienceBuffer):
|
||||
"""Naive experience buffer class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
@@ -33,7 +33,8 @@ class BufferItem:
|
||||
def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
||||
batch_size = experience.sequences.size(0)
|
||||
batch_kwargs = [{} for _ in range(batch_size)]
|
||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
keys = ('sequences', 'action_log_probs', 'values',
|
||||
'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
for key in keys:
|
||||
value = getattr(experience, key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
@@ -48,7 +49,7 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
||||
return items
|
||||
|
||||
|
||||
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
||||
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
||||
assert side in ('left', 'right')
|
||||
max_len = max(seq.size(0) for seq in sequences)
|
||||
padded_sequences = []
|
||||
@@ -62,11 +63,12 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> tor
|
||||
def make_experience_batch(items: List[BufferItem]) -> Experience:
|
||||
kwargs = {}
|
||||
to_pad_keys = set(('action_log_probs', 'action_mask'))
|
||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
keys = ('sequences', 'action_log_probs', 'values',
|
||||
'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
for key in keys:
|
||||
vals = [getattr(item, key) for item in items]
|
||||
if key in to_pad_keys:
|
||||
batch_data = zero_pad_sequences(vals)
|
||||
batch_data = _zero_pad_sequences(vals)
|
||||
else:
|
||||
batch_data = torch.stack(vals, dim=0)
|
||||
kwargs[key] = batch_data
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from coati.models.generation import generate_with_actor
|
||||
from coati.models.utils import calc_action_log_probs, compute_reward, normalize
|
||||
import torch.nn.functional as F
|
||||
from coati.models.generation import generate
|
||||
from coati.models.utils import calc_action_log_probs, compute_reward
|
||||
|
||||
from .base import Experience, ExperienceMaker
|
||||
|
||||
@@ -17,10 +18,27 @@ class NaiveExperienceMaker(ExperienceMaker):
|
||||
self.initial_model.eval()
|
||||
self.reward_model.eval()
|
||||
|
||||
sequences, attention_mask, action_mask = generate_with_actor(self.actor,
|
||||
input_ids,
|
||||
return_action_mask=True,
|
||||
**generate_kwargs)
|
||||
# generate sequences
|
||||
sequences = generate(self.actor, input_ids, **generate_kwargs)
|
||||
|
||||
# calculate auxiliary tensors
|
||||
attention_mask = None
|
||||
pad_token_id = generate_kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id)\
|
||||
.to(dtype=torch.long, device=sequences.device)
|
||||
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = generate_kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
action_mask = action_mask[:, -(sequences.size(1) - input_len):]
|
||||
num_actions = action_mask.size(1)
|
||||
|
||||
actor_output = self.actor(sequences, attention_mask)
|
||||
|
@@ -1,8 +1,8 @@
|
||||
from .base import Actor, Critic, RewardModel
|
||||
from .lora import LoRAModule, convert_to_lora_module
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
|
||||
__all__ = [
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
|
||||
'LoRAModule', 'convert_to_lora_module'
|
||||
]
|
||||
|
@@ -14,7 +14,6 @@ class BLOOMCritic(Critic):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -22,7 +21,6 @@ class BLOOMCritic(Critic):
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
@@ -32,7 +30,6 @@ class BLOOMCritic(Critic):
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
@@ -13,7 +13,6 @@ class BLOOMRM(RewardModel):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -21,7 +20,6 @@ class BLOOMRM(RewardModel):
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
@@ -30,8 +28,7 @@ class BLOOMRM(RewardModel):
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||
|
@@ -1,9 +1,9 @@
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import Actor
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
@@ -16,9 +16,9 @@ except ImportError:
|
||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
|
||||
def prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
def _prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
if temperature is not None and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
@@ -37,22 +37,22 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
||||
return unfinished_sequences.max() == 0
|
||||
|
||||
|
||||
def sample(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
def _sample(model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
if input_ids.size(1) >= max_length:
|
||||
return input_ids
|
||||
|
||||
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
||||
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(input_ids.size(1), max_length):
|
||||
@@ -89,7 +89,8 @@ def sample(model: nn.Module,
|
||||
return input_ids
|
||||
|
||||
|
||||
def generate(model: nn.Module,
|
||||
@torch.no_grad()
|
||||
def generate(model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
@@ -128,51 +129,19 @@ def generate(model: nn.Module,
|
||||
raise NotImplementedError
|
||||
elif is_sample_gen_mode:
|
||||
# run sample
|
||||
return sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
return _sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
elif is_beam_gen_mode:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("Unsupported generation mode")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_with_actor(
|
||||
actor_model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
"""Generate token sequence with actor model. Refer to `generate` for more details.
|
||||
"""
|
||||
# generate sequences
|
||||
sequences = generate(actor_model, input_ids, **kwargs)
|
||||
|
||||
# calculate auxiliary tensors
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
@@ -14,7 +14,6 @@ class GPTCritic(Critic):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -22,7 +21,6 @@ class GPTCritic(Critic):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
@@ -32,7 +30,6 @@ class GPTCritic(Critic):
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
@@ -14,7 +14,6 @@ class GPTRM(RewardModel):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -22,7 +21,6 @@ class GPTRM(RewardModel):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
@@ -31,8 +29,6 @@ class GPTRM(RewardModel):
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
|
||||
|
@@ -13,7 +13,6 @@ class LlamaCritic(Critic):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -21,7 +20,6 @@ class LlamaCritic(Critic):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
@@ -33,9 +31,5 @@ class LlamaCritic(Critic):
|
||||
else:
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
@@ -13,7 +13,6 @@ class LlamaRM(RewardModel):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -21,7 +20,6 @@ class LlamaRM(RewardModel):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
@@ -32,8 +30,6 @@ class LlamaRM(RewardModel):
|
||||
else:
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
|
||||
|
@@ -98,18 +98,18 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
|
||||
|
||||
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
||||
return lora_linear
|
||||
|
||||
|
||||
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
setattr(module, name, lora_linear_wrapper(child, lora_rank))
|
||||
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
|
||||
else:
|
||||
convert_to_lora_recursively(child, lora_rank)
|
||||
_convert_to_lora_recursively(child, lora_rank)
|
||||
|
||||
|
||||
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
|
||||
@@ -124,7 +124,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
|
||||
"""
|
||||
if lora_rank <= 0:
|
||||
return module
|
||||
convert_to_lora_recursively(module, lora_rank)
|
||||
_convert_to_lora_recursively(module, lora_rank)
|
||||
lora.mark_only_lora_as_trainable(module, lora_train_bias)
|
||||
return module
|
||||
|
||||
|
@@ -68,31 +68,6 @@ class ValueLoss(nn.Module):
|
||||
return 0.5 * loss
|
||||
|
||||
|
||||
class PPOPtxActorLoss(nn.Module):
|
||||
"""
|
||||
To Do:
|
||||
|
||||
PPO-ptx Actor Loss
|
||||
"""
|
||||
|
||||
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
|
||||
super().__init__()
|
||||
self.pretrain_coef = pretrain_coef
|
||||
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
|
||||
self.pretrain_loss_fn = pretrain_loss_fn
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
lm_logits: torch.Tensor,
|
||||
lm_input_ids: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
|
||||
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
|
||||
return policy_loss + self.pretrain_coef * lm_loss
|
||||
|
||||
|
||||
class LogSigLoss(nn.Module):
|
||||
"""
|
||||
Pairwise Loss for Reward Model
|
||||
|
@@ -14,7 +14,6 @@ class OPTCritic(Critic):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -22,7 +21,6 @@ class OPTCritic(Critic):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
@@ -32,7 +30,6 @@ class OPTCritic(Critic):
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
||||
|
@@ -13,7 +13,6 @@ class OPTRM(RewardModel):
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
@@ -21,7 +20,6 @@ class OPTRM(RewardModel):
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
@@ -30,8 +28,6 @@ class OPTRM(RewardModel):
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))
|
||||
|
@@ -1,14 +1,12 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def _compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Compute the approximate KL divergence between two distributions.
|
||||
Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
@@ -35,12 +33,12 @@ def compute_reward(r: Union[torch.Tensor, float],
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if kl_coef <= 0.0:
|
||||
return r
|
||||
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
reward = r - kl_coef * kl
|
||||
return reward
|
||||
|
||||
|
||||
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return log_probs_labels.squeeze(-1)
|
||||
@@ -58,7 +56,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
||||
@@ -68,41 +66,3 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
||||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
mean = masked_mean(tensor, mask, dim=dim)
|
||||
mean_centered = tensor - mean
|
||||
var = masked_mean(mean_centered**2, mask, dim=dim)
|
||||
return mean_centered * var.clamp(min=eps).rsqrt()
|
||||
|
||||
|
||||
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
|
||||
mean = tensor.mean(dim)
|
||||
mean_centered = tensor - mean
|
||||
var = (mean_centered**2).mean(dim)
|
||||
norm = mean_centered * var.clamp(min=eps).rsqrt()
|
||||
return norm
|
||||
|
||||
|
||||
def convert_to_lora(model: nn.Module,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
lora_rank: int = 16,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False,
|
||||
merge_weights: bool = True):
|
||||
if lora_rank > min(input_size, output_size):
|
||||
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
module._modules[name] = lora.Linear(input_size,
|
||||
output_size,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
merge_weights=merge_weights)
|
||||
|
@@ -115,12 +115,12 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
||||
f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
|
||||
f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
||||
f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
|
||||
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
)
|
||||
|
||||
|
||||
@@ -204,9 +204,9 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
|
||||
f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
|
||||
f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
)
|
||||
|
@@ -6,9 +6,9 @@ from typing import Any, List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer import ExperienceBuffer
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker.base import Experience
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
# from torch.multiprocessing import Queue
|
||||
from ray.util.queue import Queue
|
||||
|
||||
|
@@ -4,8 +4,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer.utils import BufferItem
|
||||
from coati.experience_maker import Experience
|
||||
from coati.replay_buffer.utils import BufferItem
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@@ -8,9 +8,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
import ray
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import Strategy
|
||||
from coati.trainer.strategies.sampler import DistributedSampler
|
||||
@@ -19,13 +19,9 @@ from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
|
||||
from .utils import (get_model_numel,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
is_rank_0,
|
||||
set_dist_env,
|
||||
state_dict_to)
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||
class ExperienceMakerHolder:
|
||||
@@ -41,7 +37,7 @@ class ExperienceMakerHolder:
|
||||
self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||
env_info: Dict[str, str] = None,
|
||||
sync_models_from_trainers: bool = False,
|
||||
@@ -205,15 +201,19 @@ class ExperienceMakerHolder:
|
||||
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
|
||||
else:
|
||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
|
||||
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase)
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
|
||||
new_actor_state_dict, new_actor_lora_config_dict)
|
||||
self.actor_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.actor.model, state_dict_increase)
|
||||
if new_critic_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||
else:
|
||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
|
||||
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase)
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
|
||||
new_critic_state_dict, new_critic_lora_config_dict)
|
||||
self.critic_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.critic, state_dict_increase)
|
||||
|
||||
# the lock must be released after both actor and critic being updated
|
||||
if chunk_end:
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loralib.layers import LoRALayer
|
||||
from coati.models.lora import LoraLinear
|
||||
from loralib.layers import LoRALayer
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,19 +23,19 @@ class LoRAConstructor:
|
||||
Usage:
|
||||
Step 1 (Sender):
|
||||
filter_state_dict_lora()
|
||||
|
||||
|
||||
Step 2 (Sender, Optional):
|
||||
extract_lora_config()
|
||||
|
||||
|
||||
Step 3 (Sender):
|
||||
send state_dict_lora and lora_config_dict
|
||||
|
||||
|
||||
Step 4 (Receiver):
|
||||
reconstruct_increase()
|
||||
|
||||
|
||||
Step 5 (Receiver):
|
||||
load_state_dict_increase()
|
||||
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
|
@@ -1,4 +0,0 @@
|
||||
from .base import ReplayBuffer
|
||||
from .naive import NaiveReplayBuffer
|
||||
|
||||
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
|
@@ -4,8 +4,8 @@ from typing import List
|
||||
|
||||
import torch.nn as nn
|
||||
import tqdm
|
||||
from coati.experience_buffer import NaiveExperienceBuffer
|
||||
from coati.experience_maker import Experience
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
@@ -62,7 +62,7 @@ class OnPolicyTrainer(ABC):
|
||||
|
||||
Args:
|
||||
strategy (Strategy):the strategy to use for training
|
||||
buffer (NaiveReplayBuffer): the buffer to collect experiences
|
||||
data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
|
||||
sample_buffer (bool, defaults to False): whether to sample from buffer
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
@@ -70,13 +70,13 @@ class OnPolicyTrainer(ABC):
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
buffer: NaiveReplayBuffer,
|
||||
data_buffer: NaiveExperienceBuffer,
|
||||
sample_buffer: bool,
|
||||
dataloader_pin_memory: bool,
|
||||
callbacks: List[Callback] = []) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.buffer = buffer
|
||||
self.data_buffer = data_buffer
|
||||
self.sample_buffer = sample_buffer
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
@@ -144,7 +144,7 @@ class OnPolicyTrainer(ABC):
|
||||
self._on_make_experience_start()
|
||||
experience = self._make_experience(collect_step)
|
||||
self._on_make_experience_end(experience)
|
||||
self.buffer.append(experience)
|
||||
self.data_buffer.append(experience)
|
||||
|
||||
def _update_phase(self, update_step: int):
|
||||
self._on_learn_epoch_start(update_step)
|
||||
@@ -181,8 +181,8 @@ class OnPolicyTrainer(ABC):
|
||||
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
|
||||
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
|
||||
# I only call strategy.setup_dataloader() to setup dataloader.
|
||||
self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
|
||||
self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory)
|
||||
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
|
||||
self._update_phase(update_step)
|
||||
# NOTE: this is for on-policy algorithms
|
||||
self.buffer.clear()
|
||||
self.data_buffer.clear()
|
||||
|
@@ -171,13 +171,13 @@ class PerformanceEvaluator(Callback):
|
||||
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
|
||||
|
||||
print_rank_0(
|
||||
f'Performance summary:\n' +
|
||||
f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
|
||||
+
|
||||
f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
|
||||
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' +
|
||||
f'Overall time per sample: {overall_time_per_sample:.2f} s\n' +
|
||||
f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
|
||||
+
|
||||
f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
|
||||
f'Performance summary:\n'
|
||||
+ f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
|
||||
|
||||
+ f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
|
||||
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
|
||||
+ f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
|
||||
+ f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
|
||||
)
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch.nn as nn
|
||||
from coati.experience_buffer import NaiveExperienceBuffer
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, get_base_model
|
||||
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
||||
from coati.models.utils import calc_action_log_probs
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
@@ -86,9 +86,9 @@ class PPOTrainer(OnPolicyTrainer):
|
||||
assert not offload_inference_models, \
|
||||
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||
|
||||
buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
super().__init__(
|
||||
strategy, buffer,
|
||||
strategy, data_buffer,
|
||||
sample_buffer, dataloader_pin_memory,
|
||||
callbacks
|
||||
)
|
||||
@@ -170,7 +170,7 @@ class PPOTrainer(OnPolicyTrainer):
|
||||
|
||||
# buffer may be empty at first, we should rebuild at each training
|
||||
if self.sample_buffer:
|
||||
experience = self.buffer.sample()
|
||||
experience = self.data_buffer.sample()
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(self.device)
|
||||
metrics = self._training_step(experience)
|
||||
|
@@ -4,7 +4,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.experience_buffer import ExperienceBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
@@ -45,7 +45,7 @@ class Strategy(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
pass
|
||||
|
||||
def model_init_context(self):
|
||||
|
@@ -4,7 +4,6 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
@@ -44,7 +43,7 @@ class LowLevelZeroStrategy(DDPStrategy):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
stage: int = 3,
|
||||
stage: int = 2,
|
||||
precision: str = 'fp16',
|
||||
seed: int = 42,
|
||||
placement_policy: str = 'cuda',
|
||||
@@ -214,14 +213,3 @@ class GeminiStrategy(DDPStrategy):
|
||||
ddp_model = model.unwrap()
|
||||
assert isinstance(ddp_model, GeminiDDP)
|
||||
return ddp_model.module
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
assert isinstance(self.plugin, GeminiPlugin)
|
||||
yield from super().get_model_state_dict_shard(model, **config)
|
||||
|
@@ -7,7 +7,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.experience_buffer import ExperienceBuffer
|
||||
from coati.models import Actor, Critic, RewardModel
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
@@ -71,13 +72,13 @@ class DDPStrategy(Strategy):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return self.plugin.prepare_dataloader(replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return self.plugin.prepare_dataloader(data_buffer,
|
||||
batch_size=data_buffer.sample_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
collate_fn=data_buffer.collate_fn)
|
||||
|
||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
|
||||
@@ -92,13 +93,33 @@ class DDPStrategy(Strategy):
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, PreTrainedModel)
|
||||
unwrapped_model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
if not only_rank0 or dist.get_rank() == 0:
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
|
||||
pretrained_model = unwrapped_model.model
|
||||
assert isinstance(pretrained_model, PreTrainedModel)
|
||||
# HACK: only use hf save_pretrained to save config
|
||||
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
model_path = os.path.join(path, "pytorch_model.bin")
|
||||
self.save_model(model,
|
||||
model_path,
|
||||
only_rank0=only_rank0)
|
||||
|
||||
def _replace_keys(model_path: str,
|
||||
replace_fn: Callable):
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
state_dict = {
|
||||
replace_fn(k): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
||||
# HACK: rename keys of pytorch_model.bin
|
||||
if dist.get_rank() == 0:
|
||||
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
# TODO: implement sharding on naive strategy
|
||||
|
@@ -27,7 +27,6 @@ class DistributedSampler:
|
||||
assert len(indices) == self.num_samples
|
||||
self.indices = indices
|
||||
|
||||
|
||||
def sample(self, batch_size: int) -> list:
|
||||
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
|
||||
return [self.dataset[idx] for idx in sampled_indices]
|
||||
|
@@ -21,9 +21,13 @@ class CycledDataLoader:
|
||||
self.dataloader = dataloader
|
||||
|
||||
self.count = 0
|
||||
self.dataloader_iter = iter(dataloader)
|
||||
self.dataloader_iter = None
|
||||
|
||||
def next(self):
|
||||
# defer initialization
|
||||
if self.dataloader_iter is None:
|
||||
self.dataloader_iter = iter(self.dataloader)
|
||||
|
||||
self.count += 1
|
||||
try:
|
||||
return next(self.dataloader_iter)
|
||||
|
Reference in New Issue
Block a user