Initial Library Loader for .NET Bindings / Update bindings to support newest changes (#763)

* Initial Library Loader

* Load library as part of Model factory

* Dynamically search and find the dlls

* Update tests to use locally built runtimes

* Fix dylib loading, add macos runtime support for sample/tests

* Bypass automatic loading by default.

* Only set CMAKE_OSX_ARCHITECTURES if not already set, allow cross-compile

* Switch Loading again

* Update build scripts for mac/linux

* Update bindings to support newest breaking changes

* Fix build

* Use llmodel for Windows

* Actually, it does need to be libllmodel

* Name

* Remove TFMs, bypass loading by default

* Fix script

* Delete mac script

---------

Co-authored-by: Tim Miller <innerlogic4321@ghmail.com>
This commit is contained in:
Tim Miller
2023-06-13 21:05:34 +09:00
committed by GitHub
parent 88616fde7f
commit 797891c995
21 changed files with 850 additions and 671 deletions

View File

@@ -1,247 +1,222 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
namespace Gpt4All.Bindings;
/// <summary>
/// Arguments for the response processing callback
/// </summary>
/// <param name="TokenId">The token id of the response</param>
/// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param>
/// <return>
/// A bool indicating whether the model should keep generating
/// </return>
public record ModelResponseEventArgs(int TokenId, string Response)
{
public bool IsError => TokenId == -1;
}
/// <summary>
/// Arguments for the prompt processing callback
/// </summary>
/// <param name="TokenId">The token id of the prompt</param>
/// <return>
/// A bool indicating whether the model should keep processing
/// </return>
public record ModelPromptEventArgs(int TokenId)
{
}
/// <summary>
/// Arguments for the recalculating callback
/// </summary>
/// <param name="IsRecalculating"> whether the model is recalculating the context.</param>
/// <return>
/// A bool indicating whether the model should keep generating
/// </return>
public record ModelRecalculatingEventArgs(bool IsRecalculating);
/// <summary>
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
/// </summary>
public class LLModel : ILLModel
{
protected readonly IntPtr _handle;
private readonly ModelType _modelType;
private readonly ILogger _logger;
private bool _disposed;
public ModelType ModelType => _modelType;
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
{
_handle = handle;
_modelType = modelType;
_logger = logger ?? NullLogger.Instance;
}
/// <summary>
/// Create a new model from a pointer
/// </summary>
/// <param name="handle">Pointer to underlying model</param>
/// <param name="modelType">The model type</param>
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
{
return new LLModel(handle, modelType, logger: logger);
}
/// <summary>
/// Generate a response using the model
/// </summary>
/// <param name="text">The input promp</param>
/// <param name="context">The context</param>
/// <param name="promptCallback">A callback function for handling the processing of prompt</param>
/// <param name="responseCallback">A callback function for handling the generated response</param>
/// <param name="recalculateCallback">A callback function for handling recalculation requests</param>
/// <param name="cancellationToken"></param>
public void Prompt(
string text,
LLModelPromptContext context,
Func<ModelPromptEventArgs, bool>? promptCallback = null,
Func<ModelResponseEventArgs, bool>? responseCallback = null,
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
CancellationToken cancellationToken = default)
{
GC.KeepAlive(promptCallback);
GC.KeepAlive(responseCallback);
GC.KeepAlive(recalculateCallback);
GC.KeepAlive(cancellationToken);
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
NativeMethods.llmodel_prompt(
_handle,
text,
(tokenId) =>
{
if (cancellationToken.IsCancellationRequested) return false;
if (promptCallback == null) return true;
var args = new ModelPromptEventArgs(tokenId);
return promptCallback(args);
},
(tokenId, response) =>
{
if (cancellationToken.IsCancellationRequested)
{
_logger.LogDebug("ResponseCallback evt=CancellationRequested");
return false;
}
if (responseCallback == null) return true;
var args = new ModelResponseEventArgs(tokenId, response);
return responseCallback(args);
},
(isRecalculating) =>
{
if (cancellationToken.IsCancellationRequested) return false;
if (recalculateCallback == null) return true;
var args = new ModelRecalculatingEventArgs(isRecalculating);
return recalculateCallback(args);
},
ref context.UnderlyingContext
);
}
/// <summary>
/// Set the number of threads to be used by the model.
/// </summary>
/// <param name="threadCount">The new thread count</param>
public void SetThreadCount(int threadCount)
{
NativeMethods.llmodel_setThreadCount(_handle, threadCount);
}
/// <summary>
/// Get the number of threads used by the model.
/// </summary>
/// <returns>the number of threads used by the model</returns>
public int GetThreadCount()
{
return NativeMethods.llmodel_threadCount(_handle);
}
/// <summary>
/// Get the size of the internal state of the model.
/// </summary>
/// <remarks>
/// This state data is specific to the type of model you have created.
/// </remarks>
/// <returns>the size in bytes of the internal state of the model</returns>
public ulong GetStateSizeBytes()
{
return NativeMethods.llmodel_get_state_size(_handle);
}
/// <summary>
/// Saves the internal state of the model to the specified destination address.
/// </summary>
/// <param name="source">A pointer to the src</param>
/// <returns>The number of bytes copied</returns>
public unsafe ulong SaveStateData(byte* source)
{
return NativeMethods.llmodel_save_state_data(_handle, source);
}
/// <summary>
/// Restores the internal state of the model using data from the specified address.
/// </summary>
/// <param name="destination">A pointer to destination</param>
/// <returns>the number of bytes read</returns>
public unsafe ulong RestoreStateData(byte* destination)
{
return NativeMethods.llmodel_restore_state_data(_handle, destination);
}
/// <summary>
/// Check if the model is loaded.
/// </summary>
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
public bool IsLoaded()
{
return NativeMethods.llmodel_isModelLoaded(_handle);
}
/// <summary>
/// Load the model from a file.
/// </summary>
/// <param name="modelPath">The path to the model file.</param>
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
public bool Load(string modelPath)
{
return NativeMethods.llmodel_loadModel(_handle, modelPath);
}
protected void Destroy()
{
NativeMethods.llmodel_model_destroy(_handle);
}
protected void DestroyLLama()
{
NativeMethods.llmodel_llama_destroy(_handle);
}
protected void DestroyGptj()
{
NativeMethods.llmodel_gptj_destroy(_handle);
}
protected void DestroyMtp()
{
NativeMethods.llmodel_mpt_destroy(_handle);
}
protected virtual void Dispose(bool disposing)
{
if (_disposed) return;
if (disposing)
{
// dispose managed state
}
switch (_modelType)
{
case ModelType.LLAMA:
DestroyLLama();
break;
case ModelType.GPTJ:
DestroyGptj();
break;
case ModelType.MPT:
DestroyMtp();
break;
default:
Destroy();
break;
}
_disposed = true;
}
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
namespace Gpt4All.Bindings;
/// <summary>
/// Arguments for the response processing callback
/// </summary>
/// <param name="TokenId">The token id of the response</param>
/// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param>
/// <return>
/// A bool indicating whether the model should keep generating
/// </return>
public record ModelResponseEventArgs(int TokenId, string Response)
{
public bool IsError => TokenId == -1;
}
/// <summary>
/// Arguments for the prompt processing callback
/// </summary>
/// <param name="TokenId">The token id of the prompt</param>
/// <return>
/// A bool indicating whether the model should keep processing
/// </return>
public record ModelPromptEventArgs(int TokenId)
{
}
/// <summary>
/// Arguments for the recalculating callback
/// </summary>
/// <param name="IsRecalculating"> whether the model is recalculating the context.</param>
/// <return>
/// A bool indicating whether the model should keep generating
/// </return>
public record ModelRecalculatingEventArgs(bool IsRecalculating);
/// <summary>
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
/// </summary>
public class LLModel : ILLModel
{
protected readonly IntPtr _handle;
private readonly ModelType _modelType;
private readonly ILogger _logger;
private bool _disposed;
public ModelType ModelType => _modelType;
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
{
_handle = handle;
_modelType = modelType;
_logger = logger ?? NullLogger.Instance;
}
/// <summary>
/// Create a new model from a pointer
/// </summary>
/// <param name="handle">Pointer to underlying model</param>
/// <param name="modelType">The model type</param>
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
{
return new LLModel(handle, modelType, logger: logger);
}
/// <summary>
/// Generate a response using the model
/// </summary>
/// <param name="text">The input promp</param>
/// <param name="context">The context</param>
/// <param name="promptCallback">A callback function for handling the processing of prompt</param>
/// <param name="responseCallback">A callback function for handling the generated response</param>
/// <param name="recalculateCallback">A callback function for handling recalculation requests</param>
/// <param name="cancellationToken"></param>
public void Prompt(
string text,
LLModelPromptContext context,
Func<ModelPromptEventArgs, bool>? promptCallback = null,
Func<ModelResponseEventArgs, bool>? responseCallback = null,
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
CancellationToken cancellationToken = default)
{
GC.KeepAlive(promptCallback);
GC.KeepAlive(responseCallback);
GC.KeepAlive(recalculateCallback);
GC.KeepAlive(cancellationToken);
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
NativeMethods.llmodel_prompt(
_handle,
text,
(tokenId) =>
{
if (cancellationToken.IsCancellationRequested) return false;
if (promptCallback == null) return true;
var args = new ModelPromptEventArgs(tokenId);
return promptCallback(args);
},
(tokenId, response) =>
{
if (cancellationToken.IsCancellationRequested)
{
_logger.LogDebug("ResponseCallback evt=CancellationRequested");
return false;
}
if (responseCallback == null) return true;
var args = new ModelResponseEventArgs(tokenId, response);
return responseCallback(args);
},
(isRecalculating) =>
{
if (cancellationToken.IsCancellationRequested) return false;
if (recalculateCallback == null) return true;
var args = new ModelRecalculatingEventArgs(isRecalculating);
return recalculateCallback(args);
},
ref context.UnderlyingContext
);
}
/// <summary>
/// Set the number of threads to be used by the model.
/// </summary>
/// <param name="threadCount">The new thread count</param>
public void SetThreadCount(int threadCount)
{
NativeMethods.llmodel_setThreadCount(_handle, threadCount);
}
/// <summary>
/// Get the number of threads used by the model.
/// </summary>
/// <returns>the number of threads used by the model</returns>
public int GetThreadCount()
{
return NativeMethods.llmodel_threadCount(_handle);
}
/// <summary>
/// Get the size of the internal state of the model.
/// </summary>
/// <remarks>
/// This state data is specific to the type of model you have created.
/// </remarks>
/// <returns>the size in bytes of the internal state of the model</returns>
public ulong GetStateSizeBytes()
{
return NativeMethods.llmodel_get_state_size(_handle);
}
/// <summary>
/// Saves the internal state of the model to the specified destination address.
/// </summary>
/// <param name="source">A pointer to the src</param>
/// <returns>The number of bytes copied</returns>
public unsafe ulong SaveStateData(byte* source)
{
return NativeMethods.llmodel_save_state_data(_handle, source);
}
/// <summary>
/// Restores the internal state of the model using data from the specified address.
/// </summary>
/// <param name="destination">A pointer to destination</param>
/// <returns>the number of bytes read</returns>
public unsafe ulong RestoreStateData(byte* destination)
{
return NativeMethods.llmodel_restore_state_data(_handle, destination);
}
/// <summary>
/// Check if the model is loaded.
/// </summary>
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
public bool IsLoaded()
{
return NativeMethods.llmodel_isModelLoaded(_handle);
}
/// <summary>
/// Load the model from a file.
/// </summary>
/// <param name="modelPath">The path to the model file.</param>
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
public bool Load(string modelPath)
{
return NativeMethods.llmodel_loadModel(_handle, modelPath);
}
protected void Destroy()
{
NativeMethods.llmodel_model_destroy(_handle);
}
protected virtual void Dispose(bool disposing)
{
if (_disposed) return;
if (disposing)
{
// dispose managed state
}
switch (_modelType)
{
default:
Destroy();
break;
}
_disposed = true;
}
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}

View File

@@ -1,138 +1,138 @@
namespace Gpt4All.Bindings;
/// <summary>
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
/// </summary>
/// <remarks>
/// The implementation takes care of all the memory handling of the raw logits pointer and the
/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
/// </remarks>
public unsafe class LLModelPromptContext
{
private llmodel_prompt_context _ctx;
internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
public LLModelPromptContext()
{
_ctx = new();
}
/// <summary>
/// logits of current context
/// </summary>
public Span<float> Logits => new(_ctx.logits, (int)_ctx.logits_size);
/// <summary>
/// the size of the raw logits vector
/// </summary>
public nuint LogitsSize
{
get => _ctx.logits_size;
set => _ctx.logits_size = value;
}
/// <summary>
/// current tokens in the context window
/// </summary>
public Span<int> Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
/// <summary>
/// the size of the raw tokens vector
/// </summary>
public nuint TokensSize
{
get => _ctx.tokens_size;
set => _ctx.tokens_size = value;
}
/// <summary>
/// top k logits to sample from
/// </summary>
public int TopK
{
get => _ctx.top_k;
set => _ctx.top_k = value;
}
/// <summary>
/// nucleus sampling probability threshold
/// </summary>
public float TopP
{
get => _ctx.top_p;
set => _ctx.top_p = value;
}
/// <summary>
/// temperature to adjust model's output distribution
/// </summary>
public float Temperature
{
get => _ctx.temp;
set => _ctx.temp = value;
}
/// <summary>
/// number of tokens in past conversation
/// </summary>
public int PastNum
{
get => _ctx.n_past;
set => _ctx.n_past = value;
}
/// <summary>
/// number of predictions to generate in parallel
/// </summary>
public int Batches
{
get => _ctx.n_batch;
set => _ctx.n_batch = value;
}
/// <summary>
/// number of tokens to predict
/// </summary>
public int TokensToPredict
{
get => _ctx.n_predict;
set => _ctx.n_predict = value;
}
/// <summary>
/// penalty factor for repeated tokens
/// </summary>
public float RepeatPenalty
{
get => _ctx.repeat_penalty;
set => _ctx.repeat_penalty = value;
}
/// <summary>
/// last n tokens to penalize
/// </summary>
public int RepeatLastN
{
get => _ctx.repeat_last_n;
set => _ctx.repeat_last_n = value;
}
/// <summary>
/// number of tokens possible in context window
/// </summary>
public int ContextSize
{
get => _ctx.n_ctx;
set => _ctx.n_ctx = value;
}
/// <summary>
/// percent of context to erase if we exceed the context window
/// </summary>
public float ContextErase
{
get => _ctx.context_erase;
set => _ctx.context_erase = value;
}
}
namespace Gpt4All.Bindings;
/// <summary>
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
/// </summary>
/// <remarks>
/// The implementation takes care of all the memory handling of the raw logits pointer and the
/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
/// </remarks>
public unsafe class LLModelPromptContext
{
private llmodel_prompt_context _ctx;
internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
public LLModelPromptContext()
{
_ctx = new();
}
/// <summary>
/// logits of current context
/// </summary>
public Span<float> Logits => new(_ctx.logits, (int)_ctx.logits_size);
/// <summary>
/// the size of the raw logits vector
/// </summary>
public nuint LogitsSize
{
get => _ctx.logits_size;
set => _ctx.logits_size = value;
}
/// <summary>
/// current tokens in the context window
/// </summary>
public Span<int> Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
/// <summary>
/// the size of the raw tokens vector
/// </summary>
public nuint TokensSize
{
get => _ctx.tokens_size;
set => _ctx.tokens_size = value;
}
/// <summary>
/// top k logits to sample from
/// </summary>
public int TopK
{
get => _ctx.top_k;
set => _ctx.top_k = value;
}
/// <summary>
/// nucleus sampling probability threshold
/// </summary>
public float TopP
{
get => _ctx.top_p;
set => _ctx.top_p = value;
}
/// <summary>
/// temperature to adjust model's output distribution
/// </summary>
public float Temperature
{
get => _ctx.temp;
set => _ctx.temp = value;
}
/// <summary>
/// number of tokens in past conversation
/// </summary>
public int PastNum
{
get => _ctx.n_past;
set => _ctx.n_past = value;
}
/// <summary>
/// number of predictions to generate in parallel
/// </summary>
public int Batches
{
get => _ctx.n_batch;
set => _ctx.n_batch = value;
}
/// <summary>
/// number of tokens to predict
/// </summary>
public int TokensToPredict
{
get => _ctx.n_predict;
set => _ctx.n_predict = value;
}
/// <summary>
/// penalty factor for repeated tokens
/// </summary>
public float RepeatPenalty
{
get => _ctx.repeat_penalty;
set => _ctx.repeat_penalty = value;
}
/// <summary>
/// last n tokens to penalize
/// </summary>
public int RepeatLastN
{
get => _ctx.repeat_last_n;
set => _ctx.repeat_last_n = value;
}
/// <summary>
/// number of tokens possible in context window
/// </summary>
public int ContextSize
{
get => _ctx.n_ctx;
set => _ctx.n_ctx = value;
}
/// <summary>
/// percent of context to erase if we exceed the context window
/// </summary>
public float ContextErase
{
get => _ctx.context_erase;
set => _ctx.context_erase = value;
}
}

View File

@@ -1,126 +1,107 @@
using System.Runtime.InteropServices;
namespace Gpt4All.Bindings;
public unsafe partial struct llmodel_prompt_context
{
public float* logits;
[NativeTypeName("size_t")]
public nuint logits_size;
[NativeTypeName("int32_t *")]
public int* tokens;
[NativeTypeName("size_t")]
public nuint tokens_size;
[NativeTypeName("int32_t")]
public int n_past;
[NativeTypeName("int32_t")]
public int n_ctx;
[NativeTypeName("int32_t")]
public int n_predict;
[NativeTypeName("int32_t")]
public int top_k;
public float top_p;
public float temp;
[NativeTypeName("int32_t")]
public int n_batch;
public float repeat_penalty;
[NativeTypeName("int32_t")]
public int repeat_last_n;
public float context_erase;
}
internal static unsafe partial class NativeMethods
{
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.I1)]
public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.I1)]
public delegate bool LlmodelPromptCallback(int token_id);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.I1)]
public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("llmodel_model")]
public static extern IntPtr llmodel_gptj_create();
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void llmodel_gptj_destroy([NativeTypeName("llmodel_model")] IntPtr gptj);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("llmodel_model")]
public static extern IntPtr llmodel_mpt_create();
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void llmodel_mpt_destroy([NativeTypeName("llmodel_model")] IntPtr mpt);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("llmodel_model")]
public static extern IntPtr llmodel_llama_create();
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void llmodel_llama_destroy([NativeTypeName("llmodel_model")] IntPtr llama);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
[return: NativeTypeName("llmodel_model")]
public static extern IntPtr llmodel_model_create(
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
[return: MarshalAs(UnmanagedType.I1)]
public static extern bool llmodel_loadModel(
[NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: MarshalAs(UnmanagedType.I1)]
public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("uint64_t")]
public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("uint64_t")]
public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("uint64_t")]
public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
public static extern void llmodel_prompt(
[NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
LlmodelPromptCallback prompt_callback,
LlmodelResponseCallback response_callback,
LlmodelRecalculateCallback recalculate_callback,
ref llmodel_prompt_context ctx);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("int32_t")]
public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
}
using System.Runtime.InteropServices;
namespace Gpt4All.Bindings;
public unsafe partial struct llmodel_prompt_context
{
public float* logits;
[NativeTypeName("size_t")]
public nuint logits_size;
[NativeTypeName("int32_t *")]
public int* tokens;
[NativeTypeName("size_t")]
public nuint tokens_size;
[NativeTypeName("int32_t")]
public int n_past;
[NativeTypeName("int32_t")]
public int n_ctx;
[NativeTypeName("int32_t")]
public int n_predict;
[NativeTypeName("int32_t")]
public int top_k;
public float top_p;
public float temp;
[NativeTypeName("int32_t")]
public int n_batch;
public float repeat_penalty;
[NativeTypeName("int32_t")]
public int repeat_last_n;
public float context_erase;
}
internal static unsafe partial class NativeMethods
{
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.I1)]
public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.I1)]
public delegate bool LlmodelPromptCallback(int token_id);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.I1)]
public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
[return: NativeTypeName("llmodel_model")]
public static extern IntPtr llmodel_model_create2(
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string build_variant,
out IntPtr error);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
[return: MarshalAs(UnmanagedType.I1)]
public static extern bool llmodel_loadModel(
[NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: MarshalAs(UnmanagedType.I1)]
public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("uint64_t")]
public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("uint64_t")]
public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("uint64_t")]
public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
public static extern void llmodel_prompt(
[NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
LlmodelPromptCallback prompt_callback,
LlmodelResponseCallback response_callback,
LlmodelRecalculateCallback recalculate_callback,
ref llmodel_prompt_context ctx);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("int32_t")]
public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
}