mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[ColossalEval] Support GSM, Data Leakage Evaluation and Tensor Parallel (#5169)
* Support GSM, Data Leakage Evaluation and Tensor Parallel * remove redundant code and update inference.py in examples/gpt_evaluation --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
from .base import BaseModel
|
||||
|
||||
@@ -30,6 +31,7 @@ class HuggingFaceModel(BaseModel):
|
||||
prompt_template: The model's prompt template.
|
||||
batch_size: Batch size for inference.
|
||||
logger: Logger for the model.
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
|
||||
@@ -44,6 +46,7 @@ class HuggingFaceModel(BaseModel):
|
||||
prompt_template: Conversation = None,
|
||||
batch_size: int = 1,
|
||||
logger: DistributedLogger = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
super().__init__(
|
||||
path=path,
|
||||
@@ -54,7 +57,7 @@ class HuggingFaceModel(BaseModel):
|
||||
)
|
||||
self._load_tokenizer(path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs)
|
||||
|
||||
self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path)
|
||||
self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path, shard_config=shard_config)
|
||||
|
||||
def _get_choices_indices(self, language: str):
|
||||
"""
|
||||
@@ -100,7 +103,9 @@ class HuggingFaceModel(BaseModel):
|
||||
# Qwen has an eod token "<|endoftext|>".
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eod_id
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
|
||||
def _load_model(
|
||||
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
|
||||
):
|
||||
"""
|
||||
Load model.
|
||||
|
||||
@@ -108,17 +113,29 @@ class HuggingFaceModel(BaseModel):
|
||||
path: The path to the model.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
peft_path: The path to the peft model.
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
|
||||
if "torch_dtype" in model_kwargs:
|
||||
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
|
||||
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
if "config" in model_kwargs:
|
||||
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
|
||||
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
if shard_config is not None:
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model.to(torch.cuda.current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||
else:
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
self.model.eval()
|
||||
|
||||
def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> Tuple[List]:
|
||||
@@ -152,7 +169,7 @@ class HuggingFaceModel(BaseModel):
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
|
||||
|
||||
lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
|
||||
lens = (labels[..., 1:] != IGNORE_INDEX).sum(-1).cpu().numpy()
|
||||
|
||||
loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
|
||||
return loss_sum.tolist(), lens.tolist()
|
||||
@@ -239,7 +256,13 @@ class HuggingFaceModel(BaseModel):
|
||||
|
||||
"""
|
||||
if pretrain:
|
||||
return self._get_input_ids_and_labels_pretrain(batch_prompt)
|
||||
batch = []
|
||||
# Concatenate prompt and target answers.
|
||||
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
|
||||
for p, b in zip(batch_prompt, batch_target):
|
||||
batch.append(p + b[0])
|
||||
|
||||
return self._get_input_ids_and_labels_pretrain(batch)
|
||||
|
||||
input_ids_list = []
|
||||
labels_list = []
|
||||
@@ -380,7 +403,7 @@ class HuggingFaceModel(BaseModel):
|
||||
|
||||
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
|
||||
|
||||
probs = torch.nn.functional.softmax(scores, dim=-1).numpy().tolist()
|
||||
probs = scores.numpy().tolist()
|
||||
probs = [
|
||||
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
|
||||
]
|
||||
@@ -393,7 +416,7 @@ class HuggingFaceModel(BaseModel):
|
||||
answers[i + j]["output"] = batch_decodes[j].strip()
|
||||
|
||||
if isinstance(scores, torch.Tensor):
|
||||
answers[i + j]["softmax_over_choices"] = probs[j]
|
||||
answers[i + j]["logits_over_choices"] = probs[j]
|
||||
|
||||
if calculate_loss:
|
||||
answers[i + j]["loss_over_choices"] = loss_over_choices[j]
|
||||
@@ -445,7 +468,13 @@ class HuggingFaceModel(BaseModel):
|
||||
|
||||
# Set output_scores=True to get prediction scores.
|
||||
outputs = self.model.generate(
|
||||
**encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
|
||||
**encoded_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
do_sample=False,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# We only need to decode predicted tokens.
|
||||
@@ -540,10 +569,13 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
||||
prompt_template: The model's prompt template.
|
||||
batch_size: Batch size for inference.
|
||||
logger: Logger for the model.
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
|
||||
def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
|
||||
def _load_model(
|
||||
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
|
||||
):
|
||||
"""
|
||||
Load model.
|
||||
|
||||
@@ -551,17 +583,29 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
||||
path: The path to the model.
|
||||
model_kwargs: Keyword arguments for the model.
|
||||
peft_path: The path to the peft model.
|
||||
shard_config: Shard config for tensor parallel.
|
||||
|
||||
"""
|
||||
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
|
||||
if "torch_dtype" in model_kwargs:
|
||||
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
|
||||
|
||||
if "config" in model_kwargs:
|
||||
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
|
||||
|
||||
model_kwargs.setdefault("torch_dtype", torch.float16)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
if shard_config is not None:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model.to(torch.cuda.current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
|
||||
self.model.eval()
|
||||
|
Reference in New Issue
Block a user