diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt
index aab1e98d..cb1e2675 100644
--- a/gpt4all-backend/CMakeLists.txt
+++ b/gpt4all-backend/CMakeLists.txt
@@ -9,7 +9,9 @@ if(APPLE)
set(CMAKE_OSX_ARCHITECTURES "arm64;x86_64" CACHE STRING "" FORCE)
else()
# Build for the host architecture on macOS
- set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE)
+ if(NOT CMAKE_OSX_ARCHITECTURES)
+ set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE)
+ endif()
endif()
endif()
diff --git a/gpt4all-bindings/csharp/Gpt4All.Samples/Gpt4All.Samples.csproj b/gpt4all-bindings/csharp/Gpt4All.Samples/Gpt4All.Samples.csproj
index 6fd881b0..9eb01e14 100644
--- a/gpt4all-bindings/csharp/Gpt4All.Samples/Gpt4All.Samples.csproj
+++ b/gpt4all-bindings/csharp/Gpt4All.Samples/Gpt4All.Samples.csproj
@@ -1,18 +1,31 @@
-
- Exe
- net7.0
- enable
- enable
-
+
+ Exe
+ net7.0
+ enable
+ enable
+
-
-
-
+
+
+
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/gpt4all-bindings/csharp/Gpt4All.Tests/Gpt4All.Tests.csproj b/gpt4all-bindings/csharp/Gpt4All.Tests/Gpt4All.Tests.csproj
index 56211651..a2918628 100644
--- a/gpt4all-bindings/csharp/Gpt4All.Tests/Gpt4All.Tests.csproj
+++ b/gpt4all-bindings/csharp/Gpt4All.Tests/Gpt4All.Tests.csproj
@@ -21,7 +21,24 @@
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/gpt4all-bindings/csharp/Gpt4All.Tests/ModelFactoryTests.cs b/gpt4all-bindings/csharp/Gpt4All.Tests/ModelFactoryTests.cs
index 6465c8df..19d91488 100644
--- a/gpt4all-bindings/csharp/Gpt4All.Tests/ModelFactoryTests.cs
+++ b/gpt4all-bindings/csharp/Gpt4All.Tests/ModelFactoryTests.cs
@@ -14,18 +14,18 @@ public class ModelFactoryTests
[Fact]
public void CanLoadLlamaModel()
{
- using var model = _modelFactory.LoadLlamaModel(Constants.LLAMA_MODEL_PATH);
+ using var model = _modelFactory.LoadModel(Constants.LLAMA_MODEL_PATH);
}
[Fact]
public void CanLoadGptjModel()
{
- using var model = _modelFactory.LoadGptjModel(Constants.GPTJ_MODEL_PATH);
+ using var model = _modelFactory.LoadModel(Constants.GPTJ_MODEL_PATH);
}
[Fact]
public void CanLoadMptModel()
{
- using var model = _modelFactory.LoadMptModel(Constants.MPT_MODEL_PATH);
+ using var model = _modelFactory.LoadModel(Constants.MPT_MODEL_PATH);
}
}
diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs
index 206b00cf..55defe09 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs
@@ -1,247 +1,222 @@
-using Microsoft.Extensions.Logging;
-using Microsoft.Extensions.Logging.Abstractions;
-
-namespace Gpt4All.Bindings;
-
-///
-/// Arguments for the response processing callback
-///
-/// The token id of the response
-/// The response string. NOTE: a token_id of -1 indicates the string is an error string
-///
-/// A bool indicating whether the model should keep generating
-///
-public record ModelResponseEventArgs(int TokenId, string Response)
-{
- public bool IsError => TokenId == -1;
-}
-
-///
-/// Arguments for the prompt processing callback
-///
-/// The token id of the prompt
-///
-/// A bool indicating whether the model should keep processing
-///
-public record ModelPromptEventArgs(int TokenId)
-{
-}
-
-///
-/// Arguments for the recalculating callback
-///
-/// whether the model is recalculating the context.
-///
-/// A bool indicating whether the model should keep generating
-///
-public record ModelRecalculatingEventArgs(bool IsRecalculating);
-
-///
-/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
-///
-public class LLModel : ILLModel
-{
- protected readonly IntPtr _handle;
- private readonly ModelType _modelType;
- private readonly ILogger _logger;
- private bool _disposed;
-
- public ModelType ModelType => _modelType;
-
- internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
- {
- _handle = handle;
- _modelType = modelType;
- _logger = logger ?? NullLogger.Instance;
- }
-
- ///
- /// Create a new model from a pointer
- ///
- /// Pointer to underlying model
- /// The model type
- public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
- {
- return new LLModel(handle, modelType, logger: logger);
- }
-
- ///
- /// Generate a response using the model
- ///
- /// The input promp
- /// The context
- /// A callback function for handling the processing of prompt
- /// A callback function for handling the generated response
- /// A callback function for handling recalculation requests
- ///
- public void Prompt(
- string text,
- LLModelPromptContext context,
- Func? promptCallback = null,
- Func? responseCallback = null,
- Func? recalculateCallback = null,
- CancellationToken cancellationToken = default)
- {
- GC.KeepAlive(promptCallback);
- GC.KeepAlive(responseCallback);
- GC.KeepAlive(recalculateCallback);
- GC.KeepAlive(cancellationToken);
-
- _logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
-
- NativeMethods.llmodel_prompt(
- _handle,
- text,
- (tokenId) =>
- {
- if (cancellationToken.IsCancellationRequested) return false;
- if (promptCallback == null) return true;
- var args = new ModelPromptEventArgs(tokenId);
- return promptCallback(args);
- },
- (tokenId, response) =>
- {
- if (cancellationToken.IsCancellationRequested)
- {
- _logger.LogDebug("ResponseCallback evt=CancellationRequested");
- return false;
- }
-
- if (responseCallback == null) return true;
- var args = new ModelResponseEventArgs(tokenId, response);
- return responseCallback(args);
- },
- (isRecalculating) =>
- {
- if (cancellationToken.IsCancellationRequested) return false;
- if (recalculateCallback == null) return true;
- var args = new ModelRecalculatingEventArgs(isRecalculating);
- return recalculateCallback(args);
- },
- ref context.UnderlyingContext
- );
- }
-
- ///
- /// Set the number of threads to be used by the model.
- ///
- /// The new thread count
- public void SetThreadCount(int threadCount)
- {
- NativeMethods.llmodel_setThreadCount(_handle, threadCount);
- }
-
- ///
- /// Get the number of threads used by the model.
- ///
- /// the number of threads used by the model
- public int GetThreadCount()
- {
- return NativeMethods.llmodel_threadCount(_handle);
- }
-
- ///
- /// Get the size of the internal state of the model.
- ///
- ///
- /// This state data is specific to the type of model you have created.
- ///
- /// the size in bytes of the internal state of the model
- public ulong GetStateSizeBytes()
- {
- return NativeMethods.llmodel_get_state_size(_handle);
- }
-
- ///
- /// Saves the internal state of the model to the specified destination address.
- ///
- /// A pointer to the src
- /// The number of bytes copied
- public unsafe ulong SaveStateData(byte* source)
- {
- return NativeMethods.llmodel_save_state_data(_handle, source);
- }
-
- ///
- /// Restores the internal state of the model using data from the specified address.
- ///
- /// A pointer to destination
- /// the number of bytes read
- public unsafe ulong RestoreStateData(byte* destination)
- {
- return NativeMethods.llmodel_restore_state_data(_handle, destination);
- }
-
- ///
- /// Check if the model is loaded.
- ///
- /// true if the model was loaded successfully, false otherwise.
- public bool IsLoaded()
- {
- return NativeMethods.llmodel_isModelLoaded(_handle);
- }
-
- ///
- /// Load the model from a file.
- ///
- /// The path to the model file.
- /// true if the model was loaded successfully, false otherwise.
- public bool Load(string modelPath)
- {
- return NativeMethods.llmodel_loadModel(_handle, modelPath);
- }
-
- protected void Destroy()
- {
- NativeMethods.llmodel_model_destroy(_handle);
- }
-
- protected void DestroyLLama()
- {
- NativeMethods.llmodel_llama_destroy(_handle);
- }
-
- protected void DestroyGptj()
- {
- NativeMethods.llmodel_gptj_destroy(_handle);
- }
-
- protected void DestroyMtp()
- {
- NativeMethods.llmodel_mpt_destroy(_handle);
- }
-
- protected virtual void Dispose(bool disposing)
- {
- if (_disposed) return;
-
- if (disposing)
- {
- // dispose managed state
- }
-
- switch (_modelType)
- {
- case ModelType.LLAMA:
- DestroyLLama();
- break;
- case ModelType.GPTJ:
- DestroyGptj();
- break;
- case ModelType.MPT:
- DestroyMtp();
- break;
- default:
- Destroy();
- break;
- }
-
- _disposed = true;
- }
-
- public void Dispose()
- {
- Dispose(disposing: true);
- GC.SuppressFinalize(this);
- }
-}
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+
+namespace Gpt4All.Bindings;
+
+///
+/// Arguments for the response processing callback
+///
+/// The token id of the response
+/// The response string. NOTE: a token_id of -1 indicates the string is an error string
+///
+/// A bool indicating whether the model should keep generating
+///
+public record ModelResponseEventArgs(int TokenId, string Response)
+{
+ public bool IsError => TokenId == -1;
+}
+
+///
+/// Arguments for the prompt processing callback
+///
+/// The token id of the prompt
+///
+/// A bool indicating whether the model should keep processing
+///
+public record ModelPromptEventArgs(int TokenId)
+{
+}
+
+///
+/// Arguments for the recalculating callback
+///
+/// whether the model is recalculating the context.
+///
+/// A bool indicating whether the model should keep generating
+///
+public record ModelRecalculatingEventArgs(bool IsRecalculating);
+
+///
+/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
+///
+public class LLModel : ILLModel
+{
+ protected readonly IntPtr _handle;
+ private readonly ModelType _modelType;
+ private readonly ILogger _logger;
+ private bool _disposed;
+
+ public ModelType ModelType => _modelType;
+
+ internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
+ {
+ _handle = handle;
+ _modelType = modelType;
+ _logger = logger ?? NullLogger.Instance;
+ }
+
+ ///
+ /// Create a new model from a pointer
+ ///
+ /// Pointer to underlying model
+ /// The model type
+ public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
+ {
+ return new LLModel(handle, modelType, logger: logger);
+ }
+
+ ///
+ /// Generate a response using the model
+ ///
+ /// The input promp
+ /// The context
+ /// A callback function for handling the processing of prompt
+ /// A callback function for handling the generated response
+ /// A callback function for handling recalculation requests
+ ///
+ public void Prompt(
+ string text,
+ LLModelPromptContext context,
+ Func? promptCallback = null,
+ Func? responseCallback = null,
+ Func? recalculateCallback = null,
+ CancellationToken cancellationToken = default)
+ {
+ GC.KeepAlive(promptCallback);
+ GC.KeepAlive(responseCallback);
+ GC.KeepAlive(recalculateCallback);
+ GC.KeepAlive(cancellationToken);
+
+ _logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
+
+ NativeMethods.llmodel_prompt(
+ _handle,
+ text,
+ (tokenId) =>
+ {
+ if (cancellationToken.IsCancellationRequested) return false;
+ if (promptCallback == null) return true;
+ var args = new ModelPromptEventArgs(tokenId);
+ return promptCallback(args);
+ },
+ (tokenId, response) =>
+ {
+ if (cancellationToken.IsCancellationRequested)
+ {
+ _logger.LogDebug("ResponseCallback evt=CancellationRequested");
+ return false;
+ }
+
+ if (responseCallback == null) return true;
+ var args = new ModelResponseEventArgs(tokenId, response);
+ return responseCallback(args);
+ },
+ (isRecalculating) =>
+ {
+ if (cancellationToken.IsCancellationRequested) return false;
+ if (recalculateCallback == null) return true;
+ var args = new ModelRecalculatingEventArgs(isRecalculating);
+ return recalculateCallback(args);
+ },
+ ref context.UnderlyingContext
+ );
+ }
+
+ ///
+ /// Set the number of threads to be used by the model.
+ ///
+ /// The new thread count
+ public void SetThreadCount(int threadCount)
+ {
+ NativeMethods.llmodel_setThreadCount(_handle, threadCount);
+ }
+
+ ///
+ /// Get the number of threads used by the model.
+ ///
+ /// the number of threads used by the model
+ public int GetThreadCount()
+ {
+ return NativeMethods.llmodel_threadCount(_handle);
+ }
+
+ ///
+ /// Get the size of the internal state of the model.
+ ///
+ ///
+ /// This state data is specific to the type of model you have created.
+ ///
+ /// the size in bytes of the internal state of the model
+ public ulong GetStateSizeBytes()
+ {
+ return NativeMethods.llmodel_get_state_size(_handle);
+ }
+
+ ///
+ /// Saves the internal state of the model to the specified destination address.
+ ///
+ /// A pointer to the src
+ /// The number of bytes copied
+ public unsafe ulong SaveStateData(byte* source)
+ {
+ return NativeMethods.llmodel_save_state_data(_handle, source);
+ }
+
+ ///
+ /// Restores the internal state of the model using data from the specified address.
+ ///
+ /// A pointer to destination
+ /// the number of bytes read
+ public unsafe ulong RestoreStateData(byte* destination)
+ {
+ return NativeMethods.llmodel_restore_state_data(_handle, destination);
+ }
+
+ ///
+ /// Check if the model is loaded.
+ ///
+ /// true if the model was loaded successfully, false otherwise.
+ public bool IsLoaded()
+ {
+ return NativeMethods.llmodel_isModelLoaded(_handle);
+ }
+
+ ///
+ /// Load the model from a file.
+ ///
+ /// The path to the model file.
+ /// true if the model was loaded successfully, false otherwise.
+ public bool Load(string modelPath)
+ {
+ return NativeMethods.llmodel_loadModel(_handle, modelPath);
+ }
+
+ protected void Destroy()
+ {
+ NativeMethods.llmodel_model_destroy(_handle);
+ }
+ protected virtual void Dispose(bool disposing)
+ {
+ if (_disposed) return;
+
+ if (disposing)
+ {
+ // dispose managed state
+ }
+
+ switch (_modelType)
+ {
+ default:
+ Destroy();
+ break;
+ }
+
+ _disposed = true;
+ }
+
+ public void Dispose()
+ {
+ Dispose(disposing: true);
+ GC.SuppressFinalize(this);
+ }
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs
index eeec504d..cec6948e 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs
@@ -1,138 +1,138 @@
-namespace Gpt4All.Bindings;
-
-///
-/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
-///
-///
-/// The implementation takes care of all the memory handling of the raw logits pointer and the
-/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
-///
-public unsafe class LLModelPromptContext
-{
- private llmodel_prompt_context _ctx;
-
- internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
-
- public LLModelPromptContext()
- {
- _ctx = new();
- }
-
- ///
- /// logits of current context
- ///
- public Span Logits => new(_ctx.logits, (int)_ctx.logits_size);
-
- ///
- /// the size of the raw logits vector
- ///
- public nuint LogitsSize
- {
- get => _ctx.logits_size;
- set => _ctx.logits_size = value;
- }
-
- ///
- /// current tokens in the context window
- ///
- public Span Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
-
- ///
- /// the size of the raw tokens vector
- ///
- public nuint TokensSize
- {
- get => _ctx.tokens_size;
- set => _ctx.tokens_size = value;
- }
-
- ///
- /// top k logits to sample from
- ///
- public int TopK
- {
- get => _ctx.top_k;
- set => _ctx.top_k = value;
- }
-
- ///
- /// nucleus sampling probability threshold
- ///
- public float TopP
- {
- get => _ctx.top_p;
- set => _ctx.top_p = value;
- }
-
- ///
- /// temperature to adjust model's output distribution
- ///
- public float Temperature
- {
- get => _ctx.temp;
- set => _ctx.temp = value;
- }
-
- ///
- /// number of tokens in past conversation
- ///
- public int PastNum
- {
- get => _ctx.n_past;
- set => _ctx.n_past = value;
- }
-
- ///
- /// number of predictions to generate in parallel
- ///
- public int Batches
- {
- get => _ctx.n_batch;
- set => _ctx.n_batch = value;
- }
-
- ///
- /// number of tokens to predict
- ///
- public int TokensToPredict
- {
- get => _ctx.n_predict;
- set => _ctx.n_predict = value;
- }
-
- ///
- /// penalty factor for repeated tokens
- ///
- public float RepeatPenalty
- {
- get => _ctx.repeat_penalty;
- set => _ctx.repeat_penalty = value;
- }
-
- ///
- /// last n tokens to penalize
- ///
- public int RepeatLastN
- {
- get => _ctx.repeat_last_n;
- set => _ctx.repeat_last_n = value;
- }
-
- ///
- /// number of tokens possible in context window
- ///
- public int ContextSize
- {
- get => _ctx.n_ctx;
- set => _ctx.n_ctx = value;
- }
-
- ///
- /// percent of context to erase if we exceed the context window
- ///
- public float ContextErase
- {
- get => _ctx.context_erase;
- set => _ctx.context_erase = value;
- }
-}
+namespace Gpt4All.Bindings;
+
+///
+/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
+///
+///
+/// The implementation takes care of all the memory handling of the raw logits pointer and the
+/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
+///
+public unsafe class LLModelPromptContext
+{
+ private llmodel_prompt_context _ctx;
+
+ internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
+
+ public LLModelPromptContext()
+ {
+ _ctx = new();
+ }
+
+ ///
+ /// logits of current context
+ ///
+ public Span Logits => new(_ctx.logits, (int)_ctx.logits_size);
+
+ ///
+ /// the size of the raw logits vector
+ ///
+ public nuint LogitsSize
+ {
+ get => _ctx.logits_size;
+ set => _ctx.logits_size = value;
+ }
+
+ ///
+ /// current tokens in the context window
+ ///
+ public Span Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
+
+ ///
+ /// the size of the raw tokens vector
+ ///
+ public nuint TokensSize
+ {
+ get => _ctx.tokens_size;
+ set => _ctx.tokens_size = value;
+ }
+
+ ///
+ /// top k logits to sample from
+ ///
+ public int TopK
+ {
+ get => _ctx.top_k;
+ set => _ctx.top_k = value;
+ }
+
+ ///
+ /// nucleus sampling probability threshold
+ ///
+ public float TopP
+ {
+ get => _ctx.top_p;
+ set => _ctx.top_p = value;
+ }
+
+ ///
+ /// temperature to adjust model's output distribution
+ ///
+ public float Temperature
+ {
+ get => _ctx.temp;
+ set => _ctx.temp = value;
+ }
+
+ ///
+ /// number of tokens in past conversation
+ ///
+ public int PastNum
+ {
+ get => _ctx.n_past;
+ set => _ctx.n_past = value;
+ }
+
+ ///
+ /// number of predictions to generate in parallel
+ ///
+ public int Batches
+ {
+ get => _ctx.n_batch;
+ set => _ctx.n_batch = value;
+ }
+
+ ///
+ /// number of tokens to predict
+ ///
+ public int TokensToPredict
+ {
+ get => _ctx.n_predict;
+ set => _ctx.n_predict = value;
+ }
+
+ ///
+ /// penalty factor for repeated tokens
+ ///
+ public float RepeatPenalty
+ {
+ get => _ctx.repeat_penalty;
+ set => _ctx.repeat_penalty = value;
+ }
+
+ ///
+ /// last n tokens to penalize
+ ///
+ public int RepeatLastN
+ {
+ get => _ctx.repeat_last_n;
+ set => _ctx.repeat_last_n = value;
+ }
+
+ ///
+ /// number of tokens possible in context window
+ ///
+ public int ContextSize
+ {
+ get => _ctx.n_ctx;
+ set => _ctx.n_ctx = value;
+ }
+
+ ///
+ /// percent of context to erase if we exceed the context window
+ ///
+ public float ContextErase
+ {
+ get => _ctx.context_erase;
+ set => _ctx.context_erase = value;
+ }
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs
index c77212ca..0a606009 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs
@@ -1,126 +1,107 @@
-using System.Runtime.InteropServices;
-
-namespace Gpt4All.Bindings;
-
-public unsafe partial struct llmodel_prompt_context
-{
- public float* logits;
-
- [NativeTypeName("size_t")]
- public nuint logits_size;
-
- [NativeTypeName("int32_t *")]
- public int* tokens;
-
- [NativeTypeName("size_t")]
- public nuint tokens_size;
-
- [NativeTypeName("int32_t")]
- public int n_past;
-
- [NativeTypeName("int32_t")]
- public int n_ctx;
-
- [NativeTypeName("int32_t")]
- public int n_predict;
-
- [NativeTypeName("int32_t")]
- public int top_k;
-
- public float top_p;
-
- public float temp;
-
- [NativeTypeName("int32_t")]
- public int n_batch;
-
- public float repeat_penalty;
-
- [NativeTypeName("int32_t")]
- public int repeat_last_n;
-
- public float context_erase;
-}
-
-internal static unsafe partial class NativeMethods
-{
- [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
- [return: MarshalAs(UnmanagedType.I1)]
- public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
-
- [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
- [return: MarshalAs(UnmanagedType.I1)]
- public delegate bool LlmodelPromptCallback(int token_id);
-
- [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
- [return: MarshalAs(UnmanagedType.I1)]
- public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- [return: NativeTypeName("llmodel_model")]
- public static extern IntPtr llmodel_gptj_create();
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- public static extern void llmodel_gptj_destroy([NativeTypeName("llmodel_model")] IntPtr gptj);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- [return: NativeTypeName("llmodel_model")]
- public static extern IntPtr llmodel_mpt_create();
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- public static extern void llmodel_mpt_destroy([NativeTypeName("llmodel_model")] IntPtr mpt);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- [return: NativeTypeName("llmodel_model")]
- public static extern IntPtr llmodel_llama_create();
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- public static extern void llmodel_llama_destroy([NativeTypeName("llmodel_model")] IntPtr llama);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
- [return: NativeTypeName("llmodel_model")]
- public static extern IntPtr llmodel_model_create(
- [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
- [return: MarshalAs(UnmanagedType.I1)]
- public static extern bool llmodel_loadModel(
- [NativeTypeName("llmodel_model")] IntPtr model,
- [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
-
- [return: MarshalAs(UnmanagedType.I1)]
- public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- [return: NativeTypeName("uint64_t")]
- public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- [return: NativeTypeName("uint64_t")]
- public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- [return: NativeTypeName("uint64_t")]
- public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
- public static extern void llmodel_prompt(
- [NativeTypeName("llmodel_model")] IntPtr model,
- [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
- LlmodelPromptCallback prompt_callback,
- LlmodelResponseCallback response_callback,
- LlmodelRecalculateCallback recalculate_callback,
- ref llmodel_prompt_context ctx);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
-
- [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
- [return: NativeTypeName("int32_t")]
- public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
-}
+using System.Runtime.InteropServices;
+
+namespace Gpt4All.Bindings;
+
+public unsafe partial struct llmodel_prompt_context
+{
+ public float* logits;
+
+ [NativeTypeName("size_t")]
+ public nuint logits_size;
+
+ [NativeTypeName("int32_t *")]
+ public int* tokens;
+
+ [NativeTypeName("size_t")]
+ public nuint tokens_size;
+
+ [NativeTypeName("int32_t")]
+ public int n_past;
+
+ [NativeTypeName("int32_t")]
+ public int n_ctx;
+
+ [NativeTypeName("int32_t")]
+ public int n_predict;
+
+ [NativeTypeName("int32_t")]
+ public int top_k;
+
+ public float top_p;
+
+ public float temp;
+
+ [NativeTypeName("int32_t")]
+ public int n_batch;
+
+ public float repeat_penalty;
+
+ [NativeTypeName("int32_t")]
+ public int repeat_last_n;
+
+ public float context_erase;
+}
+
+internal static unsafe partial class NativeMethods
+{
+ [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.I1)]
+ public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
+
+ [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.I1)]
+ public delegate bool LlmodelPromptCallback(int token_id);
+
+ [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.I1)]
+ public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
+ [return: NativeTypeName("llmodel_model")]
+ public static extern IntPtr llmodel_model_create2(
+ [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
+ [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string build_variant,
+ out IntPtr error);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
+ public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
+ [return: MarshalAs(UnmanagedType.I1)]
+ public static extern bool llmodel_loadModel(
+ [NativeTypeName("llmodel_model")] IntPtr model,
+ [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
+
+ [return: MarshalAs(UnmanagedType.I1)]
+ public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
+ [return: NativeTypeName("uint64_t")]
+ public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
+ [return: NativeTypeName("uint64_t")]
+ public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
+ [return: NativeTypeName("uint64_t")]
+ public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
+ public static extern void llmodel_prompt(
+ [NativeTypeName("llmodel_model")] IntPtr model,
+ [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
+ LlmodelPromptCallback prompt_callback,
+ LlmodelResponseCallback response_callback,
+ LlmodelRecalculateCallback recalculate_callback,
+ ref llmodel_prompt_context ctx);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
+ public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
+
+ [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
+ [return: NativeTypeName("int32_t")]
+ public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj
index db6780fe..72885512 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj
+++ b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj
@@ -1,27 +1,11 @@
-
-
- net6.0
- enable
- enable
- true
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+ net6.0
+ enable
+ enable
+ true
+
+
+
+
diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/ILibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/ILibraryLoader.cs
new file mode 100644
index 00000000..c4e462f8
--- /dev/null
+++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/ILibraryLoader.cs
@@ -0,0 +1,6 @@
+namespace Gpt4All.LibraryLoader;
+
+public interface ILibraryLoader
+{
+ LoadResult OpenLibrary(string? fileName);
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LinuxLibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LinuxLibraryLoader.cs
new file mode 100644
index 00000000..d7f6834a
--- /dev/null
+++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LinuxLibraryLoader.cs
@@ -0,0 +1,53 @@
+using System.Runtime.InteropServices;
+
+namespace Gpt4All.LibraryLoader;
+
+internal class LinuxLibraryLoader : ILibraryLoader
+{
+#pragma warning disable CA2101
+ [DllImport("libdl.so", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")]
+#pragma warning restore CA2101
+ public static extern IntPtr NativeOpenLibraryLibdl(string? filename, int flags);
+
+#pragma warning disable CA2101
+ [DllImport("libdl.so.2", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")]
+#pragma warning restore CA2101
+ public static extern IntPtr NativeOpenLibraryLibdl2(string? filename, int flags);
+
+ [DllImport("libdl.so", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")]
+ public static extern IntPtr GetLoadError();
+
+ [DllImport("libdl.so.2", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")]
+ public static extern IntPtr GetLoadError2();
+
+ public LoadResult OpenLibrary(string? fileName)
+ {
+ IntPtr loadedLib;
+ try
+ {
+ // open with rtls lazy flag
+ loadedLib = NativeOpenLibraryLibdl2(fileName, 0x00001);
+ }
+ catch (DllNotFoundException)
+ {
+ loadedLib = NativeOpenLibraryLibdl(fileName, 0x00001);
+ }
+
+ if (loadedLib == IntPtr.Zero)
+ {
+ string errorMessage;
+ try
+ {
+ errorMessage = Marshal.PtrToStringAnsi(GetLoadError2()) ?? "Unknown error";
+ }
+ catch (DllNotFoundException)
+ {
+ errorMessage = Marshal.PtrToStringAnsi(GetLoadError()) ?? "Unknown error";
+ }
+
+ return LoadResult.Failure(errorMessage);
+ }
+
+ return LoadResult.Success;
+ }
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LoadResult.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LoadResult.cs
new file mode 100644
index 00000000..3dccf358
--- /dev/null
+++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LoadResult.cs
@@ -0,0 +1,20 @@
+namespace Gpt4All.LibraryLoader;
+
+public class LoadResult
+{
+ private LoadResult(bool isSuccess, string? errorMessage)
+ {
+ IsSuccess = isSuccess;
+ ErrorMessage = errorMessage;
+ }
+
+ public static LoadResult Success { get; } = new(true, null);
+
+ public static LoadResult Failure(string errorMessage)
+ {
+ return new(false, errorMessage);
+ }
+
+ public bool IsSuccess { get; }
+ public string? ErrorMessage { get; }
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/MacOsLibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/MacOsLibraryLoader.cs
new file mode 100644
index 00000000..6577d979
--- /dev/null
+++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/MacOsLibraryLoader.cs
@@ -0,0 +1,28 @@
+using System.Runtime.InteropServices;
+
+namespace Gpt4All.LibraryLoader;
+
+internal class MacOsLibraryLoader : ILibraryLoader
+{
+#pragma warning disable CA2101
+ [DllImport("libdl.dylib", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")]
+#pragma warning restore CA2101
+ public static extern IntPtr NativeOpenLibraryLibdl(string? filename, int flags);
+
+ [DllImport("libdl.dylib", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")]
+ public static extern IntPtr GetLoadError();
+
+ public LoadResult OpenLibrary(string? fileName)
+ {
+ var loadedLib = NativeOpenLibraryLibdl(fileName, 0x00001);
+
+ if (loadedLib == IntPtr.Zero)
+ {
+ var errorMessage = Marshal.PtrToStringAnsi(GetLoadError()) ?? "Unknown error";
+
+ return LoadResult.Failure(errorMessage);
+ }
+
+ return LoadResult.Success;
+ }
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/NativeLibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/NativeLibraryLoader.cs
new file mode 100644
index 00000000..85353738
--- /dev/null
+++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/NativeLibraryLoader.cs
@@ -0,0 +1,81 @@
+#if !IOS && !MACCATALYST && !TVOS && !ANDROID
+using System.Runtime.InteropServices;
+#endif
+
+namespace Gpt4All.LibraryLoader;
+
+public static class NativeLibraryLoader
+{
+ private static ILibraryLoader? defaultLibraryLoader;
+
+ ///
+ /// Sets the library loader used to load the native libraries. Overwrite this only if you want some custom loading.
+ ///
+ /// The library loader to be used.
+ public static void SetLibraryLoader(ILibraryLoader libraryLoader)
+ {
+ defaultLibraryLoader = libraryLoader;
+ }
+
+ internal static LoadResult LoadNativeLibrary(string? path = default, bool bypassLoading = true)
+ {
+ // If the user has handled loading the library themselves, we don't need to do anything.
+ if (bypassLoading)
+ {
+ return LoadResult.Success;
+ }
+
+ var architecture = RuntimeInformation.OSArchitecture switch
+ {
+ Architecture.X64 => "x64",
+ Architecture.X86 => "x86",
+ Architecture.Arm => "arm",
+ Architecture.Arm64 => "arm64",
+ _ => throw new PlatformNotSupportedException(
+ $"Unsupported OS platform, architecture: {RuntimeInformation.OSArchitecture}")
+ };
+
+ var (platform, extension) = Environment.OSVersion.Platform switch
+ {
+ _ when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ("win", "dll"),
+ _ when RuntimeInformation.IsOSPlatform(OSPlatform.Linux) => ("linux", "so"),
+ _ when RuntimeInformation.IsOSPlatform(OSPlatform.OSX) => ("osx", "dylib"),
+ _ => throw new PlatformNotSupportedException(
+ $"Unsupported OS platform, architecture: {RuntimeInformation.OSArchitecture}")
+ };
+
+ // If the user hasn't set the path, we'll try to find it ourselves.
+ if (string.IsNullOrEmpty(path))
+ {
+ var libraryName = "libllmodel";
+ var assemblySearchPath = new[]
+ {
+ AppDomain.CurrentDomain.RelativeSearchPath,
+ Path.GetDirectoryName(typeof(NativeLibraryLoader).Assembly.Location),
+ Path.GetDirectoryName(Environment.GetCommandLineArgs()[0])
+ }.FirstOrDefault(it => !string.IsNullOrEmpty(it));
+ // Search for the library dll within the assembly search path. If it doesn't exist, for whatever reason, use the default path.
+ path = Directory.EnumerateFiles(assemblySearchPath ?? string.Empty, $"{libraryName}.{extension}", SearchOption.AllDirectories).FirstOrDefault() ?? Path.Combine("runtimes", $"{platform}-{architecture}", $"{libraryName}.{extension}");
+ }
+
+ if (defaultLibraryLoader != null)
+ {
+ return defaultLibraryLoader.OpenLibrary(path);
+ }
+
+ if (!File.Exists(path))
+ {
+ throw new FileNotFoundException($"Native Library not found in path {path}. " +
+ $"Verify you have have included the native Gpt4All library in your application.");
+ }
+
+ ILibraryLoader libraryLoader = platform switch
+ {
+ "win" => new WindowsLibraryLoader(),
+ "osx" => new MacOsLibraryLoader(),
+ "linux" => new LinuxLibraryLoader(),
+ _ => throw new PlatformNotSupportedException($"Currently {platform} platform is not supported")
+ };
+ return libraryLoader.OpenLibrary(path);
+ }
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/WindowsLibraryLoader.cs b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/WindowsLibraryLoader.cs
new file mode 100644
index 00000000..d2479aa4
--- /dev/null
+++ b/gpt4all-bindings/csharp/Gpt4All/LibraryLoader/WindowsLibraryLoader.cs
@@ -0,0 +1,24 @@
+using System.ComponentModel;
+using System.Runtime.InteropServices;
+
+namespace Gpt4All.LibraryLoader;
+
+internal class WindowsLibraryLoader : ILibraryLoader
+{
+ public LoadResult OpenLibrary(string? fileName)
+ {
+ var loadedLib = LoadLibrary(fileName);
+
+ if (loadedLib == IntPtr.Zero)
+ {
+ var errorCode = Marshal.GetLastWin32Error();
+ var errorMessage = new Win32Exception(errorCode).Message;
+ return LoadResult.Failure(errorMessage);
+ }
+
+ return LoadResult.Success;
+ }
+
+ [DllImport("kernel32", SetLastError = true, CharSet = CharSet.Auto)]
+ private static extern IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPWStr)] string? lpFileName);
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs
index 3c36ac26..02c5c588 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs
@@ -1,61 +1,58 @@
-using System.Diagnostics;
-using Microsoft.Extensions.Logging;
-using Gpt4All.Bindings;
-using Microsoft.Extensions.Logging.Abstractions;
-
-namespace Gpt4All;
-
-public class Gpt4AllModelFactory : IGpt4AllModelFactory
-{
- private readonly ILoggerFactory _loggerFactory;
- private readonly ILogger _logger;
-
- public Gpt4AllModelFactory(ILoggerFactory? loggerFactory = null)
- {
- _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
- _logger = _loggerFactory.CreateLogger();
- }
-
- private IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null)
- {
- var modelType_ = modelType ?? ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath);
-
- _logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_);
-
- var handle = modelType_ switch
- {
- ModelType.LLAMA => NativeMethods.llmodel_llama_create(),
- ModelType.GPTJ => NativeMethods.llmodel_gptj_create(),
- ModelType.MPT => NativeMethods.llmodel_mpt_create(),
- _ => NativeMethods.llmodel_model_create(modelPath),
- };
-
- _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
- _logger.LogInformation("Model loading started");
-
- var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
-
- _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
-
- if (loadedSuccessfully == false)
- {
- throw new Exception($"Failed to load model: '{modelPath}'");
- }
-
- var logger = _loggerFactory.CreateLogger();
-
- var underlyingModel = LLModel.Create(handle, modelType_, logger: logger);
-
- Debug.Assert(underlyingModel.IsLoaded());
-
- return new Gpt4All(underlyingModel, logger: logger);
- }
-
- public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath, modelType: null);
-
- public IGpt4AllModel LoadMptModel(string modelPath) => CreateModel(modelPath, ModelType.MPT);
-
- public IGpt4AllModel LoadGptjModel(string modelPath) => CreateModel(modelPath, ModelType.GPTJ);
-
- public IGpt4AllModel LoadLlamaModel(string modelPath) => CreateModel(modelPath, ModelType.LLAMA);
-}
+using System.Diagnostics;
+using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.Logging;
+using Gpt4All.Bindings;
+using Gpt4All.LibraryLoader;
+
+namespace Gpt4All;
+
+public class Gpt4AllModelFactory : IGpt4AllModelFactory
+{
+ private readonly ILoggerFactory _loggerFactory;
+ private readonly ILogger _logger;
+ private static bool bypassLoading;
+ private static string? libraryPath;
+
+ private static readonly Lazy libraryLoaded = new(() =>
+ {
+ return NativeLibraryLoader.LoadNativeLibrary(Gpt4AllModelFactory.libraryPath, Gpt4AllModelFactory.bypassLoading);
+ }, true);
+
+ public Gpt4AllModelFactory(string? libraryPath = default, bool bypassLoading = true, ILoggerFactory? loggerFactory = null)
+ {
+ _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
+ _logger = _loggerFactory.CreateLogger();
+ Gpt4AllModelFactory.libraryPath = libraryPath;
+ Gpt4AllModelFactory.bypassLoading = bypassLoading;
+
+ if (!libraryLoaded.Value.IsSuccess)
+ {
+ throw new Exception($"Failed to load native gpt4all library. Error: {libraryLoaded.Value.ErrorMessage}");
+ }
+ }
+
+ private IGpt4AllModel CreateModel(string modelPath)
+ {
+ var modelType_ = ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath);
+ _logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_);
+ IntPtr error;
+ var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
+ _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
+ _logger.LogInformation("Model loading started");
+ var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
+ _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
+ if (!loadedSuccessfully)
+ {
+ throw new Exception($"Failed to load model: '{modelPath}'");
+ }
+
+ var logger = _loggerFactory.CreateLogger();
+ var underlyingModel = LLModel.Create(handle, modelType_, logger: logger);
+
+ Debug.Assert(underlyingModel.IsLoaded());
+
+ return new Gpt4All(underlyingModel, logger: logger);
+ }
+
+ public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath);
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModelFactory.cs b/gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModelFactory.cs
index 3a5208aa..90c54d5e 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModelFactory.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModelFactory.cs
@@ -1,12 +1,6 @@
-namespace Gpt4All;
-
-public interface IGpt4AllModelFactory
-{
- IGpt4AllModel LoadGptjModel(string modelPath);
-
- IGpt4AllModel LoadLlamaModel(string modelPath);
-
- IGpt4AllModel LoadModel(string modelPath);
-
- IGpt4AllModel LoadMptModel(string modelPath);
-}
+namespace Gpt4All;
+
+public interface IGpt4AllModelFactory
+{
+ IGpt4AllModel LoadModel(string modelPath);
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs b/gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs
index 4aced85a..c490d8b1 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs
@@ -1,11 +1,11 @@
-namespace Gpt4All;
-
-///
-/// The supported model types
-///
-public enum ModelType
-{
- LLAMA = 0,
- GPTJ,
- MPT
-}
+namespace Gpt4All;
+
+///
+/// The supported model types
+///
+public enum ModelType
+{
+ LLAMA = 0,
+ GPTJ,
+ MPT
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/Prediction/ITextPrediction.cs b/gpt4all-bindings/csharp/Gpt4All/Prediction/ITextPrediction.cs
index c446feef..47ed3847 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Prediction/ITextPrediction.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Prediction/ITextPrediction.cs
@@ -1,31 +1,31 @@
-namespace Gpt4All;
-
-///
-/// Interface for text prediction services
-///
-public interface ITextPrediction
-{
- ///
- /// Get prediction results for the prompt and provided options.
- ///
- /// The text to complete
- /// The prediction settings
- /// The for cancellation requests. The default is .
- /// The prediction result generated by the model
- Task GetPredictionAsync(
- string text,
- PredictRequestOptions opts,
- CancellationToken cancellation = default);
-
- ///
- /// Get streaming prediction results for the prompt and provided options.
- ///
- /// The text to complete
- /// The prediction settings
- /// The for cancellation requests. The default is .
- /// The prediction result generated by the model
- Task GetStreamingPredictionAsync(
- string text,
- PredictRequestOptions opts,
- CancellationToken cancellationToken = default);
-}
\ No newline at end of file
+namespace Gpt4All;
+
+///
+/// Interface for text prediction services
+///
+public interface ITextPrediction
+{
+ ///
+ /// Get prediction results for the prompt and provided options.
+ ///
+ /// The text to complete
+ /// The prediction settings
+ /// The for cancellation requests. The default is .
+ /// The prediction result generated by the model
+ Task GetPredictionAsync(
+ string text,
+ PredictRequestOptions opts,
+ CancellationToken cancellation = default);
+
+ ///
+ /// Get streaming prediction results for the prompt and provided options.
+ ///
+ /// The text to complete
+ /// The prediction settings
+ /// The for cancellation requests. The default is .
+ /// The prediction result generated by the model
+ Task GetStreamingPredictionAsync(
+ string text,
+ PredictRequestOptions opts,
+ CancellationToken cancellationToken = default);
+}
diff --git a/gpt4all-bindings/csharp/build_linux.sh b/gpt4all-bindings/csharp/build_linux.sh
index a89969e2..b747c35f 100755
--- a/gpt4all-bindings/csharp/build_linux.sh
+++ b/gpt4all-bindings/csharp/build_linux.sh
@@ -5,4 +5,6 @@ mkdir runtimes/linux-x64/build
cmake -S ../../gpt4all-backend -B runtimes/linux-x64/build
cmake --build runtimes/linux-x64/build --parallel --config Release
cp runtimes/linux-x64/build/libllmodel.so runtimes/linux-x64/native/libllmodel.so
-cp runtimes/linux-x64/build/llama.cpp/libllama.so runtimes/linux-x64/native/libllama.so
+cp runtimes/linux-x64/build/libgptj*.so runtimes/linux-x64/native/
+cp runtimes/linux-x64/build/libllama*.so runtimes/linux-x64/native/
+cp runtimes/linux-x64/build/libmpt*.so runtimes/linux-x64/native/
diff --git a/gpt4all-bindings/csharp/build_win-mingw.ps1 b/gpt4all-bindings/csharp/build_win-mingw.ps1
index f3d17a3c..1e3dd8ef 100644
--- a/gpt4all-bindings/csharp/build_win-mingw.ps1
+++ b/gpt4all-bindings/csharp/build_win-mingw.ps1
@@ -13,4 +13,5 @@ cmake --build $BUILD_DIR --parallel --config Release
# copy native dlls
cp "C:\ProgramData\chocolatey\lib\mingw\tools\install\mingw64\bin\*dll" $LIBS_DIR
-cp "$BUILD_DIR\*.dll" $LIBS_DIR
\ No newline at end of file
+cp "$BUILD_DIR\bin\*.dll" $LIBS_DIR
+mv $LIBS_DIR\llmodel.dll $LIBS_DIR\libllmodel.dll
\ No newline at end of file
diff --git a/gpt4all-bindings/csharp/build_win-msvc.ps1 b/gpt4all-bindings/csharp/build_win-msvc.ps1
index 01a65886..8d44f3a7 100644
--- a/gpt4all-bindings/csharp/build_win-msvc.ps1
+++ b/gpt4all-bindings/csharp/build_win-msvc.ps1
@@ -2,4 +2,5 @@ Remove-Item -Force -Recurse .\runtimes\win-x64\msvc -ErrorAction SilentlyContinu
mkdir .\runtimes\win-x64\msvc\build | Out-Null
cmake -G "Visual Studio 17 2022" -A X64 -S ..\..\gpt4all-backend -B .\runtimes\win-x64\msvc\build
cmake --build .\runtimes\win-x64\msvc\build --parallel --config Release
-cp .\runtimes\win-x64\msvc\build\bin\Release\*.dll .\runtimes\win-x64
\ No newline at end of file
+cp .\runtimes\win-x64\msvc\build\bin\Release\*.dll .\runtimes\win-x64
+mv .\runtimes\win-x64\llmodel.dll .\runtimes\win-x64\libllmodel.dll
\ No newline at end of file