Initial Library Loader for .NET Bindings / Update bindings to support newest changes (#763)

* Initial Library Loader

* Load library as part of Model factory

* Dynamically search and find the dlls

* Update tests to use locally built runtimes

* Fix dylib loading, add macos runtime support for sample/tests

* Bypass automatic loading by default.

* Only set CMAKE_OSX_ARCHITECTURES if not already set, allow cross-compile

* Switch Loading again

* Update build scripts for mac/linux

* Update bindings to support newest breaking changes

* Fix build

* Use llmodel for Windows

* Actually, it does need to be libllmodel

* Name

* Remove TFMs, bypass loading by default

* Fix script

* Delete mac script

---------

Co-authored-by: Tim Miller <innerlogic4321@ghmail.com>
This commit is contained in:
Tim Miller 2023-06-13 21:05:34 +09:00 committed by GitHub
parent 88616fde7f
commit 797891c995
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 850 additions and 671 deletions

View File

@ -9,7 +9,9 @@ if(APPLE)
set(CMAKE_OSX_ARCHITECTURES "arm64;x86_64" CACHE STRING "" FORCE) set(CMAKE_OSX_ARCHITECTURES "arm64;x86_64" CACHE STRING "" FORCE)
else() else()
# Build for the host architecture on macOS # Build for the host architecture on macOS
set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE) if(NOT CMAKE_OSX_ARCHITECTURES)
set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE)
endif()
endif() endif()
endif() endif()

View File

@ -1,18 +1,31 @@
<Project Sdk="Microsoft.NET.Sdk"> <Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup> <PropertyGroup>
<OutputType>Exe</OutputType> <OutputType>Exe</OutputType>
<TargetFramework>net7.0</TargetFramework> <TargetFramework>net7.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings> <ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<ProjectReference Include="..\Gpt4All\Gpt4All.csproj" /> <ProjectReference Include="..\Gpt4All\Gpt4All.csproj" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<Folder Include="Properties\" /> <!-- Windows -->
</ItemGroup> <None Include="..\runtimes\win-x64\native\*.dll" Pack="true" PackagePath="runtimes\win-x64\native\%(Filename)%(Extension)" />
<!-- Linux -->
<None Include="..\runtimes\linux-x64\native\*.so" Pack="true" PackagePath="runtimes\linux-x64\native\%(Filename)%(Extension)" />
<!-- MacOS -->
<None Include="..\runtimes\osx\native\*.dylib" Pack="true" PackagePath="runtimes\osx\native\%(Filename)%(Extension)" />
</ItemGroup>
<ItemGroup>
<!-- Windows -->
<None Condition="$([MSBuild]::IsOSPlatform('Windows'))" Include="..\runtimes\win-x64\native\*.dll" Visible="False" CopyToOutputDirectory="PreserveNewest" />
<!-- Linux -->
<None Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="..\runtimes\linux-x64\native\*.so" Visible="False" CopyToOutputDirectory="PreserveNewest" />
<!-- MacOS -->
<None Condition="$([MSBuild]::IsOSPlatform('OSX'))" Include="..\runtimes\osx\native\*.dylib" Visible="False" CopyToOutputDirectory="PreserveNewest" />
</ItemGroup>
</Project> </Project>

View File

@ -21,7 +21,24 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ProjectReference Include="..\Gpt4All\Gpt4All.csproj" /> <ProjectReference Include="..\Gpt4All\Gpt4All.csproj" />
</ItemGroup> </ItemGroup>
<ItemGroup>
<!-- Windows -->
<None Include="..\runtimes\win-x64\native\*.dll" Pack="true" PackagePath="runtimes\win-x64\native\%(Filename)%(Extension)" />
<!-- Linux -->
<None Include="..\runtimes\linux-x64\native\*.so" Pack="true" PackagePath="runtimes\linux-x64\native\%(Filename)%(Extension)" />
<!-- MacOS -->
<None Include="..\runtimes\osx\native\*.dylib" Pack="true" PackagePath="runtimes\osx\native\%(Filename)%(Extension)" />
</ItemGroup>
<ItemGroup>
<!-- Windows -->
<None Condition="$([MSBuild]::IsOSPlatform('Windows'))" Include="..\runtimes\win-x64\native\*.dll" Visible="False" CopyToOutputDirectory="PreserveNewest" />
<!-- Linux -->
<None Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="..\runtimes\linux-x64\native\*.so" Visible="False" CopyToOutputDirectory="PreserveNewest" />
<!-- MacOS -->
<None Condition="$([MSBuild]::IsOSPlatform('OSX'))" Include="..\runtimes\osx\native\*.dylib" Visible="False" CopyToOutputDirectory="PreserveNewest" />
</ItemGroup>
</Project> </Project>

View File

@ -14,18 +14,18 @@ public class ModelFactoryTests
[Fact] [Fact]
public void CanLoadLlamaModel() public void CanLoadLlamaModel()
{ {
using var model = _modelFactory.LoadLlamaModel(Constants.LLAMA_MODEL_PATH); using var model = _modelFactory.LoadModel(Constants.LLAMA_MODEL_PATH);
} }
[Fact] [Fact]
public void CanLoadGptjModel() public void CanLoadGptjModel()
{ {
using var model = _modelFactory.LoadGptjModel(Constants.GPTJ_MODEL_PATH); using var model = _modelFactory.LoadModel(Constants.GPTJ_MODEL_PATH);
} }
[Fact] [Fact]
public void CanLoadMptModel() public void CanLoadMptModel()
{ {
using var model = _modelFactory.LoadMptModel(Constants.MPT_MODEL_PATH); using var model = _modelFactory.LoadModel(Constants.MPT_MODEL_PATH);
} }
} }

View File

@ -1,247 +1,222 @@
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Logging.Abstractions;
namespace Gpt4All.Bindings; namespace Gpt4All.Bindings;
/// <summary> /// <summary>
/// Arguments for the response processing callback /// Arguments for the response processing callback
/// </summary> /// </summary>
/// <param name="TokenId">The token id of the response</param> /// <param name="TokenId">The token id of the response</param>
/// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param> /// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param>
/// <return> /// <return>
/// A bool indicating whether the model should keep generating /// A bool indicating whether the model should keep generating
/// </return> /// </return>
public record ModelResponseEventArgs(int TokenId, string Response) public record ModelResponseEventArgs(int TokenId, string Response)
{ {
public bool IsError => TokenId == -1; public bool IsError => TokenId == -1;
} }
/// <summary> /// <summary>
/// Arguments for the prompt processing callback /// Arguments for the prompt processing callback
/// </summary> /// </summary>
/// <param name="TokenId">The token id of the prompt</param> /// <param name="TokenId">The token id of the prompt</param>
/// <return> /// <return>
/// A bool indicating whether the model should keep processing /// A bool indicating whether the model should keep processing
/// </return> /// </return>
public record ModelPromptEventArgs(int TokenId) public record ModelPromptEventArgs(int TokenId)
{ {
} }
/// <summary> /// <summary>
/// Arguments for the recalculating callback /// Arguments for the recalculating callback
/// </summary> /// </summary>
/// <param name="IsRecalculating"> whether the model is recalculating the context.</param> /// <param name="IsRecalculating"> whether the model is recalculating the context.</param>
/// <return> /// <return>
/// A bool indicating whether the model should keep generating /// A bool indicating whether the model should keep generating
/// </return> /// </return>
public record ModelRecalculatingEventArgs(bool IsRecalculating); public record ModelRecalculatingEventArgs(bool IsRecalculating);
/// <summary> /// <summary>
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API. /// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
/// </summary> /// </summary>
public class LLModel : ILLModel public class LLModel : ILLModel
{ {
protected readonly IntPtr _handle; protected readonly IntPtr _handle;
private readonly ModelType _modelType; private readonly ModelType _modelType;
private readonly ILogger _logger; private readonly ILogger _logger;
private bool _disposed; private bool _disposed;
public ModelType ModelType => _modelType; public ModelType ModelType => _modelType;
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null) internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
{ {
_handle = handle; _handle = handle;
_modelType = modelType; _modelType = modelType;
_logger = logger ?? NullLogger.Instance; _logger = logger ?? NullLogger.Instance;
} }
/// <summary> /// <summary>
/// Create a new model from a pointer /// Create a new model from a pointer
/// </summary> /// </summary>
/// <param name="handle">Pointer to underlying model</param> /// <param name="handle">Pointer to underlying model</param>
/// <param name="modelType">The model type</param> /// <param name="modelType">The model type</param>
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null) public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
{ {
return new LLModel(handle, modelType, logger: logger); return new LLModel(handle, modelType, logger: logger);
} }
/// <summary> /// <summary>
/// Generate a response using the model /// Generate a response using the model
/// </summary> /// </summary>
/// <param name="text">The input promp</param> /// <param name="text">The input promp</param>
/// <param name="context">The context</param> /// <param name="context">The context</param>
/// <param name="promptCallback">A callback function for handling the processing of prompt</param> /// <param name="promptCallback">A callback function for handling the processing of prompt</param>
/// <param name="responseCallback">A callback function for handling the generated response</param> /// <param name="responseCallback">A callback function for handling the generated response</param>
/// <param name="recalculateCallback">A callback function for handling recalculation requests</param> /// <param name="recalculateCallback">A callback function for handling recalculation requests</param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
public void Prompt( public void Prompt(
string text, string text,
LLModelPromptContext context, LLModelPromptContext context,
Func<ModelPromptEventArgs, bool>? promptCallback = null, Func<ModelPromptEventArgs, bool>? promptCallback = null,
Func<ModelResponseEventArgs, bool>? responseCallback = null, Func<ModelResponseEventArgs, bool>? responseCallback = null,
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null, Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
CancellationToken cancellationToken = default) CancellationToken cancellationToken = default)
{ {
GC.KeepAlive(promptCallback); GC.KeepAlive(promptCallback);
GC.KeepAlive(responseCallback); GC.KeepAlive(responseCallback);
GC.KeepAlive(recalculateCallback); GC.KeepAlive(recalculateCallback);
GC.KeepAlive(cancellationToken); GC.KeepAlive(cancellationToken);
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump()); _logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
NativeMethods.llmodel_prompt( NativeMethods.llmodel_prompt(
_handle, _handle,
text, text,
(tokenId) => (tokenId) =>
{ {
if (cancellationToken.IsCancellationRequested) return false; if (cancellationToken.IsCancellationRequested) return false;
if (promptCallback == null) return true; if (promptCallback == null) return true;
var args = new ModelPromptEventArgs(tokenId); var args = new ModelPromptEventArgs(tokenId);
return promptCallback(args); return promptCallback(args);
}, },
(tokenId, response) => (tokenId, response) =>
{ {
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
_logger.LogDebug("ResponseCallback evt=CancellationRequested"); _logger.LogDebug("ResponseCallback evt=CancellationRequested");
return false; return false;
} }
if (responseCallback == null) return true; if (responseCallback == null) return true;
var args = new ModelResponseEventArgs(tokenId, response); var args = new ModelResponseEventArgs(tokenId, response);
return responseCallback(args); return responseCallback(args);
}, },
(isRecalculating) => (isRecalculating) =>
{ {
if (cancellationToken.IsCancellationRequested) return false; if (cancellationToken.IsCancellationRequested) return false;
if (recalculateCallback == null) return true; if (recalculateCallback == null) return true;
var args = new ModelRecalculatingEventArgs(isRecalculating); var args = new ModelRecalculatingEventArgs(isRecalculating);
return recalculateCallback(args); return recalculateCallback(args);
}, },
ref context.UnderlyingContext ref context.UnderlyingContext
); );
} }
/// <summary> /// <summary>
/// Set the number of threads to be used by the model. /// Set the number of threads to be used by the model.
/// </summary> /// </summary>
/// <param name="threadCount">The new thread count</param> /// <param name="threadCount">The new thread count</param>
public void SetThreadCount(int threadCount) public void SetThreadCount(int threadCount)
{ {
NativeMethods.llmodel_setThreadCount(_handle, threadCount); NativeMethods.llmodel_setThreadCount(_handle, threadCount);
} }
/// <summary> /// <summary>
/// Get the number of threads used by the model. /// Get the number of threads used by the model.
/// </summary> /// </summary>
/// <returns>the number of threads used by the model</returns> /// <returns>the number of threads used by the model</returns>
public int GetThreadCount() public int GetThreadCount()
{ {
return NativeMethods.llmodel_threadCount(_handle); return NativeMethods.llmodel_threadCount(_handle);
} }
/// <summary> /// <summary>
/// Get the size of the internal state of the model. /// Get the size of the internal state of the model.
/// </summary> /// </summary>
/// <remarks> /// <remarks>
/// This state data is specific to the type of model you have created. /// This state data is specific to the type of model you have created.
/// </remarks> /// </remarks>
/// <returns>the size in bytes of the internal state of the model</returns> /// <returns>the size in bytes of the internal state of the model</returns>
public ulong GetStateSizeBytes() public ulong GetStateSizeBytes()
{ {
return NativeMethods.llmodel_get_state_size(_handle); return NativeMethods.llmodel_get_state_size(_handle);
} }
/// <summary> /// <summary>
/// Saves the internal state of the model to the specified destination address. /// Saves the internal state of the model to the specified destination address.
/// </summary> /// </summary>
/// <param name="source">A pointer to the src</param> /// <param name="source">A pointer to the src</param>
/// <returns>The number of bytes copied</returns> /// <returns>The number of bytes copied</returns>
public unsafe ulong SaveStateData(byte* source) public unsafe ulong SaveStateData(byte* source)
{ {
return NativeMethods.llmodel_save_state_data(_handle, source); return NativeMethods.llmodel_save_state_data(_handle, source);
} }
/// <summary> /// <summary>
/// Restores the internal state of the model using data from the specified address. /// Restores the internal state of the model using data from the specified address.
/// </summary> /// </summary>
/// <param name="destination">A pointer to destination</param> /// <param name="destination">A pointer to destination</param>
/// <returns>the number of bytes read</returns> /// <returns>the number of bytes read</returns>
public unsafe ulong RestoreStateData(byte* destination) public unsafe ulong RestoreStateData(byte* destination)
{ {
return NativeMethods.llmodel_restore_state_data(_handle, destination); return NativeMethods.llmodel_restore_state_data(_handle, destination);
} }
/// <summary> /// <summary>
/// Check if the model is loaded. /// Check if the model is loaded.
/// </summary> /// </summary>
/// <returns>true if the model was loaded successfully, false otherwise.</returns> /// <returns>true if the model was loaded successfully, false otherwise.</returns>
public bool IsLoaded() public bool IsLoaded()
{ {
return NativeMethods.llmodel_isModelLoaded(_handle); return NativeMethods.llmodel_isModelLoaded(_handle);
} }
/// <summary> /// <summary>
/// Load the model from a file. /// Load the model from a file.
/// </summary> /// </summary>
/// <param name="modelPath">The path to the model file.</param> /// <param name="modelPath">The path to the model file.</param>
/// <returns>true if the model was loaded successfully, false otherwise.</returns> /// <returns>true if the model was loaded successfully, false otherwise.</returns>
public bool Load(string modelPath) public bool Load(string modelPath)
{ {
return NativeMethods.llmodel_loadModel(_handle, modelPath); return NativeMethods.llmodel_loadModel(_handle, modelPath);
} }
protected void Destroy() protected void Destroy()
{ {
NativeMethods.llmodel_model_destroy(_handle); NativeMethods.llmodel_model_destroy(_handle);
} }
protected virtual void Dispose(bool disposing)
protected void DestroyLLama() {
{ if (_disposed) return;
NativeMethods.llmodel_llama_destroy(_handle);
} if (disposing)
{
protected void DestroyGptj() // dispose managed state
{ }
NativeMethods.llmodel_gptj_destroy(_handle);
} switch (_modelType)
{
protected void DestroyMtp() default:
{ Destroy();
NativeMethods.llmodel_mpt_destroy(_handle); break;
} }
protected virtual void Dispose(bool disposing) _disposed = true;
{ }
if (_disposed) return;
public void Dispose()
if (disposing) {
{ Dispose(disposing: true);
// dispose managed state GC.SuppressFinalize(this);
} }
}
switch (_modelType)
{
case ModelType.LLAMA:
DestroyLLama();
break;
case ModelType.GPTJ:
DestroyGptj();
break;
case ModelType.MPT:
DestroyMtp();
break;
default:
Destroy();
break;
}
_disposed = true;
}
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}

View File

@ -1,138 +1,138 @@
namespace Gpt4All.Bindings; namespace Gpt4All.Bindings;
/// <summary> /// <summary>
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context. /// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
/// </summary> /// </summary>
/// <remarks> /// <remarks>
/// The implementation takes care of all the memory handling of the raw logits pointer and the /// The implementation takes care of all the memory handling of the raw logits pointer and the
/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior /// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
/// </remarks> /// </remarks>
public unsafe class LLModelPromptContext public unsafe class LLModelPromptContext
{ {
private llmodel_prompt_context _ctx; private llmodel_prompt_context _ctx;
internal ref llmodel_prompt_context UnderlyingContext => ref _ctx; internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
public LLModelPromptContext() public LLModelPromptContext()
{ {
_ctx = new(); _ctx = new();
} }
/// <summary> /// <summary>
/// logits of current context /// logits of current context
/// </summary> /// </summary>
public Span<float> Logits => new(_ctx.logits, (int)_ctx.logits_size); public Span<float> Logits => new(_ctx.logits, (int)_ctx.logits_size);
/// <summary> /// <summary>
/// the size of the raw logits vector /// the size of the raw logits vector
/// </summary> /// </summary>
public nuint LogitsSize public nuint LogitsSize
{ {
get => _ctx.logits_size; get => _ctx.logits_size;
set => _ctx.logits_size = value; set => _ctx.logits_size = value;
} }
/// <summary> /// <summary>
/// current tokens in the context window /// current tokens in the context window
/// </summary> /// </summary>
public Span<int> Tokens => new(_ctx.tokens, (int)_ctx.tokens_size); public Span<int> Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
/// <summary> /// <summary>
/// the size of the raw tokens vector /// the size of the raw tokens vector
/// </summary> /// </summary>
public nuint TokensSize public nuint TokensSize
{ {
get => _ctx.tokens_size; get => _ctx.tokens_size;
set => _ctx.tokens_size = value; set => _ctx.tokens_size = value;
} }
/// <summary> /// <summary>
/// top k logits to sample from /// top k logits to sample from
/// </summary> /// </summary>
public int TopK public int TopK
{ {
get => _ctx.top_k; get => _ctx.top_k;
set => _ctx.top_k = value; set => _ctx.top_k = value;
} }
/// <summary> /// <summary>
/// nucleus sampling probability threshold /// nucleus sampling probability threshold
/// </summary> /// </summary>
public float TopP public float TopP
{ {
get => _ctx.top_p; get => _ctx.top_p;
set => _ctx.top_p = value; set => _ctx.top_p = value;
} }
/// <summary> /// <summary>
/// temperature to adjust model's output distribution /// temperature to adjust model's output distribution
/// </summary> /// </summary>
public float Temperature public float Temperature
{ {
get => _ctx.temp; get => _ctx.temp;
set => _ctx.temp = value; set => _ctx.temp = value;
} }
/// <summary> /// <summary>
/// number of tokens in past conversation /// number of tokens in past conversation
/// </summary> /// </summary>
public int PastNum public int PastNum
{ {
get => _ctx.n_past; get => _ctx.n_past;
set => _ctx.n_past = value; set => _ctx.n_past = value;
} }
/// <summary> /// <summary>
/// number of predictions to generate in parallel /// number of predictions to generate in parallel
/// </summary> /// </summary>
public int Batches public int Batches
{ {
get => _ctx.n_batch; get => _ctx.n_batch;
set => _ctx.n_batch = value; set => _ctx.n_batch = value;
} }
/// <summary> /// <summary>
/// number of tokens to predict /// number of tokens to predict
/// </summary> /// </summary>
public int TokensToPredict public int TokensToPredict
{ {
get => _ctx.n_predict; get => _ctx.n_predict;
set => _ctx.n_predict = value; set => _ctx.n_predict = value;
} }
/// <summary> /// <summary>
/// penalty factor for repeated tokens /// penalty factor for repeated tokens
/// </summary> /// </summary>
public float RepeatPenalty public float RepeatPenalty
{ {
get => _ctx.repeat_penalty; get => _ctx.repeat_penalty;
set => _ctx.repeat_penalty = value; set => _ctx.repeat_penalty = value;
} }
/// <summary> /// <summary>
/// last n tokens to penalize /// last n tokens to penalize
/// </summary> /// </summary>
public int RepeatLastN public int RepeatLastN
{ {
get => _ctx.repeat_last_n; get => _ctx.repeat_last_n;
set => _ctx.repeat_last_n = value; set => _ctx.repeat_last_n = value;
} }
/// <summary> /// <summary>
/// number of tokens possible in context window /// number of tokens possible in context window
/// </summary> /// </summary>
public int ContextSize public int ContextSize
{ {
get => _ctx.n_ctx; get => _ctx.n_ctx;
set => _ctx.n_ctx = value; set => _ctx.n_ctx = value;
} }
/// <summary> /// <summary>
/// percent of context to erase if we exceed the context window /// percent of context to erase if we exceed the context window
/// </summary> /// </summary>
public float ContextErase public float ContextErase
{ {
get => _ctx.context_erase; get => _ctx.context_erase;
set => _ctx.context_erase = value; set => _ctx.context_erase = value;
} }
} }

View File

@ -1,126 +1,107 @@
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Gpt4All.Bindings; namespace Gpt4All.Bindings;
public unsafe partial struct llmodel_prompt_context public unsafe partial struct llmodel_prompt_context
{ {
public float* logits; public float* logits;
[NativeTypeName("size_t")] [NativeTypeName("size_t")]
public nuint logits_size; public nuint logits_size;
[NativeTypeName("int32_t *")] [NativeTypeName("int32_t *")]
public int* tokens; public int* tokens;
[NativeTypeName("size_t")] [NativeTypeName("size_t")]
public nuint tokens_size; public nuint tokens_size;
[NativeTypeName("int32_t")] [NativeTypeName("int32_t")]
public int n_past; public int n_past;
[NativeTypeName("int32_t")] [NativeTypeName("int32_t")]
public int n_ctx; public int n_ctx;
[NativeTypeName("int32_t")] [NativeTypeName("int32_t")]
public int n_predict; public int n_predict;
[NativeTypeName("int32_t")] [NativeTypeName("int32_t")]
public int top_k; public int top_k;
public float top_p; public float top_p;
public float temp; public float temp;
[NativeTypeName("int32_t")] [NativeTypeName("int32_t")]
public int n_batch; public int n_batch;
public float repeat_penalty; public float repeat_penalty;
[NativeTypeName("int32_t")] [NativeTypeName("int32_t")]
public int repeat_last_n; public int repeat_last_n;
public float context_erase; public float context_erase;
} }
internal static unsafe partial class NativeMethods internal static unsafe partial class NativeMethods
{ {
[UnmanagedFunctionPointer(CallingConvention.Cdecl)] [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.I1)] [return: MarshalAs(UnmanagedType.I1)]
public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response); public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)] [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.I1)] [return: MarshalAs(UnmanagedType.I1)]
public delegate bool LlmodelPromptCallback(int token_id); public delegate bool LlmodelPromptCallback(int token_id);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)] [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.I1)] [return: MarshalAs(UnmanagedType.I1)]
public delegate bool LlmodelRecalculateCallback(bool isRecalculating); public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
[return: NativeTypeName("llmodel_model")] [return: NativeTypeName("llmodel_model")]
public static extern IntPtr llmodel_gptj_create(); public static extern IntPtr llmodel_model_create2(
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string build_variant,
public static extern void llmodel_gptj_destroy([NativeTypeName("llmodel_model")] IntPtr gptj); out IntPtr error);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("llmodel_model")] public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
public static extern IntPtr llmodel_mpt_create();
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [return: MarshalAs(UnmanagedType.I1)]
public static extern void llmodel_mpt_destroy([NativeTypeName("llmodel_model")] IntPtr mpt); public static extern bool llmodel_loadModel(
[NativeTypeName("llmodel_model")] IntPtr model,
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
[return: NativeTypeName("llmodel_model")]
public static extern IntPtr llmodel_llama_create(); [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [return: MarshalAs(UnmanagedType.I1)]
public static extern void llmodel_llama_destroy([NativeTypeName("llmodel_model")] IntPtr llama); public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)] [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("llmodel_model")] [return: NativeTypeName("uint64_t")]
public static extern IntPtr llmodel_model_create( public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [return: NativeTypeName("uint64_t")]
public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model); public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)] [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: MarshalAs(UnmanagedType.I1)] [return: NativeTypeName("uint64_t")]
public static extern bool llmodel_loadModel( public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
[NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path); [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
public static extern void llmodel_prompt(
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
[return: MarshalAs(UnmanagedType.I1)] LlmodelPromptCallback prompt_callback,
public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model); LlmodelResponseCallback response_callback,
LlmodelRecalculateCallback recalculate_callback,
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] ref llmodel_prompt_context ctx);
[return: NativeTypeName("uint64_t")]
public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model); [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("uint64_t")] [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest); [return: NativeTypeName("int32_t")]
public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] }
[return: NativeTypeName("uint64_t")]
public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
public static extern void llmodel_prompt(
[NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
LlmodelPromptCallback prompt_callback,
LlmodelResponseCallback response_callback,
LlmodelRecalculateCallback recalculate_callback,
ref llmodel_prompt_context ctx);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
[return: NativeTypeName("int32_t")]
public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
}

View File

@ -1,27 +1,11 @@
<Project Sdk="Microsoft.NET.Sdk"> <Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<PropertyGroup> <TargetFramework>net6.0</TargetFramework>
<TargetFrameworks>net6.0</TargetFrameworks> <ImplicitUsings>enable</ImplicitUsings>
<ImplicitUsings>enable</ImplicitUsings> <Nullable>enable</Nullable>
<Nullable>enable</Nullable> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> </PropertyGroup>
</PropertyGroup> <ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="7.0.0" />
<ItemGroup> </ItemGroup>
<!-- Windows -->
<None Include="..\runtimes\win-x64\native\*.dll" Pack="true" PackagePath="runtimes\win-x64\native\%(Filename)%(Extension)" />
<!-- Linux -->
<None Include="..\runtimes\linux-x64\native\*.so" Pack="true" PackagePath="runtimes\linux-x64\native\%(Filename)%(Extension)" />
</ItemGroup>
<ItemGroup>
<!-- Windows -->
<None Condition="$([MSBuild]::IsOSPlatform('Windows'))" Include="..\runtimes\win-x64\native\*.dll" Visible="False" CopyToOutputDirectory="PreserveNewest" />
<!-- Linux -->
<None Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="..\runtimes\linux-x64\native\*.so" Visible="False" CopyToOutputDirectory="PreserveNewest" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="7.0.0" />
</ItemGroup>
</Project> </Project>

View File

@ -0,0 +1,6 @@
namespace Gpt4All.LibraryLoader;
public interface ILibraryLoader
{
LoadResult OpenLibrary(string? fileName);
}

View File

@ -0,0 +1,53 @@
using System.Runtime.InteropServices;
namespace Gpt4All.LibraryLoader;
internal class LinuxLibraryLoader : ILibraryLoader
{
#pragma warning disable CA2101
[DllImport("libdl.so", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")]
#pragma warning restore CA2101
public static extern IntPtr NativeOpenLibraryLibdl(string? filename, int flags);
#pragma warning disable CA2101
[DllImport("libdl.so.2", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")]
#pragma warning restore CA2101
public static extern IntPtr NativeOpenLibraryLibdl2(string? filename, int flags);
[DllImport("libdl.so", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")]
public static extern IntPtr GetLoadError();
[DllImport("libdl.so.2", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")]
public static extern IntPtr GetLoadError2();
public LoadResult OpenLibrary(string? fileName)
{
IntPtr loadedLib;
try
{
// open with rtls lazy flag
loadedLib = NativeOpenLibraryLibdl2(fileName, 0x00001);
}
catch (DllNotFoundException)
{
loadedLib = NativeOpenLibraryLibdl(fileName, 0x00001);
}
if (loadedLib == IntPtr.Zero)
{
string errorMessage;
try
{
errorMessage = Marshal.PtrToStringAnsi(GetLoadError2()) ?? "Unknown error";
}
catch (DllNotFoundException)
{
errorMessage = Marshal.PtrToStringAnsi(GetLoadError()) ?? "Unknown error";
}
return LoadResult.Failure(errorMessage);
}
return LoadResult.Success;
}
}

View File

@ -0,0 +1,20 @@
namespace Gpt4All.LibraryLoader;
public class LoadResult
{
private LoadResult(bool isSuccess, string? errorMessage)
{
IsSuccess = isSuccess;
ErrorMessage = errorMessage;
}
public static LoadResult Success { get; } = new(true, null);
public static LoadResult Failure(string errorMessage)
{
return new(false, errorMessage);
}
public bool IsSuccess { get; }
public string? ErrorMessage { get; }
}

View File

@ -0,0 +1,28 @@
using System.Runtime.InteropServices;
namespace Gpt4All.LibraryLoader;
internal class MacOsLibraryLoader : ILibraryLoader
{
#pragma warning disable CA2101
[DllImport("libdl.dylib", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")]
#pragma warning restore CA2101
public static extern IntPtr NativeOpenLibraryLibdl(string? filename, int flags);
[DllImport("libdl.dylib", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")]
public static extern IntPtr GetLoadError();
public LoadResult OpenLibrary(string? fileName)
{
var loadedLib = NativeOpenLibraryLibdl(fileName, 0x00001);
if (loadedLib == IntPtr.Zero)
{
var errorMessage = Marshal.PtrToStringAnsi(GetLoadError()) ?? "Unknown error";
return LoadResult.Failure(errorMessage);
}
return LoadResult.Success;
}
}

View File

@ -0,0 +1,81 @@
#if !IOS && !MACCATALYST && !TVOS && !ANDROID
using System.Runtime.InteropServices;
#endif
namespace Gpt4All.LibraryLoader;
public static class NativeLibraryLoader
{
private static ILibraryLoader? defaultLibraryLoader;
/// <summary>
/// Sets the library loader used to load the native libraries. Overwrite this only if you want some custom loading.
/// </summary>
/// <param name="libraryLoader">The library loader to be used.</param>
public static void SetLibraryLoader(ILibraryLoader libraryLoader)
{
defaultLibraryLoader = libraryLoader;
}
internal static LoadResult LoadNativeLibrary(string? path = default, bool bypassLoading = true)
{
// If the user has handled loading the library themselves, we don't need to do anything.
if (bypassLoading)
{
return LoadResult.Success;
}
var architecture = RuntimeInformation.OSArchitecture switch
{
Architecture.X64 => "x64",
Architecture.X86 => "x86",
Architecture.Arm => "arm",
Architecture.Arm64 => "arm64",
_ => throw new PlatformNotSupportedException(
$"Unsupported OS platform, architecture: {RuntimeInformation.OSArchitecture}")
};
var (platform, extension) = Environment.OSVersion.Platform switch
{
_ when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ("win", "dll"),
_ when RuntimeInformation.IsOSPlatform(OSPlatform.Linux) => ("linux", "so"),
_ when RuntimeInformation.IsOSPlatform(OSPlatform.OSX) => ("osx", "dylib"),
_ => throw new PlatformNotSupportedException(
$"Unsupported OS platform, architecture: {RuntimeInformation.OSArchitecture}")
};
// If the user hasn't set the path, we'll try to find it ourselves.
if (string.IsNullOrEmpty(path))
{
var libraryName = "libllmodel";
var assemblySearchPath = new[]
{
AppDomain.CurrentDomain.RelativeSearchPath,
Path.GetDirectoryName(typeof(NativeLibraryLoader).Assembly.Location),
Path.GetDirectoryName(Environment.GetCommandLineArgs()[0])
}.FirstOrDefault(it => !string.IsNullOrEmpty(it));
// Search for the library dll within the assembly search path. If it doesn't exist, for whatever reason, use the default path.
path = Directory.EnumerateFiles(assemblySearchPath ?? string.Empty, $"{libraryName}.{extension}", SearchOption.AllDirectories).FirstOrDefault() ?? Path.Combine("runtimes", $"{platform}-{architecture}", $"{libraryName}.{extension}");
}
if (defaultLibraryLoader != null)
{
return defaultLibraryLoader.OpenLibrary(path);
}
if (!File.Exists(path))
{
throw new FileNotFoundException($"Native Library not found in path {path}. " +
$"Verify you have have included the native Gpt4All library in your application.");
}
ILibraryLoader libraryLoader = platform switch
{
"win" => new WindowsLibraryLoader(),
"osx" => new MacOsLibraryLoader(),
"linux" => new LinuxLibraryLoader(),
_ => throw new PlatformNotSupportedException($"Currently {platform} platform is not supported")
};
return libraryLoader.OpenLibrary(path);
}
}

View File

@ -0,0 +1,24 @@
using System.ComponentModel;
using System.Runtime.InteropServices;
namespace Gpt4All.LibraryLoader;
internal class WindowsLibraryLoader : ILibraryLoader
{
public LoadResult OpenLibrary(string? fileName)
{
var loadedLib = LoadLibrary(fileName);
if (loadedLib == IntPtr.Zero)
{
var errorCode = Marshal.GetLastWin32Error();
var errorMessage = new Win32Exception(errorCode).Message;
return LoadResult.Failure(errorMessage);
}
return LoadResult.Success;
}
[DllImport("kernel32", SetLastError = true, CharSet = CharSet.Auto)]
private static extern IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPWStr)] string? lpFileName);
}

View File

@ -1,61 +1,58 @@
using System.Diagnostics; using System.Diagnostics;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions;
using Gpt4All.Bindings; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions; using Gpt4All.Bindings;
using Gpt4All.LibraryLoader;
namespace Gpt4All;
namespace Gpt4All;
public class Gpt4AllModelFactory : IGpt4AllModelFactory
{ public class Gpt4AllModelFactory : IGpt4AllModelFactory
private readonly ILoggerFactory _loggerFactory; {
private readonly ILogger _logger; private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
public Gpt4AllModelFactory(ILoggerFactory? loggerFactory = null) private static bool bypassLoading;
{ private static string? libraryPath;
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger<Gpt4AllModelFactory>(); private static readonly Lazy<LoadResult> libraryLoaded = new(() =>
} {
return NativeLibraryLoader.LoadNativeLibrary(Gpt4AllModelFactory.libraryPath, Gpt4AllModelFactory.bypassLoading);
private IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null) }, true);
{
var modelType_ = modelType ?? ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath); public Gpt4AllModelFactory(string? libraryPath = default, bool bypassLoading = true, ILoggerFactory? loggerFactory = null)
{
_logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_); _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger<Gpt4AllModelFactory>();
var handle = modelType_ switch Gpt4AllModelFactory.libraryPath = libraryPath;
{ Gpt4AllModelFactory.bypassLoading = bypassLoading;
ModelType.LLAMA => NativeMethods.llmodel_llama_create(),
ModelType.GPTJ => NativeMethods.llmodel_gptj_create(), if (!libraryLoaded.Value.IsSuccess)
ModelType.MPT => NativeMethods.llmodel_mpt_create(), {
_ => NativeMethods.llmodel_model_create(modelPath), throw new Exception($"Failed to load native gpt4all library. Error: {libraryLoaded.Value.ErrorMessage}");
}; }
}
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
_logger.LogInformation("Model loading started"); private IGpt4AllModel CreateModel(string modelPath)
{
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath); var modelType_ = ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath);
_logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_);
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully); IntPtr error;
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
if (loadedSuccessfully == false) _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
{ _logger.LogInformation("Model loading started");
throw new Exception($"Failed to load model: '{modelPath}'"); var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
} _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
if (!loadedSuccessfully)
var logger = _loggerFactory.CreateLogger<LLModel>(); {
throw new Exception($"Failed to load model: '{modelPath}'");
var underlyingModel = LLModel.Create(handle, modelType_, logger: logger); }
Debug.Assert(underlyingModel.IsLoaded()); var logger = _loggerFactory.CreateLogger<LLModel>();
var underlyingModel = LLModel.Create(handle, modelType_, logger: logger);
return new Gpt4All(underlyingModel, logger: logger);
} Debug.Assert(underlyingModel.IsLoaded());
public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath, modelType: null); return new Gpt4All(underlyingModel, logger: logger);
}
public IGpt4AllModel LoadMptModel(string modelPath) => CreateModel(modelPath, ModelType.MPT);
public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath);
public IGpt4AllModel LoadGptjModel(string modelPath) => CreateModel(modelPath, ModelType.GPTJ); }
public IGpt4AllModel LoadLlamaModel(string modelPath) => CreateModel(modelPath, ModelType.LLAMA);
}

View File

@ -1,12 +1,6 @@
namespace Gpt4All; namespace Gpt4All;
public interface IGpt4AllModelFactory public interface IGpt4AllModelFactory
{ {
IGpt4AllModel LoadGptjModel(string modelPath); IGpt4AllModel LoadModel(string modelPath);
}
IGpt4AllModel LoadLlamaModel(string modelPath);
IGpt4AllModel LoadModel(string modelPath);
IGpt4AllModel LoadMptModel(string modelPath);
}

View File

@ -1,11 +1,11 @@
namespace Gpt4All; namespace Gpt4All;
/// <summary> /// <summary>
/// The supported model types /// The supported model types
/// </summary> /// </summary>
public enum ModelType public enum ModelType
{ {
LLAMA = 0, LLAMA = 0,
GPTJ, GPTJ,
MPT MPT
} }

View File

@ -1,31 +1,31 @@
namespace Gpt4All; namespace Gpt4All;
/// <summary> /// <summary>
/// Interface for text prediction services /// Interface for text prediction services
/// </summary> /// </summary>
public interface ITextPrediction public interface ITextPrediction
{ {
/// <summary> /// <summary>
/// Get prediction results for the prompt and provided options. /// Get prediction results for the prompt and provided options.
/// </summary> /// </summary>
/// <param name="text">The text to complete</param> /// <param name="text">The text to complete</param>
/// <param name="opts">The prediction settings</param> /// <param name="opts">The prediction settings</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param> /// <param name="cancellation">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The prediction result generated by the model</returns> /// <returns>The prediction result generated by the model</returns>
Task<ITextPredictionResult> GetPredictionAsync( Task<ITextPredictionResult> GetPredictionAsync(
string text, string text,
PredictRequestOptions opts, PredictRequestOptions opts,
CancellationToken cancellation = default); CancellationToken cancellation = default);
/// <summary> /// <summary>
/// Get streaming prediction results for the prompt and provided options. /// Get streaming prediction results for the prompt and provided options.
/// </summary> /// </summary>
/// <param name="text">The text to complete</param> /// <param name="text">The text to complete</param>
/// <param name="opts">The prediction settings</param> /// <param name="opts">The prediction settings</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param> /// <param name="cancellationToken">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The prediction result generated by the model</returns> /// <returns>The prediction result generated by the model</returns>
Task<ITextPredictionStreamingResult> GetStreamingPredictionAsync( Task<ITextPredictionStreamingResult> GetStreamingPredictionAsync(
string text, string text,
PredictRequestOptions opts, PredictRequestOptions opts,
CancellationToken cancellationToken = default); CancellationToken cancellationToken = default);
} }

View File

@ -5,4 +5,6 @@ mkdir runtimes/linux-x64/build
cmake -S ../../gpt4all-backend -B runtimes/linux-x64/build cmake -S ../../gpt4all-backend -B runtimes/linux-x64/build
cmake --build runtimes/linux-x64/build --parallel --config Release cmake --build runtimes/linux-x64/build --parallel --config Release
cp runtimes/linux-x64/build/libllmodel.so runtimes/linux-x64/native/libllmodel.so cp runtimes/linux-x64/build/libllmodel.so runtimes/linux-x64/native/libllmodel.so
cp runtimes/linux-x64/build/llama.cpp/libllama.so runtimes/linux-x64/native/libllama.so cp runtimes/linux-x64/build/libgptj*.so runtimes/linux-x64/native/
cp runtimes/linux-x64/build/libllama*.so runtimes/linux-x64/native/
cp runtimes/linux-x64/build/libmpt*.so runtimes/linux-x64/native/

View File

@ -13,4 +13,5 @@ cmake --build $BUILD_DIR --parallel --config Release
# copy native dlls # copy native dlls
cp "C:\ProgramData\chocolatey\lib\mingw\tools\install\mingw64\bin\*dll" $LIBS_DIR cp "C:\ProgramData\chocolatey\lib\mingw\tools\install\mingw64\bin\*dll" $LIBS_DIR
cp "$BUILD_DIR\*.dll" $LIBS_DIR cp "$BUILD_DIR\bin\*.dll" $LIBS_DIR
mv $LIBS_DIR\llmodel.dll $LIBS_DIR\libllmodel.dll

View File

@ -2,4 +2,5 @@ Remove-Item -Force -Recurse .\runtimes\win-x64\msvc -ErrorAction SilentlyContinu
mkdir .\runtimes\win-x64\msvc\build | Out-Null mkdir .\runtimes\win-x64\msvc\build | Out-Null
cmake -G "Visual Studio 17 2022" -A X64 -S ..\..\gpt4all-backend -B .\runtimes\win-x64\msvc\build cmake -G "Visual Studio 17 2022" -A X64 -S ..\..\gpt4all-backend -B .\runtimes\win-x64\msvc\build
cmake --build .\runtimes\win-x64\msvc\build --parallel --config Release cmake --build .\runtimes\win-x64\msvc\build --parallel --config Release
cp .\runtimes\win-x64\msvc\build\bin\Release\*.dll .\runtimes\win-x64 cp .\runtimes\win-x64\msvc\build\bin\Release\*.dll .\runtimes\win-x64
mv .\runtimes\win-x64\llmodel.dll .\runtimes\win-x64\libllmodel.dll