mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-18 00:07:45 +00:00
checkout load_model to fastchat
This commit is contained in:
parent
4381e21d0c
commit
7c69dc248a
@ -84,4 +84,5 @@ def get_embeddings(model, tokenizer, prompt):
|
|||||||
input_embeddings = model.get_input_embeddings()
|
input_embeddings = model.get_input_embeddings()
|
||||||
embeddings = input_embeddings(torch.LongTensor([input_ids]))
|
embeddings = input_embeddings(torch.LongTensor([input_ids]))
|
||||||
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
mean = torch.mean(embeddings[0], 0).cpu().detach()
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pilot.model.inference import generate_output, get_embeddings
|
from pilot.model.inference import generate_output, get_embeddings
|
||||||
|
from fastchat.serve.inference import load_model
|
||||||
from pilot.model.loader import ModerLoader
|
from pilot.model.loader import ModerLoader
|
||||||
from pilot.configs.model_config import *
|
from pilot.configs.model_config import *
|
||||||
|
|
||||||
model_path = llm_model_config[LLM_MODEL]
|
model_path = llm_model_config[LLM_MODEL]
|
||||||
ml = ModerLoader(model_path=model_path)
|
# ml = ModerLoader(model_path=model_path)
|
||||||
model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||||
|
|
||||||
|
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
@ -45,4 +48,9 @@ def embeddings(prompt_request: EmbeddingRequest):
|
|||||||
params = {"prompt": prompt_request.prompt}
|
params = {"prompt": prompt_request.prompt}
|
||||||
print("Received prompt: ", params["prompt"])
|
print("Received prompt: ", params["prompt"])
|
||||||
output = get_embeddings(model, tokenizer, params["prompt"])
|
output = get_embeddings(model, tokenizer, params["prompt"])
|
||||||
return {"response": [float(x) for x in output]}
|
return {"response": [float(x) for x in output]}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", log_level="info")
|
Loading…
Reference in New Issue
Block a user