mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-17 16:28:20 +00:00
C# bindings (#650)
* First workin version of the C# bindings * Update README.md Signed-off-by: mvenditto <venditto.matteo@gmail.com> * Added more docs + fixed prompt callback signature * build scripts revision * Added .editorconfig + fixed style issues --------- Signed-off-by: mvenditto <venditto.matteo@gmail.com>
This commit is contained in:
31
gpt4all-bindings/csharp/Gpt4All/Bindings/ILLModel.cs
Normal file
31
gpt4all-bindings/csharp/Gpt4All/Bindings/ILLModel.cs
Normal file
@@ -0,0 +1,31 @@
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
/// <summary>
|
||||
/// Represents the interface exposed by the universal wrapper for GPT4All language models built around llmodel C-API.
|
||||
/// </summary>
|
||||
public interface ILLModel : IDisposable
|
||||
{
|
||||
ModelType ModelType { get; }
|
||||
|
||||
ulong GetStateSizeBytes();
|
||||
|
||||
int GetThreadCount();
|
||||
|
||||
void SetThreadCount(int threadCount);
|
||||
|
||||
bool IsLoaded();
|
||||
|
||||
bool Load(string modelPath);
|
||||
|
||||
void Prompt(
|
||||
string text,
|
||||
LLModelPromptContext context,
|
||||
Func<ModelPromptEventArgs, bool>? promptCallback = null,
|
||||
Func<ModelResponseEventArgs, bool>? responseCallback = null,
|
||||
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
unsafe ulong RestoreStateData(byte* destination);
|
||||
|
||||
unsafe ulong SaveStateData(byte* source);
|
||||
}
|
235
gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs
Normal file
235
gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs
Normal file
@@ -0,0 +1,235 @@
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
/// <summary>
|
||||
/// Arguments for the response processing callback
|
||||
/// </summary>
|
||||
/// <param name="TokenId">The token id of the response</param>
|
||||
/// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param>
|
||||
/// <return>
|
||||
/// A bool indicating whether the model should keep generating
|
||||
/// </return>
|
||||
public record ModelResponseEventArgs(int TokenId, string Response)
|
||||
{
|
||||
public bool IsError => TokenId == -1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Arguments for the prompt processing callback
|
||||
/// </summary>
|
||||
/// <param name="TokenId">The token id of the prompt</param>
|
||||
/// <return>
|
||||
/// A bool indicating whether the model should keep processing
|
||||
/// </return>
|
||||
public record ModelPromptEventArgs(int TokenId)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Arguments for the recalculating callback
|
||||
/// </summary>
|
||||
/// <param name="IsRecalculating"> whether the model is recalculating the context.</param>
|
||||
/// <return>
|
||||
/// A bool indicating whether the model should keep generating
|
||||
/// </return>
|
||||
public record ModelRecalculatingEventArgs(bool IsRecalculating);
|
||||
|
||||
/// <summary>
|
||||
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
|
||||
/// </summary>
|
||||
public class LLModel : ILLModel
|
||||
{
|
||||
protected readonly IntPtr _handle;
|
||||
private readonly ModelType _modelType;
|
||||
private bool _disposed;
|
||||
|
||||
public ModelType ModelType => _modelType;
|
||||
|
||||
internal LLModel(IntPtr handle, ModelType modelType)
|
||||
{
|
||||
_handle = handle;
|
||||
_modelType = modelType;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new model from a pointer
|
||||
/// </summary>
|
||||
/// <param name="handle">Pointer to underlying model</param>
|
||||
/// <param name="modelType">The model type</param>
|
||||
public static LLModel Create(IntPtr handle, ModelType modelType)
|
||||
{
|
||||
return new LLModel(handle, modelType);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Generate a response using the model
|
||||
/// </summary>
|
||||
/// <param name="text">The input promp</param>
|
||||
/// <param name="context">The context</param>
|
||||
/// <param name="promptCallback">A callback function for handling the processing of prompt</param>
|
||||
/// <param name="responseCallback">A callback function for handling the generated response</param>
|
||||
/// <param name="recalculateCallback">A callback function for handling recalculation requests</param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
public void Prompt(
|
||||
string text,
|
||||
LLModelPromptContext context,
|
||||
Func<ModelPromptEventArgs, bool>? promptCallback = null,
|
||||
Func<ModelResponseEventArgs, bool>? responseCallback = null,
|
||||
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
GC.KeepAlive(promptCallback);
|
||||
GC.KeepAlive(responseCallback);
|
||||
GC.KeepAlive(recalculateCallback);
|
||||
GC.KeepAlive(cancellationToken);
|
||||
|
||||
NativeMethods.llmodel_prompt(
|
||||
_handle,
|
||||
text,
|
||||
(tokenId) =>
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested) return false;
|
||||
if (promptCallback == null) return true;
|
||||
var args = new ModelPromptEventArgs(tokenId);
|
||||
return promptCallback(args);
|
||||
},
|
||||
(tokenId, response) =>
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested) return false;
|
||||
if (responseCallback == null) return true;
|
||||
var args = new ModelResponseEventArgs(tokenId, response);
|
||||
return responseCallback(args);
|
||||
},
|
||||
(isRecalculating) =>
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested) return false;
|
||||
if (recalculateCallback == null) return true;
|
||||
var args = new ModelRecalculatingEventArgs(isRecalculating);
|
||||
return recalculateCallback(args);
|
||||
},
|
||||
ref context.UnderlyingContext
|
||||
);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set the number of threads to be used by the model.
|
||||
/// </summary>
|
||||
/// <param name="threadCount">The new thread count</param>
|
||||
public void SetThreadCount(int threadCount)
|
||||
{
|
||||
NativeMethods.llmodel_setThreadCount(_handle, threadCount);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of threads used by the model.
|
||||
/// </summary>
|
||||
/// <returns>the number of threads used by the model</returns>
|
||||
public int GetThreadCount()
|
||||
{
|
||||
return NativeMethods.llmodel_threadCount(_handle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the size of the internal state of the model.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This state data is specific to the type of model you have created.
|
||||
/// </remarks>
|
||||
/// <returns>the size in bytes of the internal state of the model</returns>
|
||||
public ulong GetStateSizeBytes()
|
||||
{
|
||||
return NativeMethods.llmodel_get_state_size(_handle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves the internal state of the model to the specified destination address.
|
||||
/// </summary>
|
||||
/// <param name="source">A pointer to the src</param>
|
||||
/// <returns>The number of bytes copied</returns>
|
||||
public unsafe ulong SaveStateData(byte* source)
|
||||
{
|
||||
return NativeMethods.llmodel_save_state_data(_handle, source);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Restores the internal state of the model using data from the specified address.
|
||||
/// </summary>
|
||||
/// <param name="destination">A pointer to destination</param>
|
||||
/// <returns>the number of bytes read</returns>
|
||||
public unsafe ulong RestoreStateData(byte* destination)
|
||||
{
|
||||
return NativeMethods.llmodel_restore_state_data(_handle, destination);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Check if the model is loaded.
|
||||
/// </summary>
|
||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||
public bool IsLoaded()
|
||||
{
|
||||
return NativeMethods.llmodel_isModelLoaded(_handle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load the model from a file.
|
||||
/// </summary>
|
||||
/// <param name="modelPath">The path to the model file.</param>
|
||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||
public bool Load(string modelPath)
|
||||
{
|
||||
return NativeMethods.llmodel_loadModel(_handle, modelPath);
|
||||
}
|
||||
|
||||
protected void Destroy()
|
||||
{
|
||||
NativeMethods.llmodel_model_destroy(_handle);
|
||||
}
|
||||
|
||||
protected void DestroyLLama()
|
||||
{
|
||||
NativeMethods.llmodel_llama_destroy(_handle);
|
||||
}
|
||||
|
||||
protected void DestroyGptj()
|
||||
{
|
||||
NativeMethods.llmodel_gptj_destroy(_handle);
|
||||
}
|
||||
|
||||
protected void DestroyMtp()
|
||||
{
|
||||
NativeMethods.llmodel_mpt_destroy(_handle);
|
||||
}
|
||||
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed) return;
|
||||
|
||||
if (disposing)
|
||||
{
|
||||
// dispose managed state
|
||||
}
|
||||
|
||||
switch (_modelType)
|
||||
{
|
||||
case ModelType.LLAMA:
|
||||
DestroyLLama();
|
||||
break;
|
||||
case ModelType.GPTJ:
|
||||
DestroyGptj();
|
||||
break;
|
||||
case ModelType.MPT:
|
||||
DestroyMtp();
|
||||
break;
|
||||
default:
|
||||
Destroy();
|
||||
break;
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
Dispose(disposing: true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
}
|
140
gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs
Normal file
140
gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs
Normal file
@@ -0,0 +1,140 @@
|
||||
using System.Reflection;
|
||||
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
/// <summary>
|
||||
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The implementation takes care of all the memory handling of the raw logits pointer and the
|
||||
/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
|
||||
/// </remarks>
|
||||
public unsafe class LLModelPromptContext
|
||||
{
|
||||
private llmodel_prompt_context _ctx;
|
||||
|
||||
internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
|
||||
|
||||
public LLModelPromptContext()
|
||||
{
|
||||
_ctx = new();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// logits of current context
|
||||
/// </summary>
|
||||
public Span<float> Logits => new(_ctx.logits, (int)_ctx.logits_size);
|
||||
|
||||
/// <summary>
|
||||
/// the size of the raw logits vector
|
||||
/// </summary>
|
||||
public nuint LogitsSize
|
||||
{
|
||||
get => _ctx.logits_size;
|
||||
set => _ctx.logits_size = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// current tokens in the context window
|
||||
/// </summary>
|
||||
public Span<int> Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
|
||||
|
||||
/// <summary>
|
||||
/// the size of the raw tokens vector
|
||||
/// </summary>
|
||||
public nuint TokensSize
|
||||
{
|
||||
get => _ctx.tokens_size;
|
||||
set => _ctx.tokens_size = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// top k logits to sample from
|
||||
/// </summary>
|
||||
public int TopK
|
||||
{
|
||||
get => _ctx.top_k;
|
||||
set => _ctx.top_k = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// nucleus sampling probability threshold
|
||||
/// </summary>
|
||||
public float TopP
|
||||
{
|
||||
get => _ctx.top_p;
|
||||
set => _ctx.top_p = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// temperature to adjust model's output distribution
|
||||
/// </summary>
|
||||
public float Temperature
|
||||
{
|
||||
get => _ctx.temp;
|
||||
set => _ctx.temp = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens in past conversation
|
||||
/// </summary>
|
||||
public int PastNum
|
||||
{
|
||||
get => _ctx.n_past;
|
||||
set => _ctx.n_past = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of predictions to generate in parallel
|
||||
/// </summary>
|
||||
public int Batches
|
||||
{
|
||||
get => _ctx.n_batch;
|
||||
set => _ctx.n_batch = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens to predict
|
||||
/// </summary>
|
||||
public int TokensToPredict
|
||||
{
|
||||
get => _ctx.n_predict;
|
||||
set => _ctx.n_predict = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// penalty factor for repeated tokens
|
||||
/// </summary>
|
||||
public float RepeatPenalty
|
||||
{
|
||||
get => _ctx.repeat_penalty;
|
||||
set => _ctx.repeat_penalty = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// last n tokens to penalize
|
||||
/// </summary>
|
||||
public int RepeatLastN
|
||||
{
|
||||
get => _ctx.repeat_last_n;
|
||||
set => _ctx.repeat_last_n = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens possible in context window
|
||||
/// </summary>
|
||||
public int ContextSize
|
||||
{
|
||||
get => _ctx.n_ctx;
|
||||
set => _ctx.n_ctx = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// percent of context to erase if we exceed the context window
|
||||
/// </summary>
|
||||
public float ContextErase
|
||||
{
|
||||
get => _ctx.context_erase;
|
||||
set => _ctx.context_erase = value;
|
||||
}
|
||||
}
|
126
gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs
Normal file
126
gpt4all-bindings/csharp/Gpt4All/Bindings/NativeMethods.cs
Normal file
@@ -0,0 +1,126 @@
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
public unsafe partial struct llmodel_prompt_context
|
||||
{
|
||||
public float* logits;
|
||||
|
||||
[NativeTypeName("size_t")]
|
||||
public nuint logits_size;
|
||||
|
||||
[NativeTypeName("int32_t *")]
|
||||
public int* tokens;
|
||||
|
||||
[NativeTypeName("size_t")]
|
||||
public nuint tokens_size;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_past;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_ctx;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_predict;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int top_k;
|
||||
|
||||
public float top_p;
|
||||
|
||||
public float temp;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_batch;
|
||||
|
||||
public float repeat_penalty;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int repeat_last_n;
|
||||
|
||||
public float context_erase;
|
||||
}
|
||||
|
||||
internal static unsafe partial class NativeMethods
|
||||
{
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public delegate bool LlmodelPromptCallback(int token_id);
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("llmodel_model")]
|
||||
public static extern IntPtr llmodel_gptj_create();
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_gptj_destroy([NativeTypeName("llmodel_model")] IntPtr gptj);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("llmodel_model")]
|
||||
public static extern IntPtr llmodel_mpt_create();
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_mpt_destroy([NativeTypeName("llmodel_model")] IntPtr mpt);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("llmodel_model")]
|
||||
public static extern IntPtr llmodel_llama_create();
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_llama_destroy([NativeTypeName("llmodel_model")] IntPtr llama);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||
[return: NativeTypeName("llmodel_model")]
|
||||
public static extern IntPtr llmodel_model_create(
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public static extern bool llmodel_loadModel(
|
||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("uint64_t")]
|
||||
public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("uint64_t")]
|
||||
public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("uint64_t")]
|
||||
public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||
public static extern void llmodel_prompt(
|
||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||
LlmodelPromptCallback prompt_callback,
|
||||
LlmodelResponseCallback response_callback,
|
||||
LlmodelRecalculateCallback recalculate_callback,
|
||||
ref llmodel_prompt_context ctx);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("int32_t")]
|
||||
public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
}
|
@@ -0,0 +1,21 @@
|
||||
using System.Diagnostics;
|
||||
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
/// <summary>Defines the type of a member as it was used in the native signature.</summary>
|
||||
[AttributeUsage(AttributeTargets.Struct | AttributeTargets.Enum | AttributeTargets.Property | AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.ReturnValue, AllowMultiple = false, Inherited = true)]
|
||||
[Conditional("DEBUG")]
|
||||
internal sealed partial class NativeTypeNameAttribute : Attribute
|
||||
{
|
||||
private readonly string _name;
|
||||
|
||||
/// <summary>Initializes a new instance of the <see cref="NativeTypeNameAttribute" /> class.</summary>
|
||||
/// <param name="name">The name of the type that was used in the native signature.</param>
|
||||
public NativeTypeNameAttribute(string name)
|
||||
{
|
||||
_name = name;
|
||||
}
|
||||
|
||||
/// <summary>Gets the name of the type that was used in the native signature.</summary>
|
||||
public string Name => _name;
|
||||
}
|
@@ -0,0 +1,25 @@
|
||||
using Gpt4All.Bindings;
|
||||
|
||||
namespace Gpt4All.Extensions;
|
||||
|
||||
public static class PredictRequestOptionsExtensions
|
||||
{
|
||||
public static LLModelPromptContext ToPromptContext(this PredictRequestOptions opts)
|
||||
{
|
||||
return new LLModelPromptContext
|
||||
{
|
||||
LogitsSize = opts.LogitsSize,
|
||||
TokensSize = opts.TokensSize,
|
||||
TopK = opts.TopK,
|
||||
TopP = opts.TopP,
|
||||
PastNum = opts.PastConversationTokensNum,
|
||||
RepeatPenalty = opts.RepeatPenalty,
|
||||
Temperature = opts.Temperature,
|
||||
RepeatLastN = opts.RepeatLastN,
|
||||
Batches = opts.Batches,
|
||||
ContextErase = opts.ContextErase,
|
||||
ContextSize = opts.ContextSize,
|
||||
TokensToPredict = opts.TokensToPredict
|
||||
};
|
||||
}
|
||||
}
|
21
gpt4all-bindings/csharp/Gpt4All/GenLLModelBindings.rsp
Normal file
21
gpt4all-bindings/csharp/Gpt4All/GenLLModelBindings.rsp
Normal file
@@ -0,0 +1,21 @@
|
||||
--config
|
||||
exclude-funcs-with-body
|
||||
--with-access-specifier
|
||||
*=Public
|
||||
--include-directory
|
||||
..\..\..\gpt4all-backend\
|
||||
--file
|
||||
..\..\..\gpt4all-backend\llmodel_c.h
|
||||
--libraryPath
|
||||
libllmodel
|
||||
--remap
|
||||
sbyte*=IntPtr
|
||||
void*=IntPtr
|
||||
--namespace
|
||||
Gpt4All.Bindings
|
||||
--methodClassName
|
||||
NativeMethods
|
||||
--output
|
||||
.\Bindings\NativeMethods.cs
|
||||
--output-mode
|
||||
CSharp
|
82
gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs
Normal file
82
gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs
Normal file
@@ -0,0 +1,82 @@
|
||||
using Gpt4All.Bindings;
|
||||
using Gpt4All.Extensions;
|
||||
|
||||
namespace Gpt4All;
|
||||
|
||||
public class Gpt4All : IGpt4AllModel
|
||||
{
|
||||
private readonly ILLModel _model;
|
||||
|
||||
internal Gpt4All(ILLModel model)
|
||||
{
|
||||
_model = model;
|
||||
}
|
||||
|
||||
public Task<ITextPredictionResult> GetPredictionAsync(string text, PredictRequestOptions opts, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return Task.Run(() =>
|
||||
{
|
||||
var result = new TextPredictionResult();
|
||||
var context = opts.ToPromptContext();
|
||||
|
||||
_model.Prompt(text, context, responseCallback: e =>
|
||||
{
|
||||
if (e.IsError)
|
||||
{
|
||||
result.Success = false;
|
||||
result.ErrorMessage = e.Response;
|
||||
return false;
|
||||
}
|
||||
result.Append(e.Response);
|
||||
return true;
|
||||
}, cancellationToken: cancellationToken);
|
||||
|
||||
return (ITextPredictionResult)result;
|
||||
}, CancellationToken.None);
|
||||
}
|
||||
|
||||
public Task<ITextPredictionStreamingResult> GetStreamingPredictionAsync(string text, PredictRequestOptions opts, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var result = new TextPredictionStreamingResult();
|
||||
|
||||
_ = Task.Run(() =>
|
||||
{
|
||||
try
|
||||
{
|
||||
var context = opts.ToPromptContext();
|
||||
|
||||
_model.Prompt(text, context, responseCallback: e =>
|
||||
{
|
||||
if (e.IsError)
|
||||
{
|
||||
result.Success = false;
|
||||
result.ErrorMessage = e.Response;
|
||||
return false;
|
||||
}
|
||||
result.Append(e.Response);
|
||||
return true;
|
||||
}, cancellationToken: cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
result.Complete();
|
||||
}
|
||||
}, CancellationToken.None);
|
||||
|
||||
return Task.FromResult((ITextPredictionStreamingResult)result);
|
||||
}
|
||||
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (disposing)
|
||||
{
|
||||
_model.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
Dispose(true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
}
|
23
gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj
Normal file
23
gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj
Normal file
@@ -0,0 +1,23 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFrameworks>net6.0</TargetFrameworks>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<!-- Windows -->
|
||||
<None Include="..\runtimes\win-x64\native\*.dll" Pack="true" PackagePath="runtimes\win-x64\native\%(Filename)%(Extension)" />
|
||||
<!-- Linux -->
|
||||
<None Include="..\runtimes\linux-x64\native\*.so" Pack="true" PackagePath="runtimes\linux-x64\native\%(Filename)%(Extension)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<!-- Windows -->
|
||||
<None Condition="$([MSBuild]::IsOSPlatform('Windows'))" Include="..\runtimes\win-x64\native\*.dll" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
||||
<!-- Linux -->
|
||||
<None Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="..\runtimes\linux-x64\native\*.so" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
||||
</ItemGroup>
|
||||
</Project>
|
41
gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs
Normal file
41
gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs
Normal file
@@ -0,0 +1,41 @@
|
||||
using Gpt4All.Bindings;
|
||||
using System.Diagnostics;
|
||||
|
||||
namespace Gpt4All;
|
||||
|
||||
public class Gpt4AllModelFactory : IGpt4AllModelFactory
|
||||
{
|
||||
private static IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null)
|
||||
{
|
||||
var modelType_ = modelType ?? ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath);
|
||||
|
||||
var handle = modelType_ switch
|
||||
{
|
||||
ModelType.LLAMA => NativeMethods.llmodel_llama_create(),
|
||||
ModelType.GPTJ => NativeMethods.llmodel_gptj_create(),
|
||||
ModelType.MPT => NativeMethods.llmodel_mpt_create(),
|
||||
_ => NativeMethods.llmodel_model_create(modelPath),
|
||||
};
|
||||
|
||||
var loadedSuccesfully = NativeMethods.llmodel_loadModel(handle, modelPath);
|
||||
|
||||
if (loadedSuccesfully == false)
|
||||
{
|
||||
throw new Exception($"Failed to load model: '{modelPath}'");
|
||||
}
|
||||
|
||||
var underlyingModel = LLModel.Create(handle, modelType_);
|
||||
|
||||
Debug.Assert(underlyingModel.IsLoaded());
|
||||
|
||||
return new Gpt4All(underlyingModel);
|
||||
}
|
||||
|
||||
public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath, modelType: null);
|
||||
|
||||
public IGpt4AllModel LoadMptModel(string modelPath) => CreateModel(modelPath, ModelType.MPT);
|
||||
|
||||
public IGpt4AllModel LoadGptjModel(string modelPath) => CreateModel(modelPath, ModelType.GPTJ);
|
||||
|
||||
public IGpt4AllModel LoadLlamaModel(string modelPath) => CreateModel(modelPath, ModelType.LLAMA);
|
||||
}
|
5
gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModel.cs
Normal file
5
gpt4all-bindings/csharp/Gpt4All/Model/IGpt4AllModel.cs
Normal file
@@ -0,0 +1,5 @@
|
||||
namespace Gpt4All;
|
||||
|
||||
public interface IGpt4AllModel : ITextPrediction, IDisposable
|
||||
{
|
||||
}
|
@@ -0,0 +1,12 @@
|
||||
namespace Gpt4All;
|
||||
|
||||
public interface IGpt4AllModelFactory
|
||||
{
|
||||
IGpt4AllModel LoadGptjModel(string modelPath);
|
||||
|
||||
IGpt4AllModel LoadLlamaModel(string modelPath);
|
||||
|
||||
IGpt4AllModel LoadModel(string modelPath);
|
||||
|
||||
IGpt4AllModel LoadMptModel(string modelPath);
|
||||
}
|
24
gpt4all-bindings/csharp/Gpt4All/Model/ModelFileUtils.cs
Normal file
24
gpt4all-bindings/csharp/Gpt4All/Model/ModelFileUtils.cs
Normal file
@@ -0,0 +1,24 @@
|
||||
namespace Gpt4All;
|
||||
|
||||
public static class ModelFileUtils
|
||||
{
|
||||
private const uint GPTJ_MAGIC = 0x67676d6c;
|
||||
private const uint LLAMA_MAGIC = 0x67676a74;
|
||||
private const uint MPT_MAGIC = 0x67676d6d;
|
||||
|
||||
public static ModelType GetModelTypeFromModelFileHeader(string modelPath)
|
||||
{
|
||||
using var fileStream = new FileStream(modelPath, FileMode.Open);
|
||||
using var binReader = new BinaryReader(fileStream);
|
||||
|
||||
var magic = binReader.ReadUInt32();
|
||||
|
||||
return magic switch
|
||||
{
|
||||
GPTJ_MAGIC => ModelType.GPTJ,
|
||||
LLAMA_MAGIC => ModelType.LLAMA,
|
||||
MPT_MAGIC => ModelType.MPT,
|
||||
_ => throw new ArgumentOutOfRangeException($"Invalid model file. magic=0x{magic:X8}"),
|
||||
};
|
||||
}
|
||||
}
|
8
gpt4all-bindings/csharp/Gpt4All/Model/ModelOptions.cs
Normal file
8
gpt4all-bindings/csharp/Gpt4All/Model/ModelOptions.cs
Normal file
@@ -0,0 +1,8 @@
|
||||
namespace Gpt4All;
|
||||
|
||||
public record ModelOptions
|
||||
{
|
||||
public int Threads { get; init; } = 4;
|
||||
|
||||
public ModelType ModelType { get; init; } = ModelType.GPTJ;
|
||||
}
|
11
gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs
Normal file
11
gpt4all-bindings/csharp/Gpt4All/Model/ModelType.cs
Normal file
@@ -0,0 +1,11 @@
|
||||
namespace Gpt4All;
|
||||
|
||||
/// <summary>
|
||||
/// The supported model types
|
||||
/// </summary>
|
||||
public enum ModelType
|
||||
{
|
||||
LLAMA = 0,
|
||||
GPTJ,
|
||||
MPT
|
||||
}
|
@@ -0,0 +1,31 @@
|
||||
namespace Gpt4All;
|
||||
|
||||
/// <summary>
|
||||
/// Interface for text prediction services
|
||||
/// </summary>
|
||||
public interface ITextPrediction
|
||||
{
|
||||
/// <summary>
|
||||
/// Get prediction results for the prompt and provided options.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to complete</param>
|
||||
/// <param name="opts">The prediction settings</param>
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>The prediction result generated by the model</returns>
|
||||
Task<ITextPredictionResult> GetPredictionAsync(
|
||||
string text,
|
||||
PredictRequestOptions opts,
|
||||
CancellationToken cancellation = default);
|
||||
|
||||
/// <summary>
|
||||
/// Get streaming prediction results for the prompt and provided options.
|
||||
/// </summary>
|
||||
/// <param name="text">The text to complete</param>
|
||||
/// <param name="opts">The prediction settings</param>
|
||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||
/// <returns>The prediction result generated by the model</returns>
|
||||
Task<ITextPredictionStreamingResult> GetStreamingPredictionAsync(
|
||||
string text,
|
||||
PredictRequestOptions opts,
|
||||
CancellationToken cancellationToken = default);
|
||||
}
|
@@ -0,0 +1,10 @@
|
||||
namespace Gpt4All;
|
||||
|
||||
public interface ITextPredictionResult
|
||||
{
|
||||
bool Success { get; }
|
||||
|
||||
string? ErrorMessage { get; }
|
||||
|
||||
Task<string> GetPredictionAsync(CancellationToken cancellationToken = default);
|
||||
}
|
@@ -0,0 +1,6 @@
|
||||
namespace Gpt4All;
|
||||
|
||||
public interface ITextPredictionStreamingResult : ITextPredictionResult
|
||||
{
|
||||
IAsyncEnumerable<string> GetPredictionStreamingAsync(CancellationToken cancellationToken = default);
|
||||
}
|
@@ -0,0 +1,30 @@
|
||||
namespace Gpt4All;
|
||||
|
||||
public record PredictRequestOptions
|
||||
{
|
||||
public nuint LogitsSize { get; init; } = 0;
|
||||
|
||||
public nuint TokensSize { get; init; } = 0;
|
||||
|
||||
public int PastConversationTokensNum { get; init; } = 0;
|
||||
|
||||
public int ContextSize { get; init; } = 1024;
|
||||
|
||||
public int TokensToPredict { get; init; } = 128;
|
||||
|
||||
public int TopK { get; init; } = 40;
|
||||
|
||||
public float TopP { get; init; } = 0.9f;
|
||||
|
||||
public float Temperature { get; init; } = 0.1f;
|
||||
|
||||
public int Batches { get; init; } = 8;
|
||||
|
||||
public float RepeatPenalty { get; init; } = 1.2f;
|
||||
|
||||
public int RepeatLastN { get; init; } = 10;
|
||||
|
||||
public float ContextErase { get; init; } = 0.5f;
|
||||
|
||||
public static readonly PredictRequestOptions Defaults = new();
|
||||
}
|
@@ -0,0 +1,27 @@
|
||||
using System.Text;
|
||||
|
||||
namespace Gpt4All;
|
||||
|
||||
public record TextPredictionResult : ITextPredictionResult
|
||||
{
|
||||
private readonly StringBuilder _result;
|
||||
|
||||
public bool Success { get; internal set; } = true;
|
||||
|
||||
public string? ErrorMessage { get; internal set; }
|
||||
|
||||
internal TextPredictionResult()
|
||||
{
|
||||
_result = new StringBuilder();
|
||||
}
|
||||
|
||||
internal void Append(string token)
|
||||
{
|
||||
_result.Append(token);
|
||||
}
|
||||
|
||||
public Task<string> GetPredictionAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
return Task.FromResult(_result.ToString());
|
||||
}
|
||||
}
|
@@ -0,0 +1,49 @@
|
||||
using System.Text;
|
||||
using System.Threading.Channels;
|
||||
|
||||
namespace Gpt4All;
|
||||
|
||||
public record TextPredictionStreamingResult : ITextPredictionStreamingResult
|
||||
{
|
||||
private readonly Channel<string> _channel;
|
||||
|
||||
public bool Success { get; internal set; } = true;
|
||||
|
||||
public string? ErrorMessage { get; internal set; }
|
||||
|
||||
public Task Completion => _channel.Reader.Completion;
|
||||
|
||||
internal TextPredictionStreamingResult()
|
||||
{
|
||||
_channel = Channel.CreateUnbounded<string>();
|
||||
}
|
||||
|
||||
internal bool Append(string token)
|
||||
{
|
||||
return _channel.Writer.TryWrite(token);
|
||||
}
|
||||
|
||||
internal void Complete()
|
||||
{
|
||||
_channel.Writer.Complete();
|
||||
}
|
||||
|
||||
public async Task<string> GetPredictionAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
|
||||
var tokens = GetPredictionStreamingAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await foreach (var token in tokens)
|
||||
{
|
||||
sb.Append(token);
|
||||
}
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<string> GetPredictionStreamingAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
return _channel.Reader.ReadAllAsync(cancellationToken);
|
||||
}
|
||||
}
|
1
gpt4all-bindings/csharp/Gpt4All/gen_bindings.ps1
Normal file
1
gpt4all-bindings/csharp/Gpt4All/gen_bindings.ps1
Normal file
@@ -0,0 +1 @@
|
||||
ClangSharpPInvokeGenerator @(Get-Content .\GenLLModelBindings.rsp)
|
Reference in New Issue
Block a user