ColossalAI/colossalai/legacy/inference/dynamic_batching/sampling_params.py
Xu Kai fd6482ad8c
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
2023-11-19 21:05:05 +08:00

84 lines
3.2 KiB
Python

# Adapted from https://github.com/ModelTC/lightllm
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingParams:
def __init__(
self,
do_sample: bool = False,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1, # -1 is for all
ignore_eos: bool = False,
max_new_tokens: int = 256,
stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.ignore_eos = ignore_eos
self.max_new_tokens = max_new_tokens
self.stop_sequences = stop_sequences
if self.do_sample == False:
self.temperature = 1.0
self.top_p = 1.0
self.top_k = 1
if (
self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS
): # temperature is too slow, change to greedy search
self.temperature = 1.0
self.top_k = 1
return
def verify(self):
if self.presence_penalty < 0.0:
raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}")
if self.frequency_penalty < 0.0:
raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}")
if self.temperature <= 0.0:
raise ValueError(f"temperature must > 0.0, got {self.temperature}")
if self.top_p <= 0.0 or self.top_p > 1.0:
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
return
def stop_sentences_to_token_ids(self, tokenizer):
if self.stop_sequences is None:
self.stop_sequences = []
else:
if isinstance(self.stop_sequences, str):
self.stop_sequences = [self.stop_sequences]
new_stop_sequences = []
for stop_str in self.stop_sequences:
stop_str_ids = tokenizer.encode(stop_str)
if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
stop_str_ids = stop_str_ids[1:]
if len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
self.stop_sequences = new_stop_sequences
return
def to_dict(self):
ret = {}
ret["do_sample"] = self.do_sample
ret["presence_penalty"] = self.presence_penalty
ret["frequency_penalty"] = self.frequency_penalty
ret["temperature"] = self.temperature
ret["top_p"] = self.top_p
ret["top_k"] = self.top_k
# if self.ignore_eos is not None:
# ret["ignore_eos"] = self.ignore_eos
return ret