fork file replace import

This commit is contained in:
csunny 2023-05-07 05:14:43 +08:00
parent 529f077409
commit 539e98f1dc
4 changed files with 144 additions and 16 deletions

View File

@ -3,6 +3,71 @@
import torch import torch
@torch.inference_mode()
def generate_stream(model, tokenizer, params, device,
context_len=2048, stream_interval=2):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
prompt = params["prompt"]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
output_ids = list(input_ids)
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
for i in range(max_new_tokens):
if i == 0:
out = model(
torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device=device)
out = model(input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
logits = out.logits
past_key_values = out.past_key_values
last_token_logits = logits[0][-1]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
output = tokenizer.decode(output_ids, skip_special_tokens=True)
pos = output.rfind(stop_str, l_prompt)
if pos != -1:
output = output[:pos]
stopped = True
yield output
if stopped:
break
del past_key_values
@torch.inference_mode() @torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2): def generate_output(model, tokenizer, params, device, context_len=2048, stream_interval=2):
prompt = params["prompt"] prompt = params["prompt"]

View File

@ -5,6 +5,7 @@ import torch
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModel
) )
from fastchat.serve.compression import compress_module from fastchat.serve.compression import compress_module
@ -23,20 +24,39 @@ class ModerLoader:
"device_map": "auto", "device_map": "auto",
} }
def loader(self, load_8bit=False, debug=False): def loader(self, num_gpus, load_8bit=False, debug=False):
if self.device == "cpu":
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False) kwargs = {}
model = AutoModelForCausalLM.from_pretrained(self.model_path, low_cpu_mem_usage=True, **self.kwargs) elif self.device == "cuda":
kwargs = {"torch_dtype": torch.float16}
if num_gpus == "auto":
kwargs["device_map"] = "auto"
else:
num_gpus = int(num_gpus)
if num_gpus != 1:
kwargs.update({
"device_map": "auto",
"max_memory": {i: "13GiB" for i in range(num_gpus)},
})
else:
raise ValueError(f"Invalid device: {self.device}")
if "chatglm" in self.model_path:
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda()
else:
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(self.model_path,
low_cpu_mem_usage=True, **kwargs)
if load_8bit:
compress_module(model, self.device)
if (self.device == "cuda" and num_gpus == 1):
model.to(self.device)
if debug: if debug:
print(model) print(model)
if load_8bit:
compress_module(model, self.device)
# if self.device == "cuda":
# model.to(self.device)
return model, tokenizer return model, tokenizer

View File

@ -7,10 +7,12 @@ import json
from typing import Optional, List from typing import Optional, List
from fastapi import FastAPI, Request, BackgroundTasks from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastchat.serve.inference import generate_stream from pilot.model.inference import generate_stream
from pydantic import BaseModel from pydantic import BaseModel
from pilot.model.inference import generate_output, get_embeddings from pilot.model.inference import generate_output, get_embeddings
from fastchat.serve.inference import load_model from fastchat.serve.inference import load_model
from pilot.model.loader import ModerLoader from pilot.model.loader import ModerLoader
from pilot.configs.model_config import * from pilot.configs.model_config import *
@ -20,9 +22,9 @@ model_path = LLM_MODEL_CONFIG[LLM_MODEL]
global_counter = 0 global_counter = 0
model_semaphore = None model_semaphore = None
# ml = ModerLoader(model_path=model_path) ml = ModerLoader(model_path=model_path)
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) model, tokenizer = ml.loader(num_gpus=1, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) #model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
class ModelWorker: class ModelWorker:
def __init__(self): def __init__(self):

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader
VECTOR_SEARCH_TOP_K = 5
class BaseKnownLedgeQA:
llm: object = None
embeddings: object = None
top_k: int = VECTOR_SEARCH_TOP_K
def __init__(self) -> None:
pass
def init_vector_store(self):
pass
def load_knownlege(self):
pass
def _load_file(self, filename):
# 加载文件
if filename.lower().endswith(".pdf"):
loader = UnstructuredFileLoader(filename)
text_splitor = CharacterTextSplitter()
docs = loader.load_and_split(text_splitor)
else:
loader = UnstructuredFileLoader(filename, mode="elements")
text_splitor = CharacterTextSplitter()
docs = loader.load_and_split(text_splitor)
return docs
def _load_from_url(self, url):
pass