Adapt code

This commit is contained in:
mudler
2023-06-01 14:37:14 +02:00
committed by AT
parent fca2578a81
commit 79cef86bec
6 changed files with 24 additions and 90 deletions

View File

@@ -2,14 +2,11 @@
#include "../../gpt4all-backend/llmodel.h"
#include "../../gpt4all-backend/llama.cpp/llama.h"
#include "../../gpt4all-backend/llmodel_c.cpp"
#include "../../gpt4all-backend/mpt.h"
#include "../../gpt4all-backend/mpt.cpp"
#include "../../gpt4all-backend/llamamodel.h"
#include "../../gpt4all-backend/gptj.h"
#include "binding.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdio>
#include <cstring>
#include <fstream>
@@ -19,46 +16,24 @@
#include <iostream>
#include <unistd.h>
void* load_mpt_model(const char *fname, int n_threads) {
void* load_gpt4all_model(const char *fname, int n_threads) {
// load the model
auto gptj = llmodel_mpt_create();
llmodel_setThreadCount(gptj, n_threads);
if (!llmodel_loadModel(gptj, fname)) {
auto gptj4all = llmodel_model_create(fname);
if (gptj4all == NULL ){
return nullptr;
}
llmodel_setThreadCount(gptj4all, n_threads);
if (!llmodel_loadModel(gptj4all, fname)) {
return nullptr;
}
return gptj;
}
void* load_llama_model(const char *fname, int n_threads) {
// load the model
auto gptj = llmodel_llama_create();
llmodel_setThreadCount(gptj, n_threads);
if (!llmodel_loadModel(gptj, fname)) {
return nullptr;
}
return gptj;
}
void* load_gptj_model(const char *fname, int n_threads) {
// load the model
auto gptj = llmodel_gptj_create();
llmodel_setThreadCount(gptj, n_threads);
if (!llmodel_loadModel(gptj, fname)) {
return nullptr;
}
return gptj;
return gptj4all;
}
std::string res = "";
void * mm;
void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
float top_p, float temp, int n_batch,float ctx_erase)
{
llmodel_model* model = (llmodel_model*) m;
@@ -120,8 +95,8 @@ void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_la
free(prompt_context);
}
void gptj_free_model(void *state_ptr) {
void gpt4all_free_model(void *state_ptr) {
llmodel_model* ctx = (llmodel_model*) state_ptr;
llmodel_llama_destroy(ctx);
llmodel_model_destroy(*ctx);
}