mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 09:01:06 +00:00
[chat]: update rm, add wandb and fix bugs (#4471)
* feat: modify forward fn of critic and reward model * feat: modify calc_action_log_probs * to: add wandb in sft and rm trainer * feat: update train_sft * feat: update train_rm * style: modify type annotation and add warning * feat: pass tokenizer to ppo trainer * to: modify trainer base and maker base * feat: add wandb in ppo trainer * feat: pass tokenizer to generate * test: update generate fn tests * test: update train tests * fix: remove action_mask * feat: remove unused code * fix: fix wrong ignore_index * fix: fix mock tokenizer * chore: update requirements * revert: modify make_experience * fix: fix inference * fix: add padding side * style: modify _on_learn_batch_end * test: use mock tokenizer * fix: use bf16 to avoid overflow * fix: fix workflow * [chat] fix gemini strategy * [chat] fix * sync: update colossalai strategy * fix: fix args and model dtype * fix: fix checkpoint test * fix: fix requirements * fix: fix missing import and wrong arg * fix: temporarily skip gemini test in stage 3 * style: apply pre-commit * fix: temporarily skip gemini test in stage 1&2 --------- Co-authored-by: Mingyan Jiang <1829166702@qq.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Dict, Sequence, Tuple
|
||||
from typing import Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
@@ -57,6 +57,7 @@ def _preprocess(
|
||||
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
|
||||
labels = copy.deepcopy(sequences_token["input_ids"])
|
||||
for i in range(labels.shape[0]):
|
||||
source_len = sources_token["attention_mask"][i].sum().item()
|
||||
@@ -64,9 +65,10 @@ def _preprocess(
|
||||
if tokenizer.padding_side == "right":
|
||||
# |prompt|completion|eos|pad|
|
||||
labels[i][:source_len] = IGNORE_INDEX
|
||||
labels[i][-pad_len:] = IGNORE_INDEX
|
||||
elif tokenizer.padding_side == "left":
|
||||
# |pad|prompt|completion|eos|
|
||||
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
|
||||
labels[i][: pad_len + source_len] = IGNORE_INDEX
|
||||
else:
|
||||
raise RuntimeError()
|
||||
|
||||
@@ -126,6 +128,8 @@ class SFTDataset(Dataset):
|
||||
|
||||
sources = [data["prompt"] for data in dataset]
|
||||
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
|
||||
sources, targets, tokenizer, max_length
|
||||
@@ -133,6 +137,8 @@ class SFTDataset(Dataset):
|
||||
else:
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
logger.info("Loaded dataset.")
|
||||
|
||||
def __len__(self):
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
@@ -148,7 +154,11 @@ class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(
|
||||
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_datasets_size: Optional[int] = None,
|
||||
max_length: int = 512,
|
||||
):
|
||||
super().__init__()
|
||||
logger.info("Loading data...")
|
||||
@@ -175,6 +185,8 @@ class SupervisedDataset(Dataset):
|
||||
else:
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
logger.info("Loaded dataset.")
|
||||
|
||||
def __len__(self):
|
||||
length = self.input_ids.shape[0]
|
||||
return length
|
||||
|
Reference in New Issue
Block a user