checkout load_model to fastchat

This commit is contained in:
csunny 2023-04-30 14:21:58 +08:00
parent 4381e21d0c
commit 7c69dc248a
2 changed files with 13 additions and 4 deletions

View File

@ -84,4 +84,5 @@ def get_embeddings(model, tokenizer, prompt):
input_embeddings = model.get_input_embeddings() input_embeddings = model.get_input_embeddings()
embeddings = input_embeddings(torch.LongTensor([input_ids])) embeddings = input_embeddings(torch.LongTensor([input_ids]))
mean = torch.mean(embeddings[0], 0).cpu().detach() mean = torch.mean(embeddings[0], 0).cpu().detach()
return mean return mean

View File

@ -1,17 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import uvicorn
from typing import Optional, List from typing import Optional, List
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import BaseModel from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings from pilot.model.inference import generate_output, get_embeddings
from fastchat.serve.inference import load_model
from pilot.model.loader import ModerLoader from pilot.model.loader import ModerLoader
from pilot.configs.model_config import * from pilot.configs.model_config import *
model_path = llm_model_config[LLM_MODEL] model_path = llm_model_config[LLM_MODEL]
ml = ModerLoader(model_path=model_path) # ml = ModerLoader(model_path=model_path)
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) # model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
app = FastAPI() app = FastAPI()
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
@ -45,4 +48,9 @@ def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt} params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"]) print("Received prompt: ", params["prompt"])
output = get_embeddings(model, tokenizer, params["prompt"]) output = get_embeddings(model, tokenizer, params["prompt"])
return {"response": [float(x) for x in output]} return {"response": [float(x) for x in output]}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", log_level="info")