mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
fix logprob, add filtering, temperature annealing, lr descent
This commit is contained in:
@@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
)
|
||||
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
|
||||
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
):
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
model_config.update(self.FORCE_MODEL_CONFIG)
|
||||
path = model_config.pop("path")
|
||||
@@ -61,7 +67,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
self.generate_config = generate_config.copy()
|
||||
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = 8
|
||||
self.num_generations = num_generations
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
@@ -120,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
|
||||
|
||||
class SGLangInferenceBackend(BaseInferenceBackend):
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
):
|
||||
if sgl is None:
|
||||
raise ImportError("sglang is not installed")
|
||||
path = model_config.pop("path")
|
||||
@@ -175,10 +187,15 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
)
|
||||
FORCE_GENERATE_CONFIG = dict(
|
||||
logprobs=0,
|
||||
n=8,
|
||||
)
|
||||
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
):
|
||||
if LLM is None:
|
||||
raise ImportError("vllm is not installed")
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
@@ -186,9 +203,10 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
self.llm = LLM(model=path, **model_config)
|
||||
generate_config = generate_config.copy()
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
generate_config.update({"n": num_generations})
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
|
||||
self.num_generations = num_generations
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
|
Reference in New Issue
Block a user