mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[feature] ColossalEval: Evaluation Pipeline for LLMs (#4786)
* Add ColossalEval * Delete evaluate in Chat --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
78
applications/ColossalEval/colossal_eval/models/base.py
Normal file
78
applications/ColossalEval/colossal_eval/models/base.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from abc import abstractclassmethod
|
||||
from typing import Dict, List
|
||||
|
||||
from colossal_eval.utils import Conversation, prompt_templates
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
class BaseModel:
|
||||
"""
|
||||
Base class for model wrapper.
|
||||
|
||||
Args:
|
||||
path: The path to the model.
|
||||
model_max_length: The maximum sequence length of the model.
|
||||
prompt_template: The model's prompt template.
|
||||
batch_size: Batch size for inference.
|
||||
logger: Logger for the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
model_max_length: int = 2048,
|
||||
prompt_template: Conversation = None,
|
||||
batch_size: int = 1,
|
||||
logger: DistributedLogger = None,
|
||||
):
|
||||
self.path = path
|
||||
self.model_max_length = model_max_length
|
||||
|
||||
if prompt_template:
|
||||
self.prompt_template = prompt_template
|
||||
else:
|
||||
self.prompt_template = prompt_templates["plain"]
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.logger = logger
|
||||
|
||||
@abstractclassmethod
|
||||
def inference(self, data: List[Dict]) -> None:
|
||||
"""
|
||||
Infer the given data.
|
||||
This function will call self.generate() to get model outputs and also self.model(input) to get logits.
|
||||
|
||||
Args:
|
||||
data: The data for inference.
|
||||
"""
|
||||
|
||||
@abstractclassmethod
|
||||
def generate(self, inputs: List[str], max_new_tokens: int) -> List[str]:
|
||||
"""
|
||||
Generate results given a list of inputs.
|
||||
|
||||
Args:
|
||||
inputs: A list of strings.
|
||||
max_new_tokens: The maximum length of the output.
|
||||
|
||||
Returns:
|
||||
A list of generated strings.
|
||||
"""
|
||||
|
||||
@abstractclassmethod
|
||||
def get_loss(self, batch: List[str], batch_target: List[str]) -> List[float]:
|
||||
"""
|
||||
Get loss given batch and batch with target.
|
||||
Use their length difference after tokenization to mask the loss and only compute loss at target tokens.
|
||||
|
||||
Args:
|
||||
batch: batch prompt without target answer.
|
||||
batch_target: batch prompt with target answer.
|
||||
|
||||
Returns:
|
||||
A list of loss.
|
||||
"""
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
Reference in New Issue
Block a user