mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-24 04:53:02 +00:00
Initial Library Loader for .NET Bindings / Update bindings to support newest changes (#763)
* Initial Library Loader * Load library as part of Model factory * Dynamically search and find the dlls * Update tests to use locally built runtimes * Fix dylib loading, add macos runtime support for sample/tests * Bypass automatic loading by default. * Only set CMAKE_OSX_ARCHITECTURES if not already set, allow cross-compile * Switch Loading again * Update build scripts for mac/linux * Update bindings to support newest breaking changes * Fix build * Use llmodel for Windows * Actually, it does need to be libllmodel * Name * Remove TFMs, bypass loading by default * Fix script * Delete mac script --------- Co-authored-by: Tim Miller <innerlogic4321@ghmail.com>
This commit is contained in:
@@ -1,247 +1,222 @@
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
/// <summary>
|
||||
/// Arguments for the response processing callback
|
||||
/// </summary>
|
||||
/// <param name="TokenId">The token id of the response</param>
|
||||
/// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param>
|
||||
/// <return>
|
||||
/// A bool indicating whether the model should keep generating
|
||||
/// </return>
|
||||
public record ModelResponseEventArgs(int TokenId, string Response)
|
||||
{
|
||||
public bool IsError => TokenId == -1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Arguments for the prompt processing callback
|
||||
/// </summary>
|
||||
/// <param name="TokenId">The token id of the prompt</param>
|
||||
/// <return>
|
||||
/// A bool indicating whether the model should keep processing
|
||||
/// </return>
|
||||
public record ModelPromptEventArgs(int TokenId)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Arguments for the recalculating callback
|
||||
/// </summary>
|
||||
/// <param name="IsRecalculating"> whether the model is recalculating the context.</param>
|
||||
/// <return>
|
||||
/// A bool indicating whether the model should keep generating
|
||||
/// </return>
|
||||
public record ModelRecalculatingEventArgs(bool IsRecalculating);
|
||||
|
||||
/// <summary>
|
||||
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
|
||||
/// </summary>
|
||||
public class LLModel : ILLModel
|
||||
{
|
||||
protected readonly IntPtr _handle;
|
||||
private readonly ModelType _modelType;
|
||||
private readonly ILogger _logger;
|
||||
private bool _disposed;
|
||||
|
||||
public ModelType ModelType => _modelType;
|
||||
|
||||
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
||||
{
|
||||
_handle = handle;
|
||||
_modelType = modelType;
|
||||
_logger = logger ?? NullLogger.Instance;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new model from a pointer
|
||||
/// </summary>
|
||||
/// <param name="handle">Pointer to underlying model</param>
|
||||
/// <param name="modelType">The model type</param>
|
||||
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
||||
{
|
||||
return new LLModel(handle, modelType, logger: logger);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Generate a response using the model
|
||||
/// </summary>
|
||||
/// <param name="text">The input promp</param>
|
||||
/// <param name="context">The context</param>
|
||||
/// <param name="promptCallback">A callback function for handling the processing of prompt</param>
|
||||
/// <param name="responseCallback">A callback function for handling the generated response</param>
|
||||
/// <param name="recalculateCallback">A callback function for handling recalculation requests</param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
public void Prompt(
|
||||
string text,
|
||||
LLModelPromptContext context,
|
||||
Func<ModelPromptEventArgs, bool>? promptCallback = null,
|
||||
Func<ModelResponseEventArgs, bool>? responseCallback = null,
|
||||
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
GC.KeepAlive(promptCallback);
|
||||
GC.KeepAlive(responseCallback);
|
||||
GC.KeepAlive(recalculateCallback);
|
||||
GC.KeepAlive(cancellationToken);
|
||||
|
||||
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
|
||||
|
||||
NativeMethods.llmodel_prompt(
|
||||
_handle,
|
||||
text,
|
||||
(tokenId) =>
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested) return false;
|
||||
if (promptCallback == null) return true;
|
||||
var args = new ModelPromptEventArgs(tokenId);
|
||||
return promptCallback(args);
|
||||
},
|
||||
(tokenId, response) =>
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
_logger.LogDebug("ResponseCallback evt=CancellationRequested");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (responseCallback == null) return true;
|
||||
var args = new ModelResponseEventArgs(tokenId, response);
|
||||
return responseCallback(args);
|
||||
},
|
||||
(isRecalculating) =>
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested) return false;
|
||||
if (recalculateCallback == null) return true;
|
||||
var args = new ModelRecalculatingEventArgs(isRecalculating);
|
||||
return recalculateCallback(args);
|
||||
},
|
||||
ref context.UnderlyingContext
|
||||
);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set the number of threads to be used by the model.
|
||||
/// </summary>
|
||||
/// <param name="threadCount">The new thread count</param>
|
||||
public void SetThreadCount(int threadCount)
|
||||
{
|
||||
NativeMethods.llmodel_setThreadCount(_handle, threadCount);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of threads used by the model.
|
||||
/// </summary>
|
||||
/// <returns>the number of threads used by the model</returns>
|
||||
public int GetThreadCount()
|
||||
{
|
||||
return NativeMethods.llmodel_threadCount(_handle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the size of the internal state of the model.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This state data is specific to the type of model you have created.
|
||||
/// </remarks>
|
||||
/// <returns>the size in bytes of the internal state of the model</returns>
|
||||
public ulong GetStateSizeBytes()
|
||||
{
|
||||
return NativeMethods.llmodel_get_state_size(_handle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves the internal state of the model to the specified destination address.
|
||||
/// </summary>
|
||||
/// <param name="source">A pointer to the src</param>
|
||||
/// <returns>The number of bytes copied</returns>
|
||||
public unsafe ulong SaveStateData(byte* source)
|
||||
{
|
||||
return NativeMethods.llmodel_save_state_data(_handle, source);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Restores the internal state of the model using data from the specified address.
|
||||
/// </summary>
|
||||
/// <param name="destination">A pointer to destination</param>
|
||||
/// <returns>the number of bytes read</returns>
|
||||
public unsafe ulong RestoreStateData(byte* destination)
|
||||
{
|
||||
return NativeMethods.llmodel_restore_state_data(_handle, destination);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Check if the model is loaded.
|
||||
/// </summary>
|
||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||
public bool IsLoaded()
|
||||
{
|
||||
return NativeMethods.llmodel_isModelLoaded(_handle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load the model from a file.
|
||||
/// </summary>
|
||||
/// <param name="modelPath">The path to the model file.</param>
|
||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||
public bool Load(string modelPath)
|
||||
{
|
||||
return NativeMethods.llmodel_loadModel(_handle, modelPath);
|
||||
}
|
||||
|
||||
protected void Destroy()
|
||||
{
|
||||
NativeMethods.llmodel_model_destroy(_handle);
|
||||
}
|
||||
|
||||
protected void DestroyLLama()
|
||||
{
|
||||
NativeMethods.llmodel_llama_destroy(_handle);
|
||||
}
|
||||
|
||||
protected void DestroyGptj()
|
||||
{
|
||||
NativeMethods.llmodel_gptj_destroy(_handle);
|
||||
}
|
||||
|
||||
protected void DestroyMtp()
|
||||
{
|
||||
NativeMethods.llmodel_mpt_destroy(_handle);
|
||||
}
|
||||
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed) return;
|
||||
|
||||
if (disposing)
|
||||
{
|
||||
// dispose managed state
|
||||
}
|
||||
|
||||
switch (_modelType)
|
||||
{
|
||||
case ModelType.LLAMA:
|
||||
DestroyLLama();
|
||||
break;
|
||||
case ModelType.GPTJ:
|
||||
DestroyGptj();
|
||||
break;
|
||||
case ModelType.MPT:
|
||||
DestroyMtp();
|
||||
break;
|
||||
default:
|
||||
Destroy();
|
||||
break;
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
Dispose(disposing: true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
}
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
/// <summary>
|
||||
/// Arguments for the response processing callback
|
||||
/// </summary>
|
||||
/// <param name="TokenId">The token id of the response</param>
|
||||
/// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param>
|
||||
/// <return>
|
||||
/// A bool indicating whether the model should keep generating
|
||||
/// </return>
|
||||
public record ModelResponseEventArgs(int TokenId, string Response)
|
||||
{
|
||||
public bool IsError => TokenId == -1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Arguments for the prompt processing callback
|
||||
/// </summary>
|
||||
/// <param name="TokenId">The token id of the prompt</param>
|
||||
/// <return>
|
||||
/// A bool indicating whether the model should keep processing
|
||||
/// </return>
|
||||
public record ModelPromptEventArgs(int TokenId)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Arguments for the recalculating callback
|
||||
/// </summary>
|
||||
/// <param name="IsRecalculating"> whether the model is recalculating the context.</param>
|
||||
/// <return>
|
||||
/// A bool indicating whether the model should keep generating
|
||||
/// </return>
|
||||
public record ModelRecalculatingEventArgs(bool IsRecalculating);
|
||||
|
||||
/// <summary>
|
||||
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
|
||||
/// </summary>
|
||||
public class LLModel : ILLModel
|
||||
{
|
||||
protected readonly IntPtr _handle;
|
||||
private readonly ModelType _modelType;
|
||||
private readonly ILogger _logger;
|
||||
private bool _disposed;
|
||||
|
||||
public ModelType ModelType => _modelType;
|
||||
|
||||
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
||||
{
|
||||
_handle = handle;
|
||||
_modelType = modelType;
|
||||
_logger = logger ?? NullLogger.Instance;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new model from a pointer
|
||||
/// </summary>
|
||||
/// <param name="handle">Pointer to underlying model</param>
|
||||
/// <param name="modelType">The model type</param>
|
||||
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
||||
{
|
||||
return new LLModel(handle, modelType, logger: logger);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Generate a response using the model
|
||||
/// </summary>
|
||||
/// <param name="text">The input promp</param>
|
||||
/// <param name="context">The context</param>
|
||||
/// <param name="promptCallback">A callback function for handling the processing of prompt</param>
|
||||
/// <param name="responseCallback">A callback function for handling the generated response</param>
|
||||
/// <param name="recalculateCallback">A callback function for handling recalculation requests</param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
public void Prompt(
|
||||
string text,
|
||||
LLModelPromptContext context,
|
||||
Func<ModelPromptEventArgs, bool>? promptCallback = null,
|
||||
Func<ModelResponseEventArgs, bool>? responseCallback = null,
|
||||
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
GC.KeepAlive(promptCallback);
|
||||
GC.KeepAlive(responseCallback);
|
||||
GC.KeepAlive(recalculateCallback);
|
||||
GC.KeepAlive(cancellationToken);
|
||||
|
||||
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
|
||||
|
||||
NativeMethods.llmodel_prompt(
|
||||
_handle,
|
||||
text,
|
||||
(tokenId) =>
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested) return false;
|
||||
if (promptCallback == null) return true;
|
||||
var args = new ModelPromptEventArgs(tokenId);
|
||||
return promptCallback(args);
|
||||
},
|
||||
(tokenId, response) =>
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
_logger.LogDebug("ResponseCallback evt=CancellationRequested");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (responseCallback == null) return true;
|
||||
var args = new ModelResponseEventArgs(tokenId, response);
|
||||
return responseCallback(args);
|
||||
},
|
||||
(isRecalculating) =>
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested) return false;
|
||||
if (recalculateCallback == null) return true;
|
||||
var args = new ModelRecalculatingEventArgs(isRecalculating);
|
||||
return recalculateCallback(args);
|
||||
},
|
||||
ref context.UnderlyingContext
|
||||
);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set the number of threads to be used by the model.
|
||||
/// </summary>
|
||||
/// <param name="threadCount">The new thread count</param>
|
||||
public void SetThreadCount(int threadCount)
|
||||
{
|
||||
NativeMethods.llmodel_setThreadCount(_handle, threadCount);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the number of threads used by the model.
|
||||
/// </summary>
|
||||
/// <returns>the number of threads used by the model</returns>
|
||||
public int GetThreadCount()
|
||||
{
|
||||
return NativeMethods.llmodel_threadCount(_handle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the size of the internal state of the model.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This state data is specific to the type of model you have created.
|
||||
/// </remarks>
|
||||
/// <returns>the size in bytes of the internal state of the model</returns>
|
||||
public ulong GetStateSizeBytes()
|
||||
{
|
||||
return NativeMethods.llmodel_get_state_size(_handle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves the internal state of the model to the specified destination address.
|
||||
/// </summary>
|
||||
/// <param name="source">A pointer to the src</param>
|
||||
/// <returns>The number of bytes copied</returns>
|
||||
public unsafe ulong SaveStateData(byte* source)
|
||||
{
|
||||
return NativeMethods.llmodel_save_state_data(_handle, source);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Restores the internal state of the model using data from the specified address.
|
||||
/// </summary>
|
||||
/// <param name="destination">A pointer to destination</param>
|
||||
/// <returns>the number of bytes read</returns>
|
||||
public unsafe ulong RestoreStateData(byte* destination)
|
||||
{
|
||||
return NativeMethods.llmodel_restore_state_data(_handle, destination);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Check if the model is loaded.
|
||||
/// </summary>
|
||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||
public bool IsLoaded()
|
||||
{
|
||||
return NativeMethods.llmodel_isModelLoaded(_handle);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load the model from a file.
|
||||
/// </summary>
|
||||
/// <param name="modelPath">The path to the model file.</param>
|
||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||
public bool Load(string modelPath)
|
||||
{
|
||||
return NativeMethods.llmodel_loadModel(_handle, modelPath);
|
||||
}
|
||||
|
||||
protected void Destroy()
|
||||
{
|
||||
NativeMethods.llmodel_model_destroy(_handle);
|
||||
}
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (_disposed) return;
|
||||
|
||||
if (disposing)
|
||||
{
|
||||
// dispose managed state
|
||||
}
|
||||
|
||||
switch (_modelType)
|
||||
{
|
||||
default:
|
||||
Destroy();
|
||||
break;
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
Dispose(disposing: true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
}
|
||||
|
@@ -1,138 +1,138 @@
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
/// <summary>
|
||||
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The implementation takes care of all the memory handling of the raw logits pointer and the
|
||||
/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
|
||||
/// </remarks>
|
||||
public unsafe class LLModelPromptContext
|
||||
{
|
||||
private llmodel_prompt_context _ctx;
|
||||
|
||||
internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
|
||||
|
||||
public LLModelPromptContext()
|
||||
{
|
||||
_ctx = new();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// logits of current context
|
||||
/// </summary>
|
||||
public Span<float> Logits => new(_ctx.logits, (int)_ctx.logits_size);
|
||||
|
||||
/// <summary>
|
||||
/// the size of the raw logits vector
|
||||
/// </summary>
|
||||
public nuint LogitsSize
|
||||
{
|
||||
get => _ctx.logits_size;
|
||||
set => _ctx.logits_size = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// current tokens in the context window
|
||||
/// </summary>
|
||||
public Span<int> Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
|
||||
|
||||
/// <summary>
|
||||
/// the size of the raw tokens vector
|
||||
/// </summary>
|
||||
public nuint TokensSize
|
||||
{
|
||||
get => _ctx.tokens_size;
|
||||
set => _ctx.tokens_size = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// top k logits to sample from
|
||||
/// </summary>
|
||||
public int TopK
|
||||
{
|
||||
get => _ctx.top_k;
|
||||
set => _ctx.top_k = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// nucleus sampling probability threshold
|
||||
/// </summary>
|
||||
public float TopP
|
||||
{
|
||||
get => _ctx.top_p;
|
||||
set => _ctx.top_p = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// temperature to adjust model's output distribution
|
||||
/// </summary>
|
||||
public float Temperature
|
||||
{
|
||||
get => _ctx.temp;
|
||||
set => _ctx.temp = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens in past conversation
|
||||
/// </summary>
|
||||
public int PastNum
|
||||
{
|
||||
get => _ctx.n_past;
|
||||
set => _ctx.n_past = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of predictions to generate in parallel
|
||||
/// </summary>
|
||||
public int Batches
|
||||
{
|
||||
get => _ctx.n_batch;
|
||||
set => _ctx.n_batch = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens to predict
|
||||
/// </summary>
|
||||
public int TokensToPredict
|
||||
{
|
||||
get => _ctx.n_predict;
|
||||
set => _ctx.n_predict = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// penalty factor for repeated tokens
|
||||
/// </summary>
|
||||
public float RepeatPenalty
|
||||
{
|
||||
get => _ctx.repeat_penalty;
|
||||
set => _ctx.repeat_penalty = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// last n tokens to penalize
|
||||
/// </summary>
|
||||
public int RepeatLastN
|
||||
{
|
||||
get => _ctx.repeat_last_n;
|
||||
set => _ctx.repeat_last_n = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens possible in context window
|
||||
/// </summary>
|
||||
public int ContextSize
|
||||
{
|
||||
get => _ctx.n_ctx;
|
||||
set => _ctx.n_ctx = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// percent of context to erase if we exceed the context window
|
||||
/// </summary>
|
||||
public float ContextErase
|
||||
{
|
||||
get => _ctx.context_erase;
|
||||
set => _ctx.context_erase = value;
|
||||
}
|
||||
}
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
/// <summary>
|
||||
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The implementation takes care of all the memory handling of the raw logits pointer and the
|
||||
/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
|
||||
/// </remarks>
|
||||
public unsafe class LLModelPromptContext
|
||||
{
|
||||
private llmodel_prompt_context _ctx;
|
||||
|
||||
internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
|
||||
|
||||
public LLModelPromptContext()
|
||||
{
|
||||
_ctx = new();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// logits of current context
|
||||
/// </summary>
|
||||
public Span<float> Logits => new(_ctx.logits, (int)_ctx.logits_size);
|
||||
|
||||
/// <summary>
|
||||
/// the size of the raw logits vector
|
||||
/// </summary>
|
||||
public nuint LogitsSize
|
||||
{
|
||||
get => _ctx.logits_size;
|
||||
set => _ctx.logits_size = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// current tokens in the context window
|
||||
/// </summary>
|
||||
public Span<int> Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
|
||||
|
||||
/// <summary>
|
||||
/// the size of the raw tokens vector
|
||||
/// </summary>
|
||||
public nuint TokensSize
|
||||
{
|
||||
get => _ctx.tokens_size;
|
||||
set => _ctx.tokens_size = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// top k logits to sample from
|
||||
/// </summary>
|
||||
public int TopK
|
||||
{
|
||||
get => _ctx.top_k;
|
||||
set => _ctx.top_k = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// nucleus sampling probability threshold
|
||||
/// </summary>
|
||||
public float TopP
|
||||
{
|
||||
get => _ctx.top_p;
|
||||
set => _ctx.top_p = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// temperature to adjust model's output distribution
|
||||
/// </summary>
|
||||
public float Temperature
|
||||
{
|
||||
get => _ctx.temp;
|
||||
set => _ctx.temp = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens in past conversation
|
||||
/// </summary>
|
||||
public int PastNum
|
||||
{
|
||||
get => _ctx.n_past;
|
||||
set => _ctx.n_past = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of predictions to generate in parallel
|
||||
/// </summary>
|
||||
public int Batches
|
||||
{
|
||||
get => _ctx.n_batch;
|
||||
set => _ctx.n_batch = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens to predict
|
||||
/// </summary>
|
||||
public int TokensToPredict
|
||||
{
|
||||
get => _ctx.n_predict;
|
||||
set => _ctx.n_predict = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// penalty factor for repeated tokens
|
||||
/// </summary>
|
||||
public float RepeatPenalty
|
||||
{
|
||||
get => _ctx.repeat_penalty;
|
||||
set => _ctx.repeat_penalty = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// last n tokens to penalize
|
||||
/// </summary>
|
||||
public int RepeatLastN
|
||||
{
|
||||
get => _ctx.repeat_last_n;
|
||||
set => _ctx.repeat_last_n = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// number of tokens possible in context window
|
||||
/// </summary>
|
||||
public int ContextSize
|
||||
{
|
||||
get => _ctx.n_ctx;
|
||||
set => _ctx.n_ctx = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// percent of context to erase if we exceed the context window
|
||||
/// </summary>
|
||||
public float ContextErase
|
||||
{
|
||||
get => _ctx.context_erase;
|
||||
set => _ctx.context_erase = value;
|
||||
}
|
||||
}
|
||||
|
@@ -1,126 +1,107 @@
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
public unsafe partial struct llmodel_prompt_context
|
||||
{
|
||||
public float* logits;
|
||||
|
||||
[NativeTypeName("size_t")]
|
||||
public nuint logits_size;
|
||||
|
||||
[NativeTypeName("int32_t *")]
|
||||
public int* tokens;
|
||||
|
||||
[NativeTypeName("size_t")]
|
||||
public nuint tokens_size;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_past;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_ctx;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_predict;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int top_k;
|
||||
|
||||
public float top_p;
|
||||
|
||||
public float temp;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_batch;
|
||||
|
||||
public float repeat_penalty;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int repeat_last_n;
|
||||
|
||||
public float context_erase;
|
||||
}
|
||||
|
||||
internal static unsafe partial class NativeMethods
|
||||
{
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public delegate bool LlmodelPromptCallback(int token_id);
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("llmodel_model")]
|
||||
public static extern IntPtr llmodel_gptj_create();
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_gptj_destroy([NativeTypeName("llmodel_model")] IntPtr gptj);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("llmodel_model")]
|
||||
public static extern IntPtr llmodel_mpt_create();
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_mpt_destroy([NativeTypeName("llmodel_model")] IntPtr mpt);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("llmodel_model")]
|
||||
public static extern IntPtr llmodel_llama_create();
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_llama_destroy([NativeTypeName("llmodel_model")] IntPtr llama);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||
[return: NativeTypeName("llmodel_model")]
|
||||
public static extern IntPtr llmodel_model_create(
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public static extern bool llmodel_loadModel(
|
||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("uint64_t")]
|
||||
public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("uint64_t")]
|
||||
public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("uint64_t")]
|
||||
public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||
public static extern void llmodel_prompt(
|
||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||
LlmodelPromptCallback prompt_callback,
|
||||
LlmodelResponseCallback response_callback,
|
||||
LlmodelRecalculateCallback recalculate_callback,
|
||||
ref llmodel_prompt_context ctx);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("int32_t")]
|
||||
public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
}
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Gpt4All.Bindings;
|
||||
|
||||
public unsafe partial struct llmodel_prompt_context
|
||||
{
|
||||
public float* logits;
|
||||
|
||||
[NativeTypeName("size_t")]
|
||||
public nuint logits_size;
|
||||
|
||||
[NativeTypeName("int32_t *")]
|
||||
public int* tokens;
|
||||
|
||||
[NativeTypeName("size_t")]
|
||||
public nuint tokens_size;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_past;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_ctx;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_predict;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int top_k;
|
||||
|
||||
public float top_p;
|
||||
|
||||
public float temp;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int n_batch;
|
||||
|
||||
public float repeat_penalty;
|
||||
|
||||
[NativeTypeName("int32_t")]
|
||||
public int repeat_last_n;
|
||||
|
||||
public float context_erase;
|
||||
}
|
||||
|
||||
internal static unsafe partial class NativeMethods
|
||||
{
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public delegate bool LlmodelPromptCallback(int token_id);
|
||||
|
||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||
[return: NativeTypeName("llmodel_model")]
|
||||
public static extern IntPtr llmodel_model_create2(
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string build_variant,
|
||||
out IntPtr error);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public static extern bool llmodel_loadModel(
|
||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
|
||||
[return: MarshalAs(UnmanagedType.I1)]
|
||||
public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("uint64_t")]
|
||||
public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("uint64_t")]
|
||||
public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("uint64_t")]
|
||||
public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||
public static extern void llmodel_prompt(
|
||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||
LlmodelPromptCallback prompt_callback,
|
||||
LlmodelResponseCallback response_callback,
|
||||
LlmodelRecalculateCallback recalculate_callback,
|
||||
ref llmodel_prompt_context ctx);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
|
||||
|
||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||
[return: NativeTypeName("int32_t")]
|
||||
public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
|
||||
}
|
||||
|
Reference in New Issue
Block a user