mirror of
				https://github.com/nomic-ai/gpt4all.git
				synced 2025-11-03 23:47:16 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			104 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			104 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package gpt4all
 | 
						|
 | 
						|
// #cgo CFLAGS: -I${SRCDIR}../../gpt4all-backend/ -I${SRCDIR}../../gpt4all-backend/llama.cpp -I./
 | 
						|
// #cgo CXXFLAGS: -std=c++17 -I${SRCDIR}../../gpt4all-backend/ -I${SRCDIR}../../gpt4all-backend/llama.cpp -I./
 | 
						|
// #cgo darwin LDFLAGS: -framework Accelerate
 | 
						|
// #cgo darwin CXXFLAGS: -std=c++17
 | 
						|
// #cgo LDFLAGS: -lgpt4all -lm -lstdc++ -ldl
 | 
						|
// void* load_model(const char *fname, int n_threads);
 | 
						|
// void model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k,
 | 
						|
//                            float top_p, float temp, int n_batch,float ctx_erase);
 | 
						|
// void free_model(void *state_ptr);
 | 
						|
// extern unsigned char getTokenCallback(void *, char *);
 | 
						|
import "C"
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"runtime"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"unsafe"
 | 
						|
)
 | 
						|
 | 
						|
// The following code is https://github.com/go-skynet/go-llama.cpp with small adaptations
 | 
						|
type Model struct {
 | 
						|
	state unsafe.Pointer
 | 
						|
}
 | 
						|
 | 
						|
func New(model string, opts ...ModelOption) (*Model, error) {
 | 
						|
	ops := NewModelOptions(opts...)
 | 
						|
 | 
						|
	state := C.load_model(C.CString(model), C.int(ops.Threads))
 | 
						|
 | 
						|
	if state == nil {
 | 
						|
		return nil, fmt.Errorf("failed loading model")
 | 
						|
	}
 | 
						|
 | 
						|
	gpt := &Model{state: state}
 | 
						|
	// set a finalizer to remove any callbacks when the struct is reclaimed by the garbage collector.
 | 
						|
	runtime.SetFinalizer(gpt, func(g *Model) {
 | 
						|
		setTokenCallback(g.state, nil)
 | 
						|
	})
 | 
						|
 | 
						|
	return gpt, nil
 | 
						|
}
 | 
						|
 | 
						|
func (l *Model) Predict(text string, opts ...PredictOption) (string, error) {
 | 
						|
 | 
						|
	po := NewPredictOptions(opts...)
 | 
						|
 | 
						|
	input := C.CString(text)
 | 
						|
	if po.Tokens == 0 {
 | 
						|
		po.Tokens = 99999999
 | 
						|
	}
 | 
						|
	out := make([]byte, po.Tokens)
 | 
						|
 | 
						|
	C.model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize),
 | 
						|
		C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase))
 | 
						|
 | 
						|
	res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
 | 
						|
	res = strings.TrimPrefix(res, " ")
 | 
						|
	res = strings.TrimPrefix(res, text)
 | 
						|
	res = strings.TrimPrefix(res, "\n")
 | 
						|
	res = strings.TrimSuffix(res, "<|endoftext|>")
 | 
						|
 | 
						|
	return res, nil
 | 
						|
}
 | 
						|
 | 
						|
func (l *Model) Free() {
 | 
						|
	C.free_model(l.state)
 | 
						|
}
 | 
						|
 | 
						|
func (l *Model) SetTokenCallback(callback func(token string) bool) {
 | 
						|
	setTokenCallback(l.state, callback)
 | 
						|
}
 | 
						|
 | 
						|
var (
 | 
						|
	m         sync.Mutex
 | 
						|
	callbacks = map[uintptr]func(string) bool{}
 | 
						|
)
 | 
						|
 | 
						|
//export getTokenCallback
 | 
						|
func getTokenCallback(statePtr unsafe.Pointer, token *C.char) bool {
 | 
						|
	m.Lock()
 | 
						|
	defer m.Unlock()
 | 
						|
 | 
						|
	if callback, ok := callbacks[uintptr(statePtr)]; ok {
 | 
						|
		return callback(C.GoString(token))
 | 
						|
	}
 | 
						|
 | 
						|
	return true
 | 
						|
}
 | 
						|
 | 
						|
// setCallback can be used to register a token callback for LLama. Pass in a nil callback to
 | 
						|
// remove the callback.
 | 
						|
func setTokenCallback(statePtr unsafe.Pointer, callback func(string) bool) {
 | 
						|
	m.Lock()
 | 
						|
	defer m.Unlock()
 | 
						|
 | 
						|
	if callback == nil {
 | 
						|
		delete(callbacks, uintptr(statePtr))
 | 
						|
	} else {
 | 
						|
		callbacks[uintptr(statePtr)] = callback
 | 
						|
	}
 | 
						|
}
 |