[chat]: update rm, add wandb and fix bugs (#4471)

* feat: modify forward fn of critic and reward model

* feat: modify calc_action_log_probs

* to: add wandb in sft and rm trainer

* feat: update train_sft

* feat: update train_rm

* style: modify type annotation and add warning

* feat: pass tokenizer to ppo trainer

* to: modify trainer base and maker base

* feat: add wandb in ppo trainer

* feat: pass tokenizer to generate

* test: update generate fn tests

* test: update train tests

* fix: remove action_mask

* feat: remove unused code

* fix: fix wrong ignore_index

* fix: fix mock tokenizer

* chore: update requirements

* revert: modify make_experience

* fix: fix inference

* fix: add padding side

* style: modify _on_learn_batch_end

* test: use mock tokenizer

* fix: use bf16 to avoid overflow

* fix: fix workflow

* [chat] fix gemini strategy

* [chat] fix

* sync: update colossalai strategy

* fix: fix args and model dtype

* fix: fix checkpoint test

* fix: fix requirements

* fix: fix missing import and wrong arg

* fix: temporarily skip gemini test in stage 3

* style: apply pre-commit

* fix: temporarily skip gemini test in stage 1&2

---------

Co-authored-by: Mingyan Jiang <1829166702@qq.com>
This commit is contained in:
Wenhao Chen
2023-09-20 15:53:58 +08:00
committed by GitHub
parent 07c2e3d09c
commit 7b9b86441f
36 changed files with 382 additions and 332 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import copy
from typing import Dict, Sequence, Tuple
from typing import Dict, Optional, Sequence, Tuple
import torch
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@@ -57,6 +57,7 @@ def _preprocess(
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
source_len = sources_token["attention_mask"][i].sum().item()
@@ -64,9 +65,10 @@ def _preprocess(
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
labels[i][-pad_len:] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
labels[i][: pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()
@@ -126,6 +128,8 @@ class SFTDataset(Dataset):
sources = [data["prompt"] for data in dataset]
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
logger.info("Tokenizing inputs... This may take some time...")
if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
sources, targets, tokenizer, max_length
@@ -133,6 +137,8 @@ class SFTDataset(Dataset):
else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
logger.info("Loaded dataset.")
def __len__(self):
length = self.input_ids.shape[0]
return length
@@ -148,7 +154,11 @@ class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
max_datasets_size: Optional[int] = None,
max_length: int = 512,
):
super().__init__()
logger.info("Loading data...")
@@ -175,6 +185,8 @@ class SupervisedDataset(Dataset):
else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
logger.info("Loaded dataset.")
def __len__(self):
length = self.input_ids.shape[0]
return length