fix multi device error

This commit is contained in:
csunny 2023-04-30 14:47:33 +08:00
parent bb6c1865e1
commit e493f54804

View File

@ -81,8 +81,9 @@ def generate_output(model, tokenizer, params, device, context_len=2048):
@torch.inference_mode()
def get_embeddings(model, tokenizer, prompt):
input_ids = tokenizer(prompt).input_ids
input_embeddings = model.get_input_embeddings()
embeddings = input_embeddings(torch.LongTensor([input_ids]))
mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_embeddings = model.get_input_embeddings().to(device)
embeddings = input_embeddings(torch.LongTensor([input_ids]).to(device))
mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean.to(device)