fix multi device error

This commit is contained in:
csunny 2023-04-30 14:47:33 +08:00
parent bb6c1865e1
commit e493f54804

View File

@ -81,8 +81,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
@torch.inference_mode() @torch.inference_mode()
def get_embeddings(model, tokenizer, prompt): def get_embeddings(model, tokenizer, prompt):
input_ids = tokenizer(prompt).input_ids input_ids = tokenizer(prompt).input_ids
input_embeddings = model.get_input_embeddings() device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embeddings = input_embeddings(torch.LongTensor([input_ids])) input_embeddings = model.get_input_embeddings().to(device)
mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean.to(device)