mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-03 09:14:33 +00:00
spot a possible bug
This commit is contained in:
parent
6fff36dd63
commit
4d18e7d772
@ -80,8 +80,9 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||||||
if self.num_generations > 1:
|
if self.num_generations > 1:
|
||||||
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
|
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
|
||||||
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
|
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
|
||||||
|
generate_config = kwargs.get("generate_config", self.generate_config)
|
||||||
out = self.model.generate(
|
out = self.model.generate(
|
||||||
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
|
input_ids, attention_mask=attention_mask, **kwargs, **generate_config, tokenizer=self.tokenizer
|
||||||
)
|
)
|
||||||
input_len = input_ids.shape[-1]
|
input_len = input_ids.shape[-1]
|
||||||
new_token_ids = out.sequences[:, input_len:]
|
new_token_ids = out.sequences[:, input_len:]
|
||||||
|
@ -93,12 +93,14 @@ class BaseProducer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.eval_dataset_config = eval_dataset_config
|
self.eval_dataset_config = eval_dataset_config
|
||||||
if self.eval_dataset_config is not None and self.eval_interval > 0:
|
if self.eval_dataset_config is not None:
|
||||||
self.eval_dataloaders = {}
|
self.eval_dataloaders = {}
|
||||||
for eval_task_name in self.eval_dataset_config:
|
for eval_task_name in self.eval_dataset_config:
|
||||||
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
|
|
||||||
eval_dataset = RawConversationDataset(
|
eval_dataset = RawConversationDataset(
|
||||||
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
|
self.tokenizer,
|
||||||
|
eval_dataset_config[eval_task_name]["path"],
|
||||||
|
eval_dataset_config[eval_task_name]["max_length"],
|
||||||
|
eval_dataset_config[eval_task_name]["system_prompt"],
|
||||||
)
|
)
|
||||||
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
|
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
|
||||||
self.eval_dataloaders[eval_task_name] = DataLoader(
|
self.eval_dataloaders[eval_task_name] = DataLoader(
|
||||||
@ -171,7 +173,14 @@ class BaseProducer:
|
|||||||
for eval_batch_id, eval_batch in tqdm.tqdm(
|
for eval_batch_id, eval_batch in tqdm.tqdm(
|
||||||
enumerate(self.eval_dataloaders[eval_task_name]), desc=f"Evaluating: {eval_task_name}"
|
enumerate(self.eval_dataloaders[eval_task_name]), desc=f"Evaluating: {eval_task_name}"
|
||||||
):
|
):
|
||||||
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
if isinstance(self.model, BACKEND_MAP["vllm"]):
|
||||||
|
eval_outputs = self.rollout(
|
||||||
|
**eval_batch, sample_params=self.eval_sample_params[eval_task_name]
|
||||||
|
)
|
||||||
|
elif isinstance(self.model, BACKEND_MAP["transformers"]):
|
||||||
|
eval_outputs = self.rollout(
|
||||||
|
**eval_batch, generate_config=self.eval_generation_config[eval_task_name]
|
||||||
|
)
|
||||||
eval_results = eval_results + [
|
eval_results = eval_results + [
|
||||||
self.evaluation_function(
|
self.evaluation_function(
|
||||||
eval_outputs["input_ids"][m][n],
|
eval_outputs["input_ids"][m][n],
|
||||||
@ -179,6 +188,7 @@ class BaseProducer:
|
|||||||
eval_outputs["response_idx"][m][n],
|
eval_outputs["response_idx"][m][n],
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
eval_mode=True,
|
eval_mode=True,
|
||||||
|
max_new_tokens=self.eval_dataset_config[eval_task_name]["max_new_tokens"],
|
||||||
)
|
)
|
||||||
for m in range(eval_outputs["input_ids"].size(0))
|
for m in range(eval_outputs["input_ids"].size(0))
|
||||||
for n in range(eval_outputs["input_ids"].size(1))
|
for n in range(eval_outputs["input_ids"].size(1))
|
||||||
@ -302,11 +312,29 @@ class SimpleProducer(BaseProducer):
|
|||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
)
|
)
|
||||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
self.eval_sample_params = {}
|
||||||
self.eval_generation_config.update(
|
for eval_task_name in eval_dataset_config:
|
||||||
{"n": 1, "temperature": 0.6, "top_p": 0.95}
|
eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||||
) # use 1 generation for evaluation
|
if isinstance(self.model, BACKEND_MAP["vllm"]):
|
||||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
eval_generation_config.update(
|
||||||
|
{
|
||||||
|
"n": 1,
|
||||||
|
"temperature": eval_dataset_config[eval_task_name]["temperature"],
|
||||||
|
"top_p": eval_dataset_config[eval_task_name]["top_p"],
|
||||||
|
"top_k": eval_dataset_config[eval_task_name]["top_k"],
|
||||||
|
}
|
||||||
|
) # use 1 generation for evaluation
|
||||||
|
self.eval_sample_params[eval_task_name] = SamplingParams(**eval_generation_config)
|
||||||
|
elif isinstance(self.model, BACKEND_MAP["transformers"]):
|
||||||
|
eval_generation_config.update(
|
||||||
|
{
|
||||||
|
"num_return_sequences": 1,
|
||||||
|
"temperature": eval_dataset_config[eval_task_name]["temperature"],
|
||||||
|
"top_p": eval_dataset_config[eval_task_name]["top_p"],
|
||||||
|
"top_k": eval_dataset_config[eval_task_name]["top_k"],
|
||||||
|
}
|
||||||
|
) # use 1 generation for evaluation
|
||||||
|
self.eval_generation_config[eval_task_name] = copy.deepcopy(eval_generation_config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
|
@ -246,7 +246,7 @@ if __name__ == "__main__":
|
|||||||
# "zero_stage": 2,
|
# "zero_stage": 2,
|
||||||
# }, # for zero
|
# }, # for zero
|
||||||
plugin_config={
|
plugin_config={
|
||||||
"tp_size": 4,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"microbatch_size": max(
|
"microbatch_size": max(
|
||||||
1, args.train_microbatch_size // 2
|
1, args.train_microbatch_size // 2
|
||||||
@ -261,7 +261,15 @@ if __name__ == "__main__":
|
|||||||
save_interval=args.save_interval,
|
save_interval=args.save_interval,
|
||||||
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
||||||
eval_dataset_config={
|
eval_dataset_config={
|
||||||
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
|
k: {
|
||||||
|
"path": v["path"],
|
||||||
|
"max_length": args.max_prompt_tokens + args.max_new_tokens,
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"system_prompt": args.system_prompt,
|
||||||
|
"temperature": v.get("temperature", 0.6),
|
||||||
|
"top_p": v.get("top_p", 0.95),
|
||||||
|
"top_k": v.get("top_k", 50),
|
||||||
|
}
|
||||||
for k, v in json.loads(args.eval_dataset).items()
|
for k, v in json.loads(args.eval_dataset).items()
|
||||||
},
|
},
|
||||||
eval_interval=args.eval_interval,
|
eval_interval=args.eval_interval,
|
||||||
|
Loading…
Reference in New Issue
Block a user