This commit is contained in:
YeAnbang 2025-04-23 14:46:33 +08:00
commit ca6093a582
3 changed files with 10 additions and 6 deletions

View File

@ -352,12 +352,14 @@ def apply_chat_template_and_mask(
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
chat: List[Dict[str, str]], chat: List[Dict[str, str]],
max_length: Optional[int] = None, max_length: Optional[int] = None,
system_prompt: str = None,
padding: bool = True, padding: bool = True,
truncation: bool = True, truncation: bool = True,
ignore_idx: int = -100, ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n" if system_prompt is None:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
system_element = { system_element = {
"role": "system", "role": "system",
@ -419,7 +421,7 @@ class RawConversationDataset(Dataset):
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`. Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
""" """
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None: def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.raw_texts = [] self.raw_texts = []
with jsonlines.open(input_file) as f: with jsonlines.open(input_file) as f:
@ -427,6 +429,7 @@ class RawConversationDataset(Dataset):
self.raw_texts.append(line) self.raw_texts.append(line)
self.tokenized_texts = [None] * len(self.raw_texts) self.tokenized_texts = [None] * len(self.raw_texts)
self.max_length = max_length self.max_length = max_length
self.system_prompt = system_prompt
def __len__(self) -> int: def __len__(self) -> int:
return len(self.raw_texts) return len(self.raw_texts)
@ -434,6 +437,6 @@ class RawConversationDataset(Dataset):
def __getitem__(self, index: int): def __getitem__(self, index: int):
if self.tokenized_texts[index] is None: if self.tokenized_texts[index] is None:
message = self.raw_texts[index] message = self.raw_texts[index]
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length) tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
self.tokenized_texts[index] = dict(tokens) self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index] return self.tokenized_texts[index]

View File

@ -122,7 +122,7 @@ class BaseConsumer:
assert len(self.buffer) == 0 assert len(self.buffer) == 0
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
self.lr_scheduler.step() self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0: if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
if self.rank == 0: if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.") print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}") save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")

View File

@ -55,6 +55,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
) )
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
args = parser.parse_args() args = parser.parse_args()
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0" assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
@ -149,13 +150,13 @@ if __name__ == "__main__":
num_producers=args.num_inferencer, num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1), num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
num_consumer_procs=args.num_trainers, num_consumer_procs=args.num_trainers,
num_episodes=10, num_episodes=1,
inference_batch_size=args.inference_batch_size, inference_batch_size=args.inference_batch_size,
inference_microbatch_size=args.inference_microbatch_size, inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size, train_minibatch_size=args.train_minibatch_size,
train_microbatch_size=args.train_microbatch_size, train_microbatch_size=args.train_microbatch_size,
dataset_config={"path": args.dataset, "max_length": 300}, dataset_config={"path": args.dataset, "max_length": 300, "system_prompt": args.system_prompt},
dataloaders_config={}, dataloaders_config={},
inference_model_config=inference_model_config, inference_model_config=inference_model_config,
generate_config=generate_config, generate_config=generate_config,