diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index cea1b2dbb..f701cfdf9 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -187,6 +187,14 @@ class DataCollatorForPreferenceDataset(object): f"but now `{self.tokenizer.pad_token_id}`" ) + torch.set_printoptions(profile="full") + + for ins in instances: + if sum(ins["chosen_loss_mask"][1:]) == 0: + print("Before truncated", ins["chosen_loss_mask"], len(ins["chosen_loss_mask"])) + if sum(ins["rejected_loss_mask"][1:]) == 0: + print("Before truncated", ins["rejected_loss_mask"], len(ins["rejected_loss_mask"])) + ( chosen_input_ids, chosen_loss_mask, # [batch_size * seq_len] @@ -199,6 +207,23 @@ class DataCollatorForPreferenceDataset(object): chuncate_sequence([ins["rejected_loss_mask"] for ins in instances], self.max_length, torch.bool), ) + for i in range(len(chosen_loss_mask)): + if sum(chosen_loss_mask[i][1:]) == 0: + print( + "After truncated", + chosen_loss_mask[i], + len(chosen_loss_mask[i]), + len(instances[i]["chosen_input_ids"]), + ) + for i in range(len(reject_loss_mask)): + if sum(reject_loss_mask[i][1:]) == 0: + print( + "After truncated", + reject_loss_mask[i], + len(reject_loss_mask[i]), + len(instances[i]["rejected_input_ids"]), + ) + padding_side = self.tokenizer.padding_side chosen_attention_mask = [torch.ones_like(seq).bool() for seq in chosen_input_ids] reject_attention_mask = [torch.ones_like(seq).bool() for seq in reject_input_ids] diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 34828cbaf..27addcb0d 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -73,9 +73,12 @@ def supervised_tokenize_sft( lo, hi = 0, len(turns) while lo < hi: mid = (lo + hi) // 2 - if max_length - 1 < len( - tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0] - ): + prompt = template.get_prompt(2 * turns[mid] - 1) + chunks, require_loss = split_templated_prompt_into_chunks( + template.messages[: 2 * turns[mid] - 1], prompt, conversation_template.end_of_assistant + ) + tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) + if max_length - 1 < len(tokenized): hi = mid else: lo = mid + 1 @@ -114,6 +117,7 @@ def supervised_tokenize_sft( to_truncate_len += 1 else: break + to_truncate_len = max(len(tokenized) - max_length, to_truncate_len) tokenized = tokenized[: len(tokenized) - to_truncate_len] labels = labels[: len(labels) - to_truncate_len] @@ -356,48 +360,24 @@ def tokenize_rlhf( rejected_loss_mask, rejected_label_decode, ) = (None, None, None, None, None, None) - if ( - len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0]) - <= max_length - 1 - and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0]) - <= max_length - 1 - ): - chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context) - (chosen_input_ids, chosen_loss_mask, chosen_label_decode) = ( - chosen_data_packed["input_ids"], - chosen_data_packed["loss_mask"], - chosen_data_packed["label_decode"], - ) - rejected_data_packed = apply_rlhf_data_format( - rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True - ) - (rejected_input_ids, rejected_loss_mask, rejected_label_decode) = ( - rejected_data_packed["input_ids"], - rejected_data_packed["loss_mask"], - rejected_data_packed["label_decode"], - ) + chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context) + (chosen_input_ids, chosen_loss_mask, chosen_label_decode) = ( + chosen_data_packed["input_ids"], + chosen_data_packed["loss_mask"], + chosen_data_packed["label_decode"], + ) - # Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long - if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask): - return dict( - chosen_input_ids=None, - chosen_loss_mask=None, - chosen_label_decode=None, - rejected_input_ids=None, - rejected_loss_mask=None, - rejected_label_decode=None, - ) + rejected_data_packed = apply_rlhf_data_format( + rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True + ) + (rejected_input_ids, rejected_loss_mask, rejected_label_decode) = ( + rejected_data_packed["input_ids"], + rejected_data_packed["loss_mask"], + rejected_data_packed["label_decode"], + ) - return { - "chosen_input_ids": chosen_input_ids, - "chosen_loss_mask": chosen_loss_mask, - "chosen_label_decode": chosen_label_decode, - "rejected_input_ids": rejected_input_ids, - "rejected_loss_mask": rejected_loss_mask, - "rejected_label_decode": rejected_label_decode, - } - else: + if len(chosen_input_ids) > max_length or len(rejected_input_ids) > max_length: return dict( chosen_input_ids=None, chosen_loss_mask=None, @@ -406,3 +386,22 @@ def tokenize_rlhf( rejected_loss_mask=None, rejected_label_decode=None, ) + # Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long + if chosen_loss_mask[1:].count(1) == 0 or rejected_loss_mask[1:].count(1) == 0: + return dict( + chosen_input_ids=None, + chosen_loss_mask=None, + chosen_label_decode=None, + rejected_input_ids=None, + rejected_loss_mask=None, + rejected_label_decode=None, + ) + + return { + "chosen_input_ids": chosen_input_ids, + "chosen_loss_mask": chosen_loss_mask, + "chosen_label_decode": chosen_label_decode, + "rejected_input_ids": rejected_input_ids, + "rejected_loss_mask": rejected_loss_mask, + "rejected_label_decode": rejected_label_decode, + } diff --git a/applications/ColossalChat/coati/models/utils.py b/applications/ColossalChat/coati/models/utils.py index e3df0b148..8ed8d3401 100755 --- a/applications/ColossalChat/coati/models/utils.py +++ b/applications/ColossalChat/coati/models/utils.py @@ -109,8 +109,6 @@ def calc_masked_log_probs( if not length_normalization: return log_probs * mask else: - if torch.any(mask.sum(dim=-1) == 0): - print("Mask should not be all zeros.") return log_probs * mask / (mask.sum(dim=-1, keepdim=True) + 0.01) diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.sh b/applications/ColossalChat/examples/training_scripts/train_dpo.sh index 5eba46be8..af5a04e2a 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.sh +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.sh @@ -57,7 +57,7 @@ colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 31313 train_ --beta 0.1 \ --mixed_precision "bf16" \ --grad_clip 1.0 \ - --max_length 1024 \ + --max_length 4096 \ --weight_decay 0.01 \ --warmup_steps 60 \ --grad_checkpoint \ diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh index 04c3b4814..d5ba6261e 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.sh +++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh @@ -15,24 +15,24 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { # export CUDA_VISIBLE_DEVICES=4,5,6 -set_n_least_used_CUDA_VISIBLE_DEVICES 2 +set_n_least_used_CUDA_VISIBLE_DEVICES 4 PROJECT_NAME="sft" -PARENT_SAVE_DIR="/home/yeanbang/data/experiment/rlhf_cont/dpo/ckpt" # Path to a folder to save checkpoints -PARENT_TENSORBOARD_DIR="/home/yeanbang/data/experiment/rlhf_cont/dpo/log" # Path to a folder to save logs -PARENT_CONFIG_FILE="/home/yeanbang/data/experiment/rlhf_cont/dpo/log" # Path to a folder to save training config logs -PRETRAINED_MODEL_PATH="/home/yeanbang/data/models/Sheared-LLaMA-1.3B" # huggingface or local model path -PRETRAINED_TOKENIZER_PATH="/home/yeanbang/data/models/Sheared-LLaMA-1.3B" # huggingface or local tokenizer path +PARENT_SAVE_DIR="" # Path to a folder to save checkpoints +PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs +PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path declare -a dataset=( - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00000 - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00001 - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00002 - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00003 - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00004 - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00005 - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00006 - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00007 - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00008 - /home/yeanbang/data/experiment/rlhf_cont/dpo/dataset_tokenized/sft/arrow/part-00009 + /Your/Preference/Data/arrow/part-00000 + /Your/Preference/Data/arrow/part-00001 + /Your/Preference/Data/arrow/part-00002 + /Your/Preference/Data/arrow/part-00003 + /Your/Preference/Data/arrow/part-00004 + /Your/Preference/Data/arrow/part-00005 + /Your/Preference/Data/arrow/part-00006 + /Your/Preference/Data/arrow/part-00007 + /Your/Preference/Data/arrow/part-00008 + /Your/Preference/Data/arrow/part-00009 ) TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) @@ -43,7 +43,7 @@ CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" echo $(which colossalai) echo $(which python) # the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size -colossalai run --nproc_per_node 1 --master_port 31312 --hostfile ./hostfile train_sft.py \ +colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \ --pretrain $PRETRAINED_MODEL_PATH \ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ --save_interval 4000 \ @@ -56,7 +56,7 @@ colossalai run --nproc_per_node 1 --master_port 31312 --hostfile ./hostfile trai --max_epochs 1 \ --accumulation_steps 4 \ --lr 5e-5 \ - --max_len 1000 \ + --max_len 4096 \ --grad_checkpoint \ --use_wandb \ --use_flash_attn