This commit is contained in:
csunny 2023-04-30 14:52:03 +08:00
parent e493f54804
commit 909045c0d6

View File

@ -9,11 +9,8 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
temperature = float(params.get("temperature", 1.0)) temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256)) max_new_tokens = int(params.get("max_new_tokens", 256))
stop_parameter = params.get("stop", None) stop_parameter = params.get("stop", None)
print(tokenizer.__dir__())
if stop_parameter == tokenizer.eos_token: if stop_parameter == tokenizer.eos_token:
stop_parameter = None stop_parameter = None
stop_strings = [] stop_strings = []
if isinstance(stop_parameter, str): if isinstance(stop_parameter, str):
stop_strings.append(stop_parameter) stop_strings.append(stop_parameter)
@ -43,13 +40,14 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
past_key_values=past_key_values, past_key_values=past_key_values,
) )
logits = out.logits logits = out.logits
past_key_values = out.past_key_value past_key_values = out.past_key_values
last_token_logits = logits[0][-1] last_token_logits = logits[0][-1]
if temperature < 1e-4: if temperature < 1e-4:
token = int(torch.argmax(last_token_logits)) token = int(torch.argmax(last_token_logits))
else: else:
probs = torch.softmax(last_token_logits / temperature, dim=1) probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1)) token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token) output_ids.append(token)
@ -64,12 +62,12 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
pos = output.rfind(stop_str) pos = output.rfind(stop_str)
if pos != -1: if pos != -1:
output = output[:pos] output = output[:pos]
stoppped = True stopped = True
break break
else: else:
pass pass
if stoppped: if stopped:
break break
del past_key_values del past_key_values