diff --git a/gpt4all-bindings/golang/Makefile b/gpt4all-bindings/golang/Makefile new file mode 100644 index 00000000..bdb70527 --- /dev/null +++ b/gpt4all-bindings/golang/Makefile @@ -0,0 +1,172 @@ +INCLUDE_PATH := $(abspath ./) +LIBRARY_PATH := $(abspath ./) +CMAKEFLAGS= + +ifndef UNAME_S +UNAME_S := $(shell uname -s) +endif + +ifndef UNAME_P +UNAME_P := $(shell uname -p) +endif + +ifndef UNAME_M +UNAME_M := $(shell uname -m) +endif + +CCV := $(shell $(CC) --version | head -n 1) +CXXV := $(shell $(CXX) --version | head -n 1) + +# Mac OS + Arm can report x86_64 +# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789 +ifeq ($(UNAME_S),Darwin) + ifneq ($(UNAME_P),arm) + SYSCTL_M := $(shell sysctl -n hw.optional.arm64 2>/dev/null) + ifeq ($(SYSCTL_M),1) + # UNAME_P := arm + # UNAME_M := arm64 + warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789) + endif + endif +endif + +# +# Compile flags +# + +# keep standard at C11 and C++11 +CFLAGS = -I. -I../../gpt4all-backend/llama.cpp -I../../gpt4all-backend -I -O3 -DNDEBUG -std=c11 -fPIC +CXXFLAGS = -I. -I../../gpt4all-backend/llama.cpp -I../../gpt4all-backend -O3 -DNDEBUG -std=c++17 -fPIC +LDFLAGS = + +# warnings +CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function +CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar + +# OS specific +# TODO: support Windows +ifeq ($(UNAME_S),Linux) + CFLAGS += -pthread + CXXFLAGS += -pthread +endif +ifeq ($(UNAME_S),Darwin) + CFLAGS += -pthread + CXXFLAGS += -pthread +endif +ifeq ($(UNAME_S),FreeBSD) + CFLAGS += -pthread + CXXFLAGS += -pthread +endif +ifeq ($(UNAME_S),NetBSD) + CFLAGS += -pthread + CXXFLAGS += -pthread +endif +ifeq ($(UNAME_S),OpenBSD) + CFLAGS += -pthread + CXXFLAGS += -pthread +endif +ifeq ($(UNAME_S),Haiku) + CFLAGS += -pthread + CXXFLAGS += -pthread +endif + +# Architecture specific +# TODO: probably these flags need to be tweaked on some architectures +# feel free to update the Makefile for your architecture and send a pull request or issue +ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) + # Use all CPU extensions that are available: + CFLAGS += -march=native -mtune=native + CXXFLAGS += -march=native -mtune=native +endif +ifneq ($(filter ppc64%,$(UNAME_M)),) + POWER9_M := $(shell grep "POWER9" /proc/cpuinfo) + ifneq (,$(findstring POWER9,$(POWER9_M))) + CFLAGS += -mcpu=power9 + CXXFLAGS += -mcpu=power9 + endif + # Require c++23's std::byteswap for big-endian support. + ifeq ($(UNAME_M),ppc64) + CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN + endif +endif +ifndef LLAMA_NO_ACCELERATE + # Mac M1 - include Accelerate framework. + # `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time). + ifeq ($(UNAME_S),Darwin) + CFLAGS += -DGGML_USE_ACCELERATE + LDFLAGS += -framework Accelerate + endif +endif +ifdef LLAMA_OPENBLAS + CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas + LDFLAGS += -lopenblas +endif +ifdef LLAMA_GPROF + CFLAGS += -pg + CXXFLAGS += -pg +endif +ifneq ($(filter aarch64%,$(UNAME_M)),) + CFLAGS += -mcpu=native + CXXFLAGS += -mcpu=native +endif +ifneq ($(filter armv6%,$(UNAME_M)),) + # Raspberry Pi 1, 2, 3 + CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access +endif +ifneq ($(filter armv7%,$(UNAME_M)),) + # Raspberry Pi 4 + CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations +endif +ifneq ($(filter armv8%,$(UNAME_M)),) + # Raspberry Pi 4 + CFLAGS += -mfp16-format=ieee -mno-unaligned-access +endif + +# +# Print build information +# + +$(info I go-gpt4all build info: ) +$(info I UNAME_S: $(UNAME_S)) +$(info I UNAME_P: $(UNAME_P)) +$(info I UNAME_M: $(UNAME_M)) +$(info I CFLAGS: $(CFLAGS)) +$(info I CXXFLAGS: $(CXXFLAGS)) +$(info I LDFLAGS: $(LDFLAGS)) +$(info I CMAKEFLAGS: $(CMAKEFLAGS)) +$(info I CC: $(CCV)) +$(info I CXX: $(CXXV)) +$(info ) + +llama.o: + mkdir buildllama + cd buildllama && cmake ../../../gpt4all-backend/llama.cpp $(CMAKEFLAGS) && make VERBOSE=1 llama.o && cp -rf CMakeFiles/llama.dir/llama.cpp.o ../llama.o + +llmodel.o: + mkdir buildllm + cd buildllm && cmake ../../../gpt4all-backend/ $(CMAKEFLAGS) && make VERBOSE=1 llmodel ggml common + cd buildllm && cp -rf CMakeFiles/llmodel.dir/llmodel_c.cpp.o ../llmodel.o + cd buildllm && cp -rfv CMakeFiles/llmodel.dir/llama.cpp/examples/common.cpp.o ../common.o + cd buildllm && cp -rf CMakeFiles/llmodel.dir/gptj.cpp.o ../gptj.o + cd buildllm && cp -rf CMakeFiles/llmodel.dir/llamamodel.cpp.o ../llamamodel.o + cd buildllm && cp -rf CMakeFiles/llmodel.dir/utils.cpp.o ../utils.o + cd buildllm && cp -rf llama.cpp/CMakeFiles/ggml.dir/ggml.c.o ../ggml.o + +clean: + rm -f *.o + rm -f *.a + rm -rf buildllm + rm -rf buildllama + rm -rf example/main + +binding.o: + $(CXX) $(CXXFLAGS) binding.cpp -o binding.o -c $(LDFLAGS) + +libgpt4all.a: binding.o llmodel.o llama.o + ar src libgpt4all.a ggml.o common.o llama.o llamamodel.o utils.o llmodel.o gptj.o binding.o + +test: libgpt4all.a + @C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./... + +example/main: libgpt4all.a + C_INCLUDE_PATH=$(INCLUDE_PATH) LIBRARY_PATH=$(INCLUDE_PATH) go build -o example/main ./example/ diff --git a/gpt4all-bindings/golang/README.md b/gpt4all-bindings/golang/README.md new file mode 100644 index 00000000..9ba346bd --- /dev/null +++ b/gpt4all-bindings/golang/README.md @@ -0,0 +1,58 @@ +# GPT4All Golang bindings + +The golang bindings has been tested on: +- MacOS +- Linux + +### Usage + +``` +import ( + "github.com/nomic/gpt4all/gpt4all-bindings/golang" +) + +func main() { + // Load the model + model, err := gpt4all.New("model.bin", gpt4all.SetModelType(gpt4all.GPTJType)) + if err != nil { + panic(err) + } + defer model.Free() + + model.SetTokenCallback(func(s string) bool { + fmt.Print(s) + return true + }) + + _, err = model.Predict("Here are 4 steps to create a website:", gpt4all.SetTemperature(0.1)) + if err != nil { + panic(err) + } +} +``` + +## Building + +In order to use the bindings you will need to build `libgpt4all.a`: + +``` +git clone https://github.com/nomic-ai/gpt4all +cd gpt4all/gpt4all-bindings/golang +make libgpt4all.a +``` + +To use the bindings in your own software: + +- Import `github.com/nomic/gpt4all/gpt4all-bindings/golang`; +- Compile `libgpt4all.a` (you can use `make libgpt4all.a` in the bindings/go directory); +- Link your go binary against whisper by setting the environment variables `C_INCLUDE_PATH` and `LIBRARY_PATH` to point to the `binding.h` file directory and `libgpt4all.a` file directory respectively. + +## Testing + +To run tests, run `make test`: + +``` +git clone https://github.com/nomic-ai/gpt4all +cd gpt4all/gpt4all-bindings/golang +make test +``` \ No newline at end of file diff --git a/gpt4all-bindings/golang/binding.cpp b/gpt4all-bindings/golang/binding.cpp new file mode 100644 index 00000000..867117ef --- /dev/null +++ b/gpt4all-bindings/golang/binding.cpp @@ -0,0 +1,127 @@ +#include "../../gpt4all-backend/llmodel_c.h" +#include "../../gpt4all-backend/llmodel.h" +#include "../../gpt4all-backend/llama.cpp/llama.h" +#include "../../gpt4all-backend/llmodel_c.cpp" +#include "../../gpt4all-backend/mpt.h" +#include "../../gpt4all-backend/mpt.cpp" + +#include "../../gpt4all-backend/llamamodel.h" +#include "../../gpt4all-backend/gptj.h" +#include "binding.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void* load_mpt_model(const char *fname, int n_threads) { + // load the model + auto gptj = llmodel_mpt_create(); + + llmodel_setThreadCount(gptj, n_threads); + if (!llmodel_loadModel(gptj, fname)) { + return nullptr; + } + + return gptj; +} + +void* load_llama_model(const char *fname, int n_threads) { + // load the model + auto gptj = llmodel_llama_create(); + + llmodel_setThreadCount(gptj, n_threads); + if (!llmodel_loadModel(gptj, fname)) { + return nullptr; + } + + return gptj; +} + +void* load_gptj_model(const char *fname, int n_threads) { + // load the model + auto gptj = llmodel_gptj_create(); + + llmodel_setThreadCount(gptj, n_threads); + if (!llmodel_loadModel(gptj, fname)) { + return nullptr; + } + + return gptj; +} + +std::string res = ""; +void * mm; + +void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, + float top_p, float temp, int n_batch,float ctx_erase) +{ + llmodel_model* model = (llmodel_model*) m; + + // std::string res = ""; + + auto lambda_prompt = [](int token_id, const char *promptchars) { + return true; + }; + + mm=model; + res=""; + + auto lambda_response = [](int token_id, const char *responsechars) { + res.append((char*)responsechars); + return !!getTokenCallback(mm, (char*)responsechars); + }; + + auto lambda_recalculate = [](bool is_recalculating) { + // You can handle recalculation requests here if needed + return is_recalculating; + }; + + llmodel_prompt_context* prompt_context = new llmodel_prompt_context{ + .logits = NULL, + .logits_size = 0, + .tokens = NULL, + .tokens_size = 0, + .n_past = 0, + .n_ctx = 1024, + .n_predict = 50, + .top_k = 10, + .top_p = 0.9, + .temp = 1.0, + .n_batch = 1, + .repeat_penalty = 1.2, + .repeat_last_n = 10, + .context_erase = 0.5 + }; + + prompt_context->n_predict = tokens; + prompt_context->repeat_last_n = repeat_last_n; + prompt_context->repeat_penalty = repeat_penalty; + prompt_context->n_ctx = n_ctx; + prompt_context->top_k = top_k; + prompt_context->context_erase = ctx_erase; + prompt_context->top_p = top_p; + prompt_context->temp = temp; + prompt_context->n_batch = n_batch; + + llmodel_prompt(model, prompt, + lambda_prompt, + lambda_response, + lambda_recalculate, + prompt_context ); + + strcpy(result, res.c_str()); + + free(prompt_context); +} + +void gptj_free_model(void *state_ptr) { + llmodel_model* ctx = (llmodel_model*) state_ptr; + llmodel_llama_destroy(ctx); +} + diff --git a/gpt4all-bindings/golang/binding.h b/gpt4all-bindings/golang/binding.h new file mode 100644 index 00000000..6b49a03e --- /dev/null +++ b/gpt4all-bindings/golang/binding.h @@ -0,0 +1,22 @@ +#ifdef __cplusplus +extern "C" { +#endif + +#include + +void* load_mpt_model(const char *fname, int n_threads); + +void* load_llama_model(const char *fname, int n_threads); + +void* load_gptj_model(const char *fname, int n_threads); + +void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, + float top_p, float temp, int n_batch,float ctx_erase); + +void gptj_free_model(void *state_ptr); + +extern unsigned char getTokenCallback(void *, char *); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/gpt4all-bindings/golang/example/main.go b/gpt4all-bindings/golang/example/main.go new file mode 100644 index 00000000..f3a103a7 --- /dev/null +++ b/gpt4all-bindings/golang/example/main.go @@ -0,0 +1,82 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + "io" + "os" + "runtime" + "strings" + + gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" +) + +var ( + threads = 4 + tokens = 128 +) + +func main() { + var model string + + flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flags.StringVar(&model, "m", "./models/7B/ggml-model-q4_0.bin", "path to q4_0.bin model file to load") + flags.IntVar(&threads, "t", runtime.NumCPU(), "number of threads to use during computation") + flags.IntVar(&tokens, "n", 512, "number of tokens to predict") + + err := flags.Parse(os.Args[1:]) + if err != nil { + fmt.Printf("Parsing program arguments failed: %s", err) + os.Exit(1) + } + l, err := gpt4all.New(model, gpt4all.SetModelType(gpt4all.GPTJType), gpt4all.SetThreads(threads)) + if err != nil { + fmt.Println("Loading the model failed:", err.Error()) + os.Exit(1) + } + fmt.Printf("Model loaded successfully.\n") + + l.SetTokenCallback(func(token string) bool { + fmt.Print(token) + return true + }) + + reader := bufio.NewReader(os.Stdin) + + for { + text := readMultiLineInput(reader) + + _, err := l.Predict(text, gpt4all.SetTokens(tokens), gpt4all.SetTopK(90), gpt4all.SetTopP(0.86)) + if err != nil { + panic(err) + } + fmt.Printf("\n\n") + } +} + +// readMultiLineInput reads input until an empty line is entered. +func readMultiLineInput(reader *bufio.Reader) string { + var lines []string + fmt.Print(">>> ") + + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + os.Exit(0) + } + fmt.Printf("Reading the prompt failed: %s", err) + os.Exit(1) + } + + if len(strings.TrimSpace(line)) == 0 { + break + } + + lines = append(lines, line) + } + + text := strings.Join(lines, "") + return text +} diff --git a/gpt4all-bindings/golang/go.mod b/gpt4all-bindings/golang/go.mod new file mode 100644 index 00000000..e45c3dad --- /dev/null +++ b/gpt4all-bindings/golang/go.mod @@ -0,0 +1,20 @@ +module github.com/nomic-ai/gpt4all/gpt4all-bindings/golang + +go 1.19 + +require ( + github.com/onsi/ginkgo/v2 v2.9.4 + github.com/onsi/gomega v1.27.6 +) + +require ( + github.com/go-logr/logr v1.2.4 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/google/go-cmp v0.5.9 // indirect + github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect + golang.org/x/net v0.9.0 // indirect + golang.org/x/sys v0.7.0 // indirect + golang.org/x/text v0.9.0 // indirect + golang.org/x/tools v0.8.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/gpt4all-bindings/golang/go.sum b/gpt4all-bindings/golang/go.sum new file mode 100644 index 00000000..fa0bcd86 --- /dev/null +++ b/gpt4all-bindings/golang/go.sum @@ -0,0 +1,40 @@ +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= +github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= +github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= +golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/gpt4all-bindings/golang/gpt4all.go b/gpt4all-bindings/golang/gpt4all.go new file mode 100644 index 00000000..2da5dee9 --- /dev/null +++ b/gpt4all-bindings/golang/gpt4all.go @@ -0,0 +1,113 @@ +package gpt4all + +// #cgo CFLAGS: -I../../gpt4all-backend/ -I../../gpt4all-backend/llama.cpp -I./ +// #cgo CXXFLAGS: -std=c++17 -I../../gpt4all-backend/ -I../../gpt4all-backend/llama.cpp -I./ +// #cgo darwin LDFLAGS: -framework Accelerate +// #cgo darwin CXXFLAGS: -std=c++17 +// #cgo LDFLAGS: -lgpt4all -lm -lstdc++ +// void* load_mpt_model(const char *fname, int n_threads); +// void* load_llama_model(const char *fname, int n_threads); +// void* load_gptj_model(const char *fname, int n_threads); +// void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, +// float top_p, float temp, int n_batch,float ctx_erase); +// void gptj_free_model(void *state_ptr); +// extern unsigned char getTokenCallback(void *, char *); +import "C" +import ( + "fmt" + "runtime" + "strings" + "sync" + "unsafe" +) + +// The following code is https://github.com/go-skynet/go-llama.cpp with small adaptations +type Model struct { + state unsafe.Pointer +} + +func New(model string, opts ...ModelOption) (*Model, error) { + ops := NewModelOptions(opts...) + var state unsafe.Pointer + + switch ops.ModelType { + case LLaMAType: + state = C.load_llama_model(C.CString(model), C.int(ops.Threads)) + case GPTJType: + state = C.load_gptj_model(C.CString(model), C.int(ops.Threads)) + case MPTType: + state = C.load_mpt_model(C.CString(model), C.int(ops.Threads)) + } + + if state == nil { + return nil, fmt.Errorf("failed loading model") + } + + gpt := &Model{state: state} + // set a finalizer to remove any callbacks when the struct is reclaimed by the garbage collector. + runtime.SetFinalizer(gpt, func(g *Model) { + setTokenCallback(g.state, nil) + }) + + return gpt, nil +} + +func (l *Model) Predict(text string, opts ...PredictOption) (string, error) { + + po := NewPredictOptions(opts...) + + input := C.CString(text) + if po.Tokens == 0 { + po.Tokens = 99999999 + } + out := make([]byte, po.Tokens) + + C.gptj_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize), + C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase)) + + res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) + res = strings.TrimPrefix(res, " ") + res = strings.TrimPrefix(res, text) + res = strings.TrimPrefix(res, "\n") + res = strings.TrimSuffix(res, "<|endoftext|>") + + return res, nil +} + +func (l *Model) Free() { + C.gptj_free_model(l.state) +} + +func (l *Model) SetTokenCallback(callback func(token string) bool) { + setTokenCallback(l.state, callback) +} + +var ( + m sync.Mutex + callbacks = map[uintptr]func(string) bool{} +) + +//export getTokenCallback +func getTokenCallback(statePtr unsafe.Pointer, token *C.char) bool { + m.Lock() + defer m.Unlock() + + if callback, ok := callbacks[uintptr(statePtr)]; ok { + return callback(C.GoString(token)) + } + + return true +} + +// setCallback can be used to register a token callback for LLama. Pass in a nil callback to +// remove the callback. +func setTokenCallback(statePtr unsafe.Pointer, callback func(string) bool) { + m.Lock() + defer m.Unlock() + + if callback == nil { + delete(callbacks, uintptr(statePtr)) + } else { + callbacks[uintptr(statePtr)] = callback + } +} diff --git a/gpt4all-bindings/golang/gpt4all_suite_test.go b/gpt4all-bindings/golang/gpt4all_suite_test.go new file mode 100644 index 00000000..3f379b1e --- /dev/null +++ b/gpt4all-bindings/golang/gpt4all_suite_test.go @@ -0,0 +1,13 @@ +package gpt4all_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestGPT(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "go-gpt4all-j test suite") +} diff --git a/gpt4all-bindings/golang/gpt4all_test.go b/gpt4all-bindings/golang/gpt4all_test.go new file mode 100644 index 00000000..1d99dd66 --- /dev/null +++ b/gpt4all-bindings/golang/gpt4all_test.go @@ -0,0 +1,27 @@ +package gpt4all_test + +import ( + . "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LLama binding", func() { + Context("Declaration", func() { + It("fails with no model", func() { + model, err := New("not-existing") + Expect(err).To(HaveOccurred()) + Expect(model).To(BeNil()) + }) + It("fails with no model", func() { + model, err := New("not-existing", SetModelType(MPTType)) + Expect(err).To(HaveOccurred()) + Expect(model).To(BeNil()) + }) + It("fails with no model", func() { + model, err := New("not-existing", SetModelType(LLaMAType)) + Expect(err).To(HaveOccurred()) + Expect(model).To(BeNil()) + }) + }) +}) diff --git a/gpt4all-bindings/golang/options.go b/gpt4all-bindings/golang/options.go new file mode 100644 index 00000000..573f9abc --- /dev/null +++ b/gpt4all-bindings/golang/options.go @@ -0,0 +1,127 @@ +package gpt4all + +type PredictOptions struct { + ContextSize, RepeatLastN, Tokens, TopK, Batch int + TopP, Temperature, ContextErase, RepeatPenalty float64 +} + +type PredictOption func(p *PredictOptions) + +var DefaultOptions PredictOptions = PredictOptions{ + Tokens: 200, + TopK: 10, + TopP: 0.90, + Temperature: 0.96, + Batch: 1, + ContextErase: 0.55, + ContextSize: 1024, + RepeatLastN: 10, + RepeatPenalty: 1.2, +} + +var DefaultModelOptions ModelOptions = ModelOptions{ + Threads: 4, + ModelType: GPTJType, +} + +type ModelOptions struct { + Threads int + ModelType ModelType +} +type ModelOption func(p *ModelOptions) + +type ModelType int + +const ( + LLaMAType ModelType = 0 + GPTJType ModelType = iota + MPTType ModelType = iota +) + +// SetTokens sets the number of tokens to generate. +func SetTokens(tokens int) PredictOption { + return func(p *PredictOptions) { + p.Tokens = tokens + } +} + +// SetTopK sets the value for top-K sampling. +func SetTopK(topk int) PredictOption { + return func(p *PredictOptions) { + p.TopK = topk + } +} + +// SetTopP sets the value for nucleus sampling. +func SetTopP(topp float64) PredictOption { + return func(p *PredictOptions) { + p.TopP = topp + } +} + +// SetRepeatPenalty sets the repeat penalty. +func SetRepeatPenalty(ce float64) PredictOption { + return func(p *PredictOptions) { + p.RepeatPenalty = ce + } +} + +// SetRepeatLastN sets the RepeatLastN. +func SetRepeatLastN(ce int) PredictOption { + return func(p *PredictOptions) { + p.RepeatLastN = ce + } +} + +// SetContextErase sets the context erase %. +func SetContextErase(ce float64) PredictOption { + return func(p *PredictOptions) { + p.ContextErase = ce + } +} + +// SetTemperature sets the temperature value for text generation. +func SetTemperature(temp float64) PredictOption { + return func(p *PredictOptions) { + p.Temperature = temp + } +} + +// SetBatch sets the batch size. +func SetBatch(size int) PredictOption { + return func(p *PredictOptions) { + p.Batch = size + } +} + +// Create a new PredictOptions object with the given options. +func NewPredictOptions(opts ...PredictOption) PredictOptions { + p := DefaultOptions + for _, opt := range opts { + opt(&p) + } + return p +} + +// SetThreads sets the number of threads to use for text generation. +func SetThreads(c int) ModelOption { + return func(p *ModelOptions) { + p.Threads = c + } +} + +// SetModelType sets the model type. +func SetModelType(c ModelType) ModelOption { + return func(p *ModelOptions) { + p.ModelType = c + } +} + +// Create a new PredictOptions object with the given options. +func NewModelOptions(opts ...ModelOption) ModelOptions { + p := DefaultModelOptions + for _, opt := range opts { + opt(&p) + } + return p +}