diff --git a/gpt4all-api/triton/README.md b/gpt4all-api/triton/README.md new file mode 100644 index 00000000..21ba3c94 --- /dev/null +++ b/gpt4all-api/triton/README.md @@ -0,0 +1,5 @@ +# To Run Inference Server + +docker run --gpus=1 --rm --net=host -v ${PWD}/model_store:/model_store nvcr.io/nvidia/tritonserver:23.01-py3 tritonserver --model-repository=/model_store + +python client.py --model= \ No newline at end of file diff --git a/gpt4all-api/triton/client.py b/gpt4all-api/triton/client.py new file mode 100644 index 00000000..21e24b6b --- /dev/null +++ b/gpt4all-api/triton/client.py @@ -0,0 +1,85 @@ +import torch +import tritonclient.grpc.aio as grpcclient + + +def prepare_inference_inputs( + inputs_ids: torch.IntTensor, new_tokens: int = 1, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0 +): + batch_size = inputs_ids.shape[0] + + input_ids_input = grpcclient.InferInput("input_ids", inputs_ids.shape, "INT32") + input_ids_input.set_data_from_numpy(inputs_ids.int().cpu().numpy()) + + new_tokens_input = grpcclient.InferInput( + "tensor_of_seq_len", [batch_size, new_tokens], "INT32" + ) + new_tokens_input.set_data_from_numpy( + torch.zeros(batch_size, new_tokens, dtype=torch.int32).cpu().numpy() + ) + + temperature_input = grpcclient.InferInput("temperature", [batch_size, 1], "FP32") + temperature_input.set_data_from_numpy( + torch.full([batch_size, 1], temperature, dtype=torch.float32).cpu().numpy() + ) + + top_k_input = grpcclient.InferInput("top_k", [batch_size, 1], "INT32") + top_k_input.set_data_from_numpy( + torch.full([batch_size, 1], top_k, dtype=torch.int32).cpu().numpy() + ) + + top_p_input = grpcclient.InferInput("top_p", [batch_size, 1], "FP32") + top_p_input.set_data_from_numpy( + torch.full([batch_size, 1], top_p, dtype=torch.float32).cpu().numpy() + ) + + inputs = [input_ids_input, new_tokens_input, temperature_input, top_k_input, top_p_input] + outputs = [ + grpcclient.InferRequestedOutput("logits"), + grpcclient.InferRequestedOutput("output_ids"), + ] + return inputs, outputs + + +async def infer( + triton_client, model_name, input_ids, new_tokens: int = 1, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0 +): + inputs, outputs = prepare_inference_inputs(input_ids, new_tokens, temperature, top_k, top_p) + + triton_model_name = model_name.replace("/", "--") + + result = await triton_client.infer( + model_name=triton_model_name, inputs=inputs, outputs=outputs + ) + + logits = torch.tensor(result.as_numpy("logits").copy(), requires_grad=False) + output_ids = torch.tensor(result.as_numpy("output_ids").copy(), requires_grad=False) + + return logits, output_ids + +def Client(url: str): + return grpcclient.InferenceServerClient(url=url) + +if __name__ == "__main__": + import argparse + from transformers import AutoTokenizer + + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="localhost:8001") + parser.add_argument("--model", type=str, default="gpt2") + + args = parser.parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.model) + + async def main(): + async with Client(args.url) as triton_client: + while True: + prompt = input("Prompt: ") + input_ids = tokenizer.encode(prompt, return_tensors="pt") + last_logits, output_ids = await infer( + triton_client, args.model, input_ids, new_tokens=128, temperature=1.0, top_k=0, top_p=0.9, + ) + print(tokenizer.decode(output_ids[0])) + + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/gpt4all-api/triton/convert_to_triton.py b/gpt4all-api/triton/convert_to_triton.py new file mode 100644 index 00000000..c1e6560c --- /dev/null +++ b/gpt4all-api/triton/convert_to_triton.py @@ -0,0 +1,158 @@ +import argparse +import os +from string import Template + +import torch +from torch import nn +from transformers import AutoModelForCausalLM, AutoTokenizer +from gpt4all.falcon.modelling_RW import RWForCausalLM + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--model", type=str, required=True, help="Path to HF checkpoint with the base model" +) + +parser.add_argument( + "--max-batch-size", type=int, default=4, help="Maximum batch size for inference" +) + +parser.add_argument( + "--revision", + type=str, + required=False, + help="Optional branch/commit of the HF checkpoint", +) + +parser.add_argument("--device", type=int, default=0) +args = parser.parse_args() + +device = torch.device(args.device) + + +class ModelLogits(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + @torch.inference_mode() + def forward(self, input_ids: torch.Tensor): + return self.model(input_ids).logits + + +class InferModel(nn.Module): + def __init__(self, traced_model, eos_token_id): + super().__init__() + self.traced_model = traced_model + self.eos_token_id = eos_token_id + + def forward( + self, + input_ids: torch.Tensor, + tensor_of_seq_len: torch.Tensor, + temperature: torch.Tensor, + top_k: torch.Tensor, + top_p: torch.Tensor, + ): + with torch.no_grad(): + for _ in range(tensor_of_seq_len.shape[1] - 1): + logits = self.traced_model(input_ids).float() + next_token_logits = logits[:, -1, :] + next_token_logits = next_token_logits / temperature + + next_token_logits = self.top_k(next_token_logits, top_k) + next_token_logits = self.top_p(next_token_logits, top_p) + + next_token = torch.multinomial( + torch.softmax(next_token_logits, dim=-1), 1 + ).squeeze(1) + # early break + if next_token.item() == self.eos_token_id: + return input_ids.int(), logits + + input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1) + + # in TorchScript, the above logits var lifetime doesn't escape the loop's scope + logits = self.traced_model(input_ids).float() + next_token_logits = logits[:, -1, :] + next_token_logits = next_token_logits / temperature + + next_token_logits = self.top_k(next_token_logits, top_k) + next_token_logits = self.top_p(next_token_logits, top_p) + + next_token = torch.multinomial( + torch.softmax(next_token_logits, dim=-1), 1 + ).squeeze(1) + + input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1) + + return input_ids.int(), logits + + def top_p(self, scores: torch.Tensor, top_p: torch.Tensor): + if top_p.squeeze().item() >= 1.0: + return scores + sorted_logits, sorted_indices = torch.sort(scores, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores[indices_to_remove] = float("-inf") + return scores + + + def top_k(self, scores: torch.Tensor, top_k: torch.Tensor): + if top_k.squeeze().item() <= 0: + return scores + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < torch.topk(scores, top_k.squeeze().item())[0][..., -1, None] + scores[indices_to_remove] = float("-inf") + return scores + + +print(f"Converting {args.model} to TorchScript...") +tokenizer = AutoTokenizer.from_pretrained(args.model) +model = ModelLogits(AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True, revision=args.revision)) +model.eval() +model.requires_grad_(False) +model = model.half().to(device) + +input = tokenizer("annotator model's hash is 0x", return_tensors="pt").to(device) +print(f"{model(input.input_ids)=}") + +traced_script_module = torch.jit.trace(model, input.input_ids) + +print(f"{traced_script_module(input.input_ids)=}") + +print("Scripting generation wrapper...") + +# need to script this as we have data conditional flow +scripted_generator_model = torch.jit.script(InferModel(traced_script_module, tokenizer.eos_token_id)) +print(scripted_generator_model.code) + +print(f"{input.input_ids=}") +x = input.input_ids, torch.empty(1, 5), torch.full([1, 1], 1.0).cuda(), torch.full([1, 1], len(tokenizer) // 2).cuda(), torch.full([1, 1], 0.9).cuda() +# x = input.input_ids, torch.empty(1, 5), torch.full([1, 1], 1.0), torch.full([1, 1], len(tokenizer) // 2), torch.full([1, 1], 0.9) +# print(f"{(scripted_generator_model(*x))=}") +print(f"{tokenizer.decode(scripted_generator_model(*x)[0][0])=}") + +sanitized_name = args.model.replace("/", "--") +print("Model renamed to ", sanitized_name) + +print("Saving TorchScript model...") + +os.makedirs(f"model_store/{sanitized_name}/1", exist_ok=True) +scripted_generator_model.save(f"model_store/{sanitized_name}/1/traced-model.pt") + +config_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "triton_config.pbtxt" +) +with open(config_path) as f: + template = Template(f.read()) +config = template.substitute( + {"model_name": sanitized_name, "max_batch_size": args.max_batch_size} +) +with open(f"model_store/{sanitized_name}/config.pbtxt", "w") as f: + f.write(config) \ No newline at end of file diff --git a/gpt4all-api/triton/requirements.txt b/gpt4all-api/triton/requirements.txt new file mode 100644 index 00000000..450e71f3 --- /dev/null +++ b/gpt4all-api/triton/requirements.txt @@ -0,0 +1,5 @@ +transformers +triton +triton-client +einops +pandas \ No newline at end of file diff --git a/gpt4all-api/triton/triton_config.pbtxt b/gpt4all-api/triton/triton_config.pbtxt new file mode 100644 index 00000000..35417c61 --- /dev/null +++ b/gpt4all-api/triton/triton_config.pbtxt @@ -0,0 +1,78 @@ +name: "${model_name}" +backend: "pytorch" +default_model_filename: "traced-model.pt" +max_batch_size: ${max_batch_size} + +dynamic_batching { } + +parameters { + key: "model_name" + value: { + string_value: "${model_name}" + } +} + +instance_group [ + { + count: 1 + kind: KIND_GPU + gpus: [0] + } +] + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "tensor_of_seq_len" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [-1] + }, + { + name: "top_k" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [-1] + } +] + +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "logits" + data_type: TYPE_FP32 + dims: [-1] + } +] + +parameters { + key: "data_type" + value: { + string_value: "fp16" + } +} + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +} + +version_policy: {specific: {versions: [1]}} \ No newline at end of file