[ColossalEval] Support GSM, Data Leakage Evaluation and Tensor Parallel (#5169)

* Support GSM, Data Leakage Evaluation and Tensor Parallel

* remove redundant code and update inference.py in examples/gpt_evaluation

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen
2023-12-12 14:47:35 +08:00
committed by GitHub
parent b07a6f4e27
commit cefdc32615
19 changed files with 578 additions and 100 deletions

View File

@@ -10,6 +10,7 @@ from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from colossalai.logging import DistributedLogger
from colossalai.shardformer import ShardConfig, ShardFormer
from .base import BaseModel
@@ -30,6 +31,7 @@ class HuggingFaceModel(BaseModel):
prompt_template: The model's prompt template.
batch_size: Batch size for inference.
logger: Logger for the model.
shard_config: Shard config for tensor parallel.
"""
@@ -44,6 +46,7 @@ class HuggingFaceModel(BaseModel):
prompt_template: Conversation = None,
batch_size: int = 1,
logger: DistributedLogger = None,
shard_config: ShardConfig = None,
):
super().__init__(
path=path,
@@ -54,7 +57,7 @@ class HuggingFaceModel(BaseModel):
)
self._load_tokenizer(path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs)
self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path)
self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path, shard_config=shard_config)
def _get_choices_indices(self, language: str):
"""
@@ -100,7 +103,9 @@ class HuggingFaceModel(BaseModel):
# Qwen has an eod token "<|endoftext|>".
self.tokenizer.pad_token_id = self.tokenizer.eod_id
def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
def _load_model(
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
):
"""
Load model.
@@ -108,17 +113,29 @@ class HuggingFaceModel(BaseModel):
path: The path to the model.
model_kwargs: Keyword arguments for the model.
peft_path: The path to the peft model.
shard_config: Shard config for tensor parallel.
"""
model_kwargs.setdefault("torch_dtype", torch.float16)
if "torch_dtype" in model_kwargs:
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
model_kwargs.setdefault("torch_dtype", torch.float16)
if "config" in model_kwargs:
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
if shard_config is not None:
self.model = AutoModel.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model)
self.model.to(torch.cuda.current_device())
if peft_path is not None:
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
else:
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
self.model.eval()
def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> Tuple[List]:
@@ -152,7 +169,7 @@ class HuggingFaceModel(BaseModel):
loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
lens = (labels[..., 1:] != IGNORE_INDEX).sum(-1).cpu().numpy()
loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
return loss_sum.tolist(), lens.tolist()
@@ -239,7 +256,13 @@ class HuggingFaceModel(BaseModel):
"""
if pretrain:
return self._get_input_ids_and_labels_pretrain(batch_prompt)
batch = []
# Concatenate prompt and target answers.
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
for p, b in zip(batch_prompt, batch_target):
batch.append(p + b[0])
return self._get_input_ids_and_labels_pretrain(batch)
input_ids_list = []
labels_list = []
@@ -380,7 +403,7 @@ class HuggingFaceModel(BaseModel):
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
probs = torch.nn.functional.softmax(scores, dim=-1).numpy().tolist()
probs = scores.numpy().tolist()
probs = [
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
]
@@ -393,7 +416,7 @@ class HuggingFaceModel(BaseModel):
answers[i + j]["output"] = batch_decodes[j].strip()
if isinstance(scores, torch.Tensor):
answers[i + j]["softmax_over_choices"] = probs[j]
answers[i + j]["logits_over_choices"] = probs[j]
if calculate_loss:
answers[i + j]["loss_over_choices"] = loss_over_choices[j]
@@ -445,7 +468,13 @@ class HuggingFaceModel(BaseModel):
# Set output_scores=True to get prediction scores.
outputs = self.model.generate(
**encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
**encoded_inputs,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True,
do_sample=False,
use_cache=True,
**kwargs,
)
# We only need to decode predicted tokens.
@@ -540,10 +569,13 @@ class HuggingFaceCausalLM(HuggingFaceModel):
prompt_template: The model's prompt template.
batch_size: Batch size for inference.
logger: Logger for the model.
shard_config: Shard config for tensor parallel.
"""
def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
def _load_model(
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
):
"""
Load model.
@@ -551,17 +583,29 @@ class HuggingFaceCausalLM(HuggingFaceModel):
path: The path to the model.
model_kwargs: Keyword arguments for the model.
peft_path: The path to the peft model.
shard_config: Shard config for tensor parallel.
"""
model_kwargs.setdefault("torch_dtype", torch.float16)
if "torch_dtype" in model_kwargs:
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
if "config" in model_kwargs:
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
model_kwargs.setdefault("torch_dtype", torch.float16)
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
if shard_config is not None:
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model)
self.model.to(torch.cuda.current_device())
if peft_path is not None:
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
else:
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
self.model.eval()