diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt index aab1e98d..cb1e2675 100644 --- a/gpt4all-backend/CMakeLists.txt +++ b/gpt4all-backend/CMakeLists.txt @@ -9,7 +9,9 @@ if(APPLE) set(CMAKE_OSX_ARCHITECTURES "arm64;x86_64" CACHE STRING "" FORCE) else() # Build for the host architecture on macOS - set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE) + if(NOT CMAKE_OSX_ARCHITECTURES) + set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE) + endif() endif() endif() diff --git a/gpt4all-bindings/csharp/Gpt4All.Samples/Gpt4All.Samples.csproj b/gpt4all-bindings/csharp/Gpt4All.Samples/Gpt4All.Samples.csproj index 6fd881b0..9eb01e14 100644 --- a/gpt4all-bindings/csharp/Gpt4All.Samples/Gpt4All.Samples.csproj +++ b/gpt4all-bindings/csharp/Gpt4All.Samples/Gpt4All.Samples.csproj @@ -1,18 +1,31 @@ - - Exe - net7.0 - enable - enable - + + Exe + net7.0 + enable + enable + - - - + + + - - - + + + + + + + + + + + + + + + + diff --git a/gpt4all-bindings/csharp/Gpt4All.Tests/Gpt4All.Tests.csproj b/gpt4all-bindings/csharp/Gpt4All.Tests/Gpt4All.Tests.csproj index 56211651..a2918628 100644 --- a/gpt4all-bindings/csharp/Gpt4All.Tests/Gpt4All.Tests.csproj +++ b/gpt4all-bindings/csharp/Gpt4All.Tests/Gpt4All.Tests.csproj @@ -21,7 +21,24 @@ - + + + + + + + + + + + + + + + + + + diff --git a/gpt4all-bindings/csharp/Gpt4All.Tests/ModelFactoryTests.cs b/gpt4all-bindings/csharp/Gpt4All.Tests/ModelFactoryTests.cs index 6465c8df..19d91488 100644 --- a/gpt4all-bindings/csharp/Gpt4All.Tests/ModelFactoryTests.cs +++ b/gpt4all-bindings/csharp/Gpt4All.Tests/ModelFactoryTests.cs @@ -14,18 +14,18 @@ public class ModelFactoryTests [Fact] public void CanLoadLlamaModel() { - using var model = _modelFactory.LoadLlamaModel(Constants.LLAMA_MODEL_PATH); + using var model = _modelFactory.LoadModel(Constants.LLAMA_MODEL_PATH); } [Fact] public void CanLoadGptjModel() { - using var model = _modelFactory.LoadGptjModel(Constants.GPTJ_MODEL_PATH); + using var model = _modelFactory.LoadModel(Constants.GPTJ_MODEL_PATH); } [Fact] public void CanLoadMptModel() { - using var model = _modelFactory.LoadMptModel(Constants.MPT_MODEL_PATH); + using var model = _modelFactory.LoadModel(Constants.MPT_MODEL_PATH); } } diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs index 206b00cf..55defe09 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs @@ -1,247 +1,222 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; - -namespace Gpt4All.Bindings; - -/// -/// Arguments for the response processing callback -/// -/// The token id of the response -/// The response string. NOTE: a token_id of -1 indicates the string is an error string -/// -/// A bool indicating whether the model should keep generating -/// -public record ModelResponseEventArgs(int TokenId, string Response) -{ - public bool IsError => TokenId == -1; -} - -/// -/// Arguments for the prompt processing callback -/// -/// The token id of the prompt -/// -/// A bool indicating whether the model should keep processing -/// -public record ModelPromptEventArgs(int TokenId) -{ -} - -/// -/// Arguments for the recalculating callback -/// -/// whether the model is recalculating the context. -/// -/// A bool indicating whether the model should keep generating -/// -public record ModelRecalculatingEventArgs(bool IsRecalculating); - -/// -/// Base class and universal wrapper for GPT4All language models built around llmodel C-API. -/// -public class LLModel : ILLModel -{ - protected readonly IntPtr _handle; - private readonly ModelType _modelType; - private readonly ILogger _logger; - private bool _disposed; - - public ModelType ModelType => _modelType; - - internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null) - { - _handle = handle; - _modelType = modelType; - _logger = logger ?? NullLogger.Instance; - } - - /// - /// Create a new model from a pointer - /// - /// Pointer to underlying model - /// The model type - public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null) - { - return new LLModel(handle, modelType, logger: logger); - } - - /// - /// Generate a response using the model - /// - /// The input promp - /// The context - /// A callback function for handling the processing of prompt - /// A callback function for handling the generated response - /// A callback function for handling recalculation requests - /// - public void Prompt( - string text, - LLModelPromptContext context, - Func? promptCallback = null, - Func? responseCallback = null, - Func? recalculateCallback = null, - CancellationToken cancellationToken = default) - { - GC.KeepAlive(promptCallback); - GC.KeepAlive(responseCallback); - GC.KeepAlive(recalculateCallback); - GC.KeepAlive(cancellationToken); - - _logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump()); - - NativeMethods.llmodel_prompt( - _handle, - text, - (tokenId) => - { - if (cancellationToken.IsCancellationRequested) return false; - if (promptCallback == null) return true; - var args = new ModelPromptEventArgs(tokenId); - return promptCallback(args); - }, - (tokenId, response) => - { - if (cancellationToken.IsCancellationRequested) - { - _logger.LogDebug("ResponseCallback evt=CancellationRequested"); - return false; - } - - if (responseCallback == null) return true; - var args = new ModelResponseEventArgs(tokenId, response); - return responseCallback(args); - }, - (isRecalculating) => - { - if (cancellationToken.IsCancellationRequested) return false; - if (recalculateCallback == null) return true; - var args = new ModelRecalculatingEventArgs(isRecalculating); - return recalculateCallback(args); - }, - ref context.UnderlyingContext - ); - } - - /// - /// Set the number of threads to be used by the model. - /// - /// The new thread count - public void SetThreadCount(int threadCount) - { - NativeMethods.llmodel_setThreadCount(_handle, threadCount); - } - - /// - /// Get the number of threads used by the model. - /// - /// the number of threads used by the model - public int GetThreadCount() - { - return NativeMethods.llmodel_threadCount(_handle); - } - - /// - /// Get the size of the internal state of the model. - /// - /// - /// This state data is specific to the type of model you have created. - /// - /// the size in bytes of the internal state of the model - public ulong GetStateSizeBytes() - { - return NativeMethods.llmodel_get_state_size(_handle); - } - - /// - /// Saves the internal state of the model to the specified destination address. - /// - /// A pointer to the src - /// The number of bytes copied - public unsafe ulong SaveStateData(byte* source) - { - return NativeMethods.llmodel_save_state_data(_handle, source); - } - - /// - /// Restores the internal state of the model using data from the specified address. - /// - /// A pointer to destination - /// the number of bytes read - public unsafe ulong RestoreStateData(byte* destination) - { - return NativeMethods.llmodel_restore_state_data(_handle, destination); - } - - /// - /// Check if the model is loaded. - /// - /// true if the model was loaded successfully, false otherwise. - public bool IsLoaded() - { - return NativeMethods.llmodel_isModelLoaded(_handle); - } - - /// - /// Load the model from a file. - /// - /// The path to the model file. - /// true if the model was loaded successfully, false otherwise. - public bool Load(string modelPath) - { - return NativeMethods.llmodel_loadModel(_handle, modelPath); - } - - protected void Destroy() - { - NativeMethods.llmodel_model_destroy(_handle); - } - - protected void DestroyLLama() - { - NativeMethods.llmodel_llama_destroy(_handle); - } - - protected void DestroyGptj() - { - NativeMethods.llmodel_gptj_destroy(_handle); - } - - protected void DestroyMtp() - { - NativeMethods.llmodel_mpt_destroy(_handle); - } - - protected virtual void Dispose(bool disposing) - { - if (_disposed) return; - - if (disposing) - { - // dispose managed state - } - - switch (_modelType) - { - case ModelType.LLAMA: - DestroyLLama(); - break; - case ModelType.GPTJ: - DestroyGptj(); - break; - case ModelType.MPT: - DestroyMtp(); - break; - default: - Destroy(); - break; - } - - _disposed = true; - } - - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } -} +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Gpt4All.Bindings; + +/// +/// Arguments for the response processing callback +/// +/// The token id of the response +/// The response string. NOTE: a token_id of -1 indicates the string is an error string +/// +/// A bool indicating whether the model should keep generating +/// +public record ModelResponseEventArgs(int TokenId, string Response) +{ + public bool IsError => TokenId == -1; +} + +/// +/// Arguments for the prompt processing callback +/// +/// The token id of the prompt +/// +/// A bool indicating whether the model should keep processing +/// +public record ModelPromptEventArgs(int TokenId) +{ +} + +/// +/// Arguments for the recalculating callback +/// +/// whether the model is recalculating the context. +/// +/// A bool indicating whether the model should keep generating +/// +public record ModelRecalculatingEventArgs(bool IsRecalculating); + +/// +/// Base class and universal wrapper for GPT4All language models built around llmodel C-API. +/// +public class LLModel : ILLModel +{ + protected readonly IntPtr _handle; + private readonly ModelType _modelType; + private readonly ILogger _logger; + private bool _disposed; + + public ModelType ModelType => _modelType; + + internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null) + { + _handle = handle; + _modelType = modelType; + _logger = logger ?? NullLogger.Instance; + } + + /// + /// Create a new model from a pointer + /// + /// Pointer to underlying model + /// The model type + public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null) + { + return new LLModel(handle, modelType, logger: logger); + } + + /// + /// Generate a response using the model + /// + /// The input promp + /// The context + /// A callback function for handling the processing of prompt + /// A callback function for handling the generated response + /// A callback function for handling recalculation requests + /// + public void Prompt( + string text, + LLModelPromptContext context, + Func? promptCallback = null, + Func? responseCallback = null, + Func? recalculateCallback = null, + CancellationToken cancellationToken = default) + { + GC.KeepAlive(promptCallback); + GC.KeepAlive(responseCallback); + GC.KeepAlive(recalculateCallback); + GC.KeepAlive(cancellationToken); + + _logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump()); + + NativeMethods.llmodel_prompt( + _handle, + text, + (tokenId) => + { + if (cancellationToken.IsCancellationRequested) return false; + if (promptCallback == null) return true; + var args = new ModelPromptEventArgs(tokenId); + return promptCallback(args); + }, + (tokenId, response) => + { + if (cancellationToken.IsCancellationRequested) + { + _logger.LogDebug("ResponseCallback evt=CancellationRequested"); + return false; + } + + if (responseCallback == null) return true; + var args = new ModelResponseEventArgs(tokenId, response); + return responseCallback(args); + }, + (isRecalculating) => + { + if (cancellationToken.IsCancellationRequested) return false; + if (recalculateCallback == null) return true; + var args = new ModelRecalculatingEventArgs(isRecalculating); + return recalculateCallback(args); + }, + ref context.UnderlyingContext + ); + } + + /// + /// Set the number of threads to be used by the model. + /// + /// The new thread count + public void SetThreadCount(int threadCount) + { + NativeMethods.llmodel_setThreadCount(_handle, threadCount); + } + + /// + /// Get the number of threads used by the model. + /// + /// the number of threads used by the model + public int GetThreadCount() + { + return NativeMethods.llmodel_threadCount(_handle); + } + + /// + /// Get the size of the internal state of the model. + /// + /// + /// This state data is specific to the type of model you have created. + /// + /// the size in bytes of the internal state of the model + public ulong GetStateSizeBytes() + { + return NativeMethods.llmodel_get_state_size(_handle); + } + + /// + /// Saves the internal state of the model to the specified destination address. + /// + /// A pointer to the src + /// The number of bytes copied + public unsafe ulong SaveStateData(byte* source) + { + return NativeMethods.llmodel_save_state_data(_handle, source); + } + + /// + /// Restores the internal state of the model using data from the specified address. + /// + /// A pointer to destination + /// the number of bytes read + public unsafe ulong RestoreStateData(byte* destination) + { + return NativeMethods.llmodel_restore_state_data(_handle, destination); + } + + /// + /// Check if the model is loaded. + /// + /// true if the model was loaded successfully, false otherwise. + public bool IsLoaded() + { + return NativeMethods.llmodel_isModelLoaded(_handle); + } + + /// + /// Load the model from a file. + /// + /// The path to the model file. + /// true if the model was loaded successfully, false otherwise. + public bool Load(string modelPath) + { + return NativeMethods.llmodel_loadModel(_handle, modelPath); + } + + protected void Destroy() + { + NativeMethods.llmodel_model_destroy(_handle); + } + protected virtual void Dispose(bool disposing) + { + if (_disposed) return; + + if (disposing) + { + // dispose managed state + } + + switch (_modelType) + { + default: + Destroy(); + break; + } + + _disposed = true; + } + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } +} diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs index eeec504d..cec6948e 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs @@ -1,138 +1,138 @@ -namespace Gpt4All.Bindings; - -/// -/// Wrapper around the llmodel_prompt_context structure for holding the prompt context. -/// -/// -/// The implementation takes care of all the memory handling of the raw logits pointer and the -/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior -/// -public unsafe class LLModelPromptContext -{ - private llmodel_prompt_context _ctx; - - internal ref llmodel_prompt_context UnderlyingContext => ref _ctx; - - public LLModelPromptContext() - { - _ctx = new(); - } - - /// - /// logits of current context - /// - public Span Logits => new(_ctx.logits, (int)_ctx.logits_size); - - /// - /// the size of the raw logits vector - /// - public nuint LogitsSize - { - get => _ctx.logits_size; - set => _ctx.logits_size = value; - } - - /// - /// current tokens in the context window - /// - public Span Tokens => new(_ctx.tokens, (int)_ctx.tokens_size); - - /// - /// the size of the raw tokens vector - /// - public nuint TokensSize - { - get => _ctx.tokens_size; - set => _ctx.tokens_size = value; - } - - /// - /// top k logits to sample from - /// - public int TopK - { - get => _ctx.top_k; - set => _ctx.top_k = value; - } - - /// - /// nucleus sampling probability threshold - /// - public float TopP - { - get => _ctx.top_p; - set => _ctx.top_p = value; - } - - /// - /// temperature to adjust model's output distribution - /// - public float Temperature - { - get => _ctx.temp; - set => _ctx.temp = value; - } - - /// - /// number of tokens in past conversation - /// - public int PastNum - { - get => _ctx.n_past; - set => _ctx.n_past = value; - } - - /// - /// number of predictions to generate in parallel - /// - public int Batches - { - get => _ctx.n_batch; - set => _ctx.n_batch = value; - } - - /// - /// number of tokens to predict - /// - public int TokensToPredict - { - get => _ctx.n_predict; - set => _ctx.n_predict = value; - } - - /// - /// penalty factor for repeated tokens - /// - public float RepeatPenalty - { - get => _ctx.repeat_penalty; - set => _ctx.repeat_penalty = value; - } - - /// - /// last n tokens to penalize - /// - public int RepeatLastN - { - get => _ctx.repeat_last_n; - set => _ctx.repeat_last_n = value; - } - - /// - /// number of tokens possible in context window - /// - public int ContextSize - { - get => _ctx.n_ctx; - set => _ctx.n_ctx = value; - } - - /// - /// percent of context to erase if we exceed the context window - /// - public float ContextErase - { - get => _ctx.context_erase; - set => _ctx.context_erase = value; - } -} +namespace Gpt4All.Bindings; + +/// +/// Wrapper around the llmodel_prompt_context structure for holding the prompt context. +/// +/// +/// The implementation takes care of all the memory handling of the raw logits pointer and the +/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior +/// +public unsafe class LLModelPromptContext +{ + private llmodel_prompt_context _ctx; + + internal ref llmodel_prompt_context UnderlyingContext => ref _ctx; + + public LLModelPromptContext() + { + _ctx = new(); + } + + /// + /// logits of current context + /// + public Span Logits => new(_ctx.logits, (int)_ctx.logits_size); + + /// + /// the size of the raw logits vector + /// + public nuint LogitsSize + { + get => _ctx.logits_size; + set => _ctx.logits_size = value; + } + + /// + /// current tokens in the context window + /// + public Span Tokens => new(_ctx.tokens, (int)_ctx.tokens_size); + + /// + /// the size of the raw tokens vector + /// + public nuint TokensSize + { + get => _ctx.tokens_size; + set => _ctx.tokens_size = value; + } + + /// + /// top k logits to sample from + /// + public int TopK + { + get => _ctx.top_k; + set => _ctx.top_k = value; + } + + /// + /// nucleus sampling probability threshold + /// + public float TopP + { + get => _ctx.top_p; + set => _ctx.top_p = value; + } + + /// + /// temperature to adjust model's output distribution + /// + public float Temperature + { + get => _ctx.temp; + set => _ctx.temp = value; + } + + /// + /// number of tokens in past conversation + /// + public int PastNum + { + get => _ctx.n_past; + set => _ctx.n_past = value; + } + + /// + /// number of predictions to generate in parallel + /// + public int Batches + { + get => _ctx.n_batch; + set => _ctx.n_batch = value; + } + + /// + /// number of tokens to predict + /// + public int TokensToPredict + { + get => _ctx.n_predict; + set => _ctx.n_predict = value; + } + + /// + /// penalty factor for repeated tokens + /// + public float RepeatPenalty + { + get => _ctx.repeat_penalty; + set => _ctx.repeat_penalty = value; + } + + /// + /// last n tokens to penalize + /// + public int RepeatLastN + { + get => _ctx.repeat_last_n; + set => _ctx.repeat_last_n = value; + } + + /// + /// number of tokens possible in context window + /// + public int ContextSize + { + get => _ctx.n_ctx; + set => _ctx.n_ctx = value; + } + + /// + /// percent of context to erase if we exceed the context window + /// + public float ContextErase + { + get => _ctx.context_erase; + set => _ctx.context_erase = value; + } +} diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs index c77212ca..0a606009 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs @@ -1,126 +1,107 @@ -using System.Runtime.InteropServices; - -namespace Gpt4All.Bindings; - -public unsafe partial struct llmodel_prompt_context -{ - public float* logits; - - [NativeTypeName("size_t")] - public nuint logits_size; - - [NativeTypeName("int32_t *")] - public int* tokens; - - [NativeTypeName("size_t")] - public nuint tokens_size; - - [NativeTypeName("int32_t")] - public int n_past; - - [NativeTypeName("int32_t")] - public int n_ctx; - - [NativeTypeName("int32_t")] - public int n_predict; - - [NativeTypeName("int32_t")] - public int top_k; - - public float top_p; - - public float temp; - - [NativeTypeName("int32_t")] - public int n_batch; - - public float repeat_penalty; - - [NativeTypeName("int32_t")] - public int repeat_last_n; - - public float context_erase; -} - -internal static unsafe partial class NativeMethods -{ - [UnmanagedFunctionPointer(CallingConvention.Cdecl)] - [return: MarshalAs(UnmanagedType.I1)] - public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response); - - [UnmanagedFunctionPointer(CallingConvention.Cdecl)] - [return: MarshalAs(UnmanagedType.I1)] - public delegate bool LlmodelPromptCallback(int token_id); - - [UnmanagedFunctionPointer(CallingConvention.Cdecl)] - [return: MarshalAs(UnmanagedType.I1)] - public delegate bool LlmodelRecalculateCallback(bool isRecalculating); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - [return: NativeTypeName("llmodel_model")] - public static extern IntPtr llmodel_gptj_create(); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void llmodel_gptj_destroy([NativeTypeName("llmodel_model")] IntPtr gptj); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - [return: NativeTypeName("llmodel_model")] - public static extern IntPtr llmodel_mpt_create(); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void llmodel_mpt_destroy([NativeTypeName("llmodel_model")] IntPtr mpt); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - [return: NativeTypeName("llmodel_model")] - public static extern IntPtr llmodel_llama_create(); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void llmodel_llama_destroy([NativeTypeName("llmodel_model")] IntPtr llama); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)] - [return: NativeTypeName("llmodel_model")] - public static extern IntPtr llmodel_model_create( - [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)] - [return: MarshalAs(UnmanagedType.I1)] - public static extern bool llmodel_loadModel( - [NativeTypeName("llmodel_model")] IntPtr model, - [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - - [return: MarshalAs(UnmanagedType.I1)] - public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - [return: NativeTypeName("uint64_t")] - public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - [return: NativeTypeName("uint64_t")] - public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - [return: NativeTypeName("uint64_t")] - public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)] - public static extern void llmodel_prompt( - [NativeTypeName("llmodel_model")] IntPtr model, - [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt, - LlmodelPromptCallback prompt_callback, - LlmodelResponseCallback response_callback, - LlmodelRecalculateCallback recalculate_callback, - ref llmodel_prompt_context ctx); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads); - - [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - [return: NativeTypeName("int32_t")] - public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model); -} +using System.Runtime.InteropServices; + +namespace Gpt4All.Bindings; + +public unsafe partial struct llmodel_prompt_context +{ + public float* logits; + + [NativeTypeName("size_t")] + public nuint logits_size; + + [NativeTypeName("int32_t *")] + public int* tokens; + + [NativeTypeName("size_t")] + public nuint tokens_size; + + [NativeTypeName("int32_t")] + public int n_past; + + [NativeTypeName("int32_t")] + public int n_ctx; + + [NativeTypeName("int32_t")] + public int n_predict; + + [NativeTypeName("int32_t")] + public int top_k; + + public float top_p; + + public float temp; + + [NativeTypeName("int32_t")] + public int n_batch; + + public float repeat_penalty; + + [NativeTypeName("int32_t")] + public int repeat_last_n; + + public float context_erase; +} + +internal static unsafe partial class NativeMethods +{ + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + public delegate bool LlmodelPromptCallback(int token_id); + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.I1)] + public delegate bool LlmodelRecalculateCallback(bool isRecalculating); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)] + [return: NativeTypeName("llmodel_model")] + public static extern IntPtr llmodel_model_create2( + [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path, + [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string build_variant, + out IntPtr error); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)] + [return: MarshalAs(UnmanagedType.I1)] + public static extern bool llmodel_loadModel( + [NativeTypeName("llmodel_model")] IntPtr model, + [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + + [return: MarshalAs(UnmanagedType.I1)] + public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + [return: NativeTypeName("uint64_t")] + public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + [return: NativeTypeName("uint64_t")] + public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + [return: NativeTypeName("uint64_t")] + public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)] + public static extern void llmodel_prompt( + [NativeTypeName("llmodel_model")] IntPtr model, + [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt, + LlmodelPromptCallback prompt_callback, + LlmodelResponseCallback response_callback, + LlmodelRecalculateCallback recalculate_callback, + ref llmodel_prompt_context ctx); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads); + + [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + [return: NativeTypeName("int32_t")] + public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model); +} diff --git a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj index db6780fe..72885512 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj +++ b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj @@ -1,27 +1,11 @@  - - - net6.0 - enable - enable - true - - - - - - - - - - - - - - - - - - - + + net6.0 + enable + enable + true + + + + diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/ILibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/ILibraryLoader.cs new file mode 100644 index 00000000..c4e462f8 --- /dev/null +++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/ILibraryLoader.cs @@ -0,0 +1,6 @@ +namespace Gpt4All.LibraryLoader; + +public interface ILibraryLoader +{ + LoadResult OpenLibrary(string? fileName); +} diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LinuxLibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LinuxLibraryLoader.cs new file mode 100644 index 00000000..d7f6834a --- /dev/null +++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LinuxLibraryLoader.cs @@ -0,0 +1,53 @@ +using System.Runtime.InteropServices; + +namespace Gpt4All.LibraryLoader; + +internal class LinuxLibraryLoader : ILibraryLoader +{ +#pragma warning disable CA2101 + [DllImport("libdl.so", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")] +#pragma warning restore CA2101 + public static extern IntPtr NativeOpenLibraryLibdl(string? filename, int flags); + +#pragma warning disable CA2101 + [DllImport("libdl.so.2", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")] +#pragma warning restore CA2101 + public static extern IntPtr NativeOpenLibraryLibdl2(string? filename, int flags); + + [DllImport("libdl.so", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")] + public static extern IntPtr GetLoadError(); + + [DllImport("libdl.so.2", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")] + public static extern IntPtr GetLoadError2(); + + public LoadResult OpenLibrary(string? fileName) + { + IntPtr loadedLib; + try + { + // open with rtls lazy flag + loadedLib = NativeOpenLibraryLibdl2(fileName, 0x00001); + } + catch (DllNotFoundException) + { + loadedLib = NativeOpenLibraryLibdl(fileName, 0x00001); + } + + if (loadedLib == IntPtr.Zero) + { + string errorMessage; + try + { + errorMessage = Marshal.PtrToStringAnsi(GetLoadError2()) ?? "Unknown error"; + } + catch (DllNotFoundException) + { + errorMessage = Marshal.PtrToStringAnsi(GetLoadError()) ?? "Unknown error"; + } + + return LoadResult.Failure(errorMessage); + } + + return LoadResult.Success; + } +} diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LoadResult.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LoadResult.cs new file mode 100644 index 00000000..3dccf358 --- /dev/null +++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LoadResult.cs @@ -0,0 +1,20 @@ +namespace Gpt4All.LibraryLoader; + +public class LoadResult +{ + private LoadResult(bool isSuccess, string? errorMessage) + { + IsSuccess = isSuccess; + ErrorMessage = errorMessage; + } + + public static LoadResult Success { get; } = new(true, null); + + public static LoadResult Failure(string errorMessage) + { + return new(false, errorMessage); + } + + public bool IsSuccess { get; } + public string? ErrorMessage { get; } +} diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/MacOsLibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/MacOsLibraryLoader.cs new file mode 100644 index 00000000..6577d979 --- /dev/null +++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/MacOsLibraryLoader.cs @@ -0,0 +1,28 @@ +using System.Runtime.InteropServices; + +namespace Gpt4All.LibraryLoader; + +internal class MacOsLibraryLoader : ILibraryLoader +{ +#pragma warning disable CA2101 + [DllImport("libdl.dylib", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")] +#pragma warning restore CA2101 + public static extern IntPtr NativeOpenLibraryLibdl(string? filename, int flags); + + [DllImport("libdl.dylib", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")] + public static extern IntPtr GetLoadError(); + + public LoadResult OpenLibrary(string? fileName) + { + var loadedLib = NativeOpenLibraryLibdl(fileName, 0x00001); + + if (loadedLib == IntPtr.Zero) + { + var errorMessage = Marshal.PtrToStringAnsi(GetLoadError()) ?? "Unknown error"; + + return LoadResult.Failure(errorMessage); + } + + return LoadResult.Success; + } +} diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/NativeLibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/NativeLibraryLoader.cs new file mode 100644 index 00000000..85353738 --- /dev/null +++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/NativeLibraryLoader.cs @@ -0,0 +1,81 @@ +#if !IOS && !MACCATALYST && !TVOS && !ANDROID +using System.Runtime.InteropServices; +#endif + +namespace Gpt4All.LibraryLoader; + +public static class NativeLibraryLoader +{ + private static ILibraryLoader? defaultLibraryLoader; + + /// + /// Sets the library loader used to load the native libraries. Overwrite this only if you want some custom loading. + /// + /// The library loader to be used. + public static void SetLibraryLoader(ILibraryLoader libraryLoader) + { + defaultLibraryLoader = libraryLoader; + } + + internal static LoadResult LoadNativeLibrary(string? path = default, bool bypassLoading = true) + { + // If the user has handled loading the library themselves, we don't need to do anything. + if (bypassLoading) + { + return LoadResult.Success; + } + + var architecture = RuntimeInformation.OSArchitecture switch + { + Architecture.X64 => "x64", + Architecture.X86 => "x86", + Architecture.Arm => "arm", + Architecture.Arm64 => "arm64", + _ => throw new PlatformNotSupportedException( + $"Unsupported OS platform, architecture: {RuntimeInformation.OSArchitecture}") + }; + + var (platform, extension) = Environment.OSVersion.Platform switch + { + _ when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ("win", "dll"), + _ when RuntimeInformation.IsOSPlatform(OSPlatform.Linux) => ("linux", "so"), + _ when RuntimeInformation.IsOSPlatform(OSPlatform.OSX) => ("osx", "dylib"), + _ => throw new PlatformNotSupportedException( + $"Unsupported OS platform, architecture: {RuntimeInformation.OSArchitecture}") + }; + + // If the user hasn't set the path, we'll try to find it ourselves. + if (string.IsNullOrEmpty(path)) + { + var libraryName = "libllmodel"; + var assemblySearchPath = new[] + { + AppDomain.CurrentDomain.RelativeSearchPath, + Path.GetDirectoryName(typeof(NativeLibraryLoader).Assembly.Location), + Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]) + }.FirstOrDefault(it => !string.IsNullOrEmpty(it)); + // Search for the library dll within the assembly search path. If it doesn't exist, for whatever reason, use the default path. + path = Directory.EnumerateFiles(assemblySearchPath ?? string.Empty, $"{libraryName}.{extension}", SearchOption.AllDirectories).FirstOrDefault() ?? Path.Combine("runtimes", $"{platform}-{architecture}", $"{libraryName}.{extension}"); + } + + if (defaultLibraryLoader != null) + { + return defaultLibraryLoader.OpenLibrary(path); + } + + if (!File.Exists(path)) + { + throw new FileNotFoundException($"Native Library not found in path {path}. " + + $"Verify you have have included the native Gpt4All library in your application."); + } + + ILibraryLoader libraryLoader = platform switch + { + "win" => new WindowsLibraryLoader(), + "osx" => new MacOsLibraryLoader(), + "linux" => new LinuxLibraryLoader(), + _ => throw new PlatformNotSupportedException($"Currently {platform} platform is not supported") + }; + return libraryLoader.OpenLibrary(path); + } +} diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/WindowsLibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/WindowsLibraryLoader.cs new file mode 100644 index 00000000..d2479aa4 --- /dev/null +++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/WindowsLibraryLoader.cs @@ -0,0 +1,24 @@ +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Gpt4All.LibraryLoader; + +internal class WindowsLibraryLoader : ILibraryLoader +{ + public LoadResult OpenLibrary(string? fileName) + { + var loadedLib = LoadLibrary(fileName); + + if (loadedLib == IntPtr.Zero) + { + var errorCode = Marshal.GetLastWin32Error(); + var errorMessage = new Win32Exception(errorCode).Message; + return LoadResult.Failure(errorMessage); + } + + return LoadResult.Success; + } + + [DllImport("kernel32", SetLastError = true, CharSet = CharSet.Auto)] + private static extern IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPWStr)] string? lpFileName); +} diff --git a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs index 3c36ac26..02c5c588 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs @@ -1,61 +1,58 @@ -using System.Diagnostics; -using Microsoft.Extensions.Logging; -using Gpt4All.Bindings; -using Microsoft.Extensions.Logging.Abstractions; - -namespace Gpt4All; - -public class Gpt4AllModelFactory : IGpt4AllModelFactory -{ - private readonly ILoggerFactory _loggerFactory; - private readonly ILogger _logger; - - public Gpt4AllModelFactory(ILoggerFactory? loggerFactory = null) - { - _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; - _logger = _loggerFactory.CreateLogger(); - } - - private IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null) - { - var modelType_ = modelType ?? ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath); - - _logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_); - - var handle = modelType_ switch - { - ModelType.LLAMA => NativeMethods.llmodel_llama_create(), - ModelType.GPTJ => NativeMethods.llmodel_gptj_create(), - ModelType.MPT => NativeMethods.llmodel_mpt_create(), - _ => NativeMethods.llmodel_model_create(modelPath), - }; - - _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle); - _logger.LogInformation("Model loading started"); - - var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath); - - _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully); - - if (loadedSuccessfully == false) - { - throw new Exception($"Failed to load model: '{modelPath}'"); - } - - var logger = _loggerFactory.CreateLogger(); - - var underlyingModel = LLModel.Create(handle, modelType_, logger: logger); - - Debug.Assert(underlyingModel.IsLoaded()); - - return new Gpt4All(underlyingModel, logger: logger); - } - - public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath, modelType: null); - - public IGpt4AllModel LoadMptModel(string modelPath) => CreateModel(modelPath, ModelType.MPT); - - public IGpt4AllModel LoadGptjModel(string modelPath) => CreateModel(modelPath, ModelType.GPTJ); - - public IGpt4AllModel LoadLlamaModel(string modelPath) => CreateModel(modelPath, ModelType.LLAMA); -} +using System.Diagnostics; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Logging; +using Gpt4All.Bindings; +using Gpt4All.LibraryLoader; + +namespace Gpt4All; + +public class Gpt4AllModelFactory : IGpt4AllModelFactory +{ + private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; + private static bool bypassLoading; + private static string? libraryPath; + + private static readonly Lazy libraryLoaded = new(() => + { + return NativeLibraryLoader.LoadNativeLibrary(Gpt4AllModelFactory.libraryPath, Gpt4AllModelFactory.bypassLoading); + }, true); + + public Gpt4AllModelFactory(string? libraryPath = default, bool bypassLoading = true, ILoggerFactory? loggerFactory = null) + { + _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; + _logger = _loggerFactory.CreateLogger(); + Gpt4AllModelFactory.libraryPath = libraryPath; + Gpt4AllModelFactory.bypassLoading = bypassLoading; + + if (!libraryLoaded.Value.IsSuccess) + { + throw new Exception($"Failed to load native gpt4all library. Error: {libraryLoaded.Value.ErrorMessage}"); + } + } + + private IGpt4AllModel CreateModel(string modelPath) + { + var modelType_ = ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath); + _logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_); + IntPtr error; + var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error); + _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle); + _logger.LogInformation("Model loading started"); + var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath); + _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully); + if (!loadedSuccessfully) + { + throw new Exception($"Failed to load model: '{modelPath}'"); + } + + var logger = _loggerFactory.CreateLogger(); + var underlyingModel = LLModel.Create(handle, modelType_, logger: logger); + + Debug.Assert(underlyingModel.IsLoaded()); + + return new Gpt4All(underlyingModel, logger: logger); + } + + public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath); +} diff --git a/gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModelFactory.cs b/gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModelFactory.cs index 3a5208aa..90c54d5e 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModelFactory.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModelFactory.cs @@ -1,12 +1,6 @@ -namespace Gpt4All; - -public interface IGpt4AllModelFactory -{ - IGpt4AllModel LoadGptjModel(string modelPath); - - IGpt4AllModel LoadLlamaModel(string modelPath); - - IGpt4AllModel LoadModel(string modelPath); - - IGpt4AllModel LoadMptModel(string modelPath); -} +namespace Gpt4All; + +public interface IGpt4AllModelFactory +{ + IGpt4AllModel LoadModel(string modelPath); +} diff --git a/gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs b/gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs index 4aced85a..c490d8b1 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs @@ -1,11 +1,11 @@ -namespace Gpt4All; - -/// -/// The supported model types -/// -public enum ModelType -{ - LLAMA = 0, - GPTJ, - MPT -} +namespace Gpt4All; + +/// +/// The supported model types +/// +public enum ModelType +{ + LLAMA = 0, + GPTJ, + MPT +} diff --git a/gpt4all-bindings/csharp/Gpt4All/Prediction/ITextPrediction.cs b/gpt4all-bindings/csharp/Gpt4All/Prediction/ITextPrediction.cs index c446feef..47ed3847 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Prediction/ITextPrediction.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Prediction/ITextPrediction.cs @@ -1,31 +1,31 @@ -namespace Gpt4All; - -/// -/// Interface for text prediction services -/// -public interface ITextPrediction -{ - /// - /// Get prediction results for the prompt and provided options. - /// - /// The text to complete - /// The prediction settings - /// The for cancellation requests. The default is . - /// The prediction result generated by the model - Task GetPredictionAsync( - string text, - PredictRequestOptions opts, - CancellationToken cancellation = default); - - /// - /// Get streaming prediction results for the prompt and provided options. - /// - /// The text to complete - /// The prediction settings - /// The for cancellation requests. The default is . - /// The prediction result generated by the model - Task GetStreamingPredictionAsync( - string text, - PredictRequestOptions opts, - CancellationToken cancellationToken = default); -} \ No newline at end of file +namespace Gpt4All; + +/// +/// Interface for text prediction services +/// +public interface ITextPrediction +{ + /// + /// Get prediction results for the prompt and provided options. + /// + /// The text to complete + /// The prediction settings + /// The for cancellation requests. The default is . + /// The prediction result generated by the model + Task GetPredictionAsync( + string text, + PredictRequestOptions opts, + CancellationToken cancellation = default); + + /// + /// Get streaming prediction results for the prompt and provided options. + /// + /// The text to complete + /// The prediction settings + /// The for cancellation requests. The default is . + /// The prediction result generated by the model + Task GetStreamingPredictionAsync( + string text, + PredictRequestOptions opts, + CancellationToken cancellationToken = default); +} diff --git a/gpt4all-bindings/csharp/build_linux.sh b/gpt4all-bindings/csharp/build_linux.sh index a89969e2..b747c35f 100755 --- a/gpt4all-bindings/csharp/build_linux.sh +++ b/gpt4all-bindings/csharp/build_linux.sh @@ -5,4 +5,6 @@ mkdir runtimes/linux-x64/build cmake -S ../../gpt4all-backend -B runtimes/linux-x64/build cmake --build runtimes/linux-x64/build --parallel --config Release cp runtimes/linux-x64/build/libllmodel.so runtimes/linux-x64/native/libllmodel.so -cp runtimes/linux-x64/build/llama.cpp/libllama.so runtimes/linux-x64/native/libllama.so +cp runtimes/linux-x64/build/libgptj*.so runtimes/linux-x64/native/ +cp runtimes/linux-x64/build/libllama*.so runtimes/linux-x64/native/ +cp runtimes/linux-x64/build/libmpt*.so runtimes/linux-x64/native/ diff --git a/gpt4all-bindings/csharp/build_win-mingw.ps1 b/gpt4all-bindings/csharp/build_win-mingw.ps1 index f3d17a3c..1e3dd8ef 100644 --- a/gpt4all-bindings/csharp/build_win-mingw.ps1 +++ b/gpt4all-bindings/csharp/build_win-mingw.ps1 @@ -13,4 +13,5 @@ cmake --build $BUILD_DIR --parallel --config Release # copy native dlls cp "C:\ProgramData\chocolatey\lib\mingw\tools\install\mingw64\bin\*dll" $LIBS_DIR -cp "$BUILD_DIR\*.dll" $LIBS_DIR \ No newline at end of file +cp "$BUILD_DIR\bin\*.dll" $LIBS_DIR +mv $LIBS_DIR\llmodel.dll $LIBS_DIR\libllmodel.dll \ No newline at end of file diff --git a/gpt4all-bindings/csharp/build_win-msvc.ps1 b/gpt4all-bindings/csharp/build_win-msvc.ps1 index 01a65886..8d44f3a7 100644 --- a/gpt4all-bindings/csharp/build_win-msvc.ps1 +++ b/gpt4all-bindings/csharp/build_win-msvc.ps1 @@ -2,4 +2,5 @@ Remove-Item -Force -Recurse .\runtimes\win-x64\msvc -ErrorAction SilentlyContinu mkdir .\runtimes\win-x64\msvc\build | Out-Null cmake -G "Visual Studio 17 2022" -A X64 -S ..\..\gpt4all-backend -B .\runtimes\win-x64\msvc\build cmake --build .\runtimes\win-x64\msvc\build --parallel --config Release -cp .\runtimes\win-x64\msvc\build\bin\Release\*.dll .\runtimes\win-x64 \ No newline at end of file +cp .\runtimes\win-x64\msvc\build\bin\Release\*.dll .\runtimes\win-x64 +mv .\runtimes\win-x64\llmodel.dll .\runtimes\win-x64\libllmodel.dll \ No newline at end of file