mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-03 17:19:51 +00:00
fix: fix sft (#3568)
This commit is contained in:
parent
6e7e43c6fe
commit
7788e0b0a5
@ -53,29 +53,25 @@ class SFTDataset(Dataset):
|
|||||||
|
|
||||||
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
|
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# self.prompts = []
|
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
|
|
||||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||||
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
|
prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
|
||||||
prompt_token = tokenizer(prompt,
|
prompt_token = tokenizer(prompt,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="pt")
|
return_tensors="pt")
|
||||||
|
|
||||||
# self.prompts.append(prompt_token)s
|
self.input_ids.append(prompt_token['input_ids'][0])
|
||||||
self.input_ids.append(prompt_token)
|
self.labels = copy.deepcopy(self.input_ids)
|
||||||
self.labels = copy.deepcopy(self.input_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
length = len(self.prompts)
|
length = len(self.input_ids)
|
||||||
return length
|
return length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
|
||||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||||
# return dict(self.prompts[idx], self.prompts[idx])
|
|
||||||
|
|
||||||
|
|
||||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
|
||||||
|
@ -96,7 +96,7 @@ class SFTTrainer(ABC):
|
|||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
prompt_logits = outputs.logits
|
prompt_logits = outputs.logits
|
||||||
|
|
||||||
if loss >= 2.5:
|
if loss >= 2.5 and is_rank_0():
|
||||||
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
|
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
|
||||||
|
|
||||||
loss = loss / self.accimulation_steps
|
loss = loss / self.accimulation_steps
|
||||||
@ -110,12 +110,13 @@ class SFTTrainer(ABC):
|
|||||||
self.strategy.optimizer_step(self.optimizer)
|
self.strategy.optimizer_step(self.optimizer)
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
wandb.log({
|
if is_rank_0():
|
||||||
"loss": total_loss / self.accimulation_steps,
|
wandb.log({
|
||||||
"lr": self.scheduler.get_last_lr()[0],
|
"loss": total_loss / self.accimulation_steps,
|
||||||
"epoch": epoch,
|
"lr": self.scheduler.get_last_lr()[0],
|
||||||
"batch_id": batch_id
|
"epoch": epoch,
|
||||||
})
|
"batch_id": batch_id
|
||||||
|
})
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ def train(args):
|
|||||||
max_datasets_size=args.max_datasets_size,
|
max_datasets_size=args.max_datasets_size,
|
||||||
max_length=max_len)
|
max_length=max_len)
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||||
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
train_sampler = DistributedSampler(train_dataset,
|
train_sampler = DistributedSampler(train_dataset,
|
||||||
|
Loading…
Reference in New Issue
Block a user