mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 03:56:45 +00:00
Initial Library Loader for .NET Bindings / Update bindings to support newest changes (#763)
* Initial Library Loader * Load library as part of Model factory * Dynamically search and find the dlls * Update tests to use locally built runtimes * Fix dylib loading, add macos runtime support for sample/tests * Bypass automatic loading by default. * Only set CMAKE_OSX_ARCHITECTURES if not already set, allow cross-compile * Switch Loading again * Update build scripts for mac/linux * Update bindings to support newest breaking changes * Fix build * Use llmodel for Windows * Actually, it does need to be libllmodel * Name * Remove TFMs, bypass loading by default * Fix script * Delete mac script --------- Co-authored-by: Tim Miller <innerlogic4321@ghmail.com>
This commit is contained in:
parent
88616fde7f
commit
797891c995
@ -9,7 +9,9 @@ if(APPLE)
|
|||||||
set(CMAKE_OSX_ARCHITECTURES "arm64;x86_64" CACHE STRING "" FORCE)
|
set(CMAKE_OSX_ARCHITECTURES "arm64;x86_64" CACHE STRING "" FORCE)
|
||||||
else()
|
else()
|
||||||
# Build for the host architecture on macOS
|
# Build for the host architecture on macOS
|
||||||
set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE)
|
if(NOT CMAKE_OSX_ARCHITECTURES)
|
||||||
|
set(CMAKE_OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}" CACHE STRING "" FORCE)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -1,18 +1,31 @@
|
|||||||
<Project Sdk="Microsoft.NET.Sdk">
|
<Project Sdk="Microsoft.NET.Sdk">
|
||||||
|
|
||||||
<PropertyGroup>
|
<PropertyGroup>
|
||||||
<OutputType>Exe</OutputType>
|
<OutputType>Exe</OutputType>
|
||||||
<TargetFramework>net7.0</TargetFramework>
|
<TargetFramework>net7.0</TargetFramework>
|
||||||
<ImplicitUsings>enable</ImplicitUsings>
|
<ImplicitUsings>enable</ImplicitUsings>
|
||||||
<Nullable>enable</Nullable>
|
<Nullable>enable</Nullable>
|
||||||
</PropertyGroup>
|
</PropertyGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<ProjectReference Include="..\Gpt4All\Gpt4All.csproj" />
|
<ProjectReference Include="..\Gpt4All\Gpt4All.csproj" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<Folder Include="Properties\" />
|
<!-- Windows -->
|
||||||
</ItemGroup>
|
<None Include="..\runtimes\win-x64\native\*.dll" Pack="true" PackagePath="runtimes\win-x64\native\%(Filename)%(Extension)" />
|
||||||
|
<!-- Linux -->
|
||||||
|
<None Include="..\runtimes\linux-x64\native\*.so" Pack="true" PackagePath="runtimes\linux-x64\native\%(Filename)%(Extension)" />
|
||||||
|
<!-- MacOS -->
|
||||||
|
<None Include="..\runtimes\osx\native\*.dylib" Pack="true" PackagePath="runtimes\osx\native\%(Filename)%(Extension)" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<!-- Windows -->
|
||||||
|
<None Condition="$([MSBuild]::IsOSPlatform('Windows'))" Include="..\runtimes\win-x64\native\*.dll" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
||||||
|
<!-- Linux -->
|
||||||
|
<None Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="..\runtimes\linux-x64\native\*.so" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
||||||
|
<!-- MacOS -->
|
||||||
|
<None Condition="$([MSBuild]::IsOSPlatform('OSX'))" Include="..\runtimes\osx\native\*.dylib" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
||||||
|
</ItemGroup>
|
||||||
</Project>
|
</Project>
|
||||||
|
@ -21,7 +21,24 @@
|
|||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<ProjectReference Include="..\Gpt4All\Gpt4All.csproj" />
|
<ProjectReference Include="..\Gpt4All\Gpt4All.csproj" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<!-- Windows -->
|
||||||
|
<None Include="..\runtimes\win-x64\native\*.dll" Pack="true" PackagePath="runtimes\win-x64\native\%(Filename)%(Extension)" />
|
||||||
|
<!-- Linux -->
|
||||||
|
<None Include="..\runtimes\linux-x64\native\*.so" Pack="true" PackagePath="runtimes\linux-x64\native\%(Filename)%(Extension)" />
|
||||||
|
<!-- MacOS -->
|
||||||
|
<None Include="..\runtimes\osx\native\*.dylib" Pack="true" PackagePath="runtimes\osx\native\%(Filename)%(Extension)" />
|
||||||
|
</ItemGroup>
|
||||||
|
|
||||||
|
<ItemGroup>
|
||||||
|
<!-- Windows -->
|
||||||
|
<None Condition="$([MSBuild]::IsOSPlatform('Windows'))" Include="..\runtimes\win-x64\native\*.dll" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
||||||
|
<!-- Linux -->
|
||||||
|
<None Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="..\runtimes\linux-x64\native\*.so" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
||||||
|
<!-- MacOS -->
|
||||||
|
<None Condition="$([MSBuild]::IsOSPlatform('OSX'))" Include="..\runtimes\osx\native\*.dylib" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
||||||
|
</ItemGroup>
|
||||||
</Project>
|
</Project>
|
||||||
|
@ -14,18 +14,18 @@ public class ModelFactoryTests
|
|||||||
[Fact]
|
[Fact]
|
||||||
public void CanLoadLlamaModel()
|
public void CanLoadLlamaModel()
|
||||||
{
|
{
|
||||||
using var model = _modelFactory.LoadLlamaModel(Constants.LLAMA_MODEL_PATH);
|
using var model = _modelFactory.LoadModel(Constants.LLAMA_MODEL_PATH);
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
public void CanLoadGptjModel()
|
public void CanLoadGptjModel()
|
||||||
{
|
{
|
||||||
using var model = _modelFactory.LoadGptjModel(Constants.GPTJ_MODEL_PATH);
|
using var model = _modelFactory.LoadModel(Constants.GPTJ_MODEL_PATH);
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
public void CanLoadMptModel()
|
public void CanLoadMptModel()
|
||||||
{
|
{
|
||||||
using var model = _modelFactory.LoadMptModel(Constants.MPT_MODEL_PATH);
|
using var model = _modelFactory.LoadModel(Constants.MPT_MODEL_PATH);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,247 +1,222 @@
|
|||||||
using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Logging;
|
||||||
using Microsoft.Extensions.Logging.Abstractions;
|
using Microsoft.Extensions.Logging.Abstractions;
|
||||||
|
|
||||||
namespace Gpt4All.Bindings;
|
namespace Gpt4All.Bindings;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Arguments for the response processing callback
|
/// Arguments for the response processing callback
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="TokenId">The token id of the response</param>
|
/// <param name="TokenId">The token id of the response</param>
|
||||||
/// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param>
|
/// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param>
|
||||||
/// <return>
|
/// <return>
|
||||||
/// A bool indicating whether the model should keep generating
|
/// A bool indicating whether the model should keep generating
|
||||||
/// </return>
|
/// </return>
|
||||||
public record ModelResponseEventArgs(int TokenId, string Response)
|
public record ModelResponseEventArgs(int TokenId, string Response)
|
||||||
{
|
{
|
||||||
public bool IsError => TokenId == -1;
|
public bool IsError => TokenId == -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Arguments for the prompt processing callback
|
/// Arguments for the prompt processing callback
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="TokenId">The token id of the prompt</param>
|
/// <param name="TokenId">The token id of the prompt</param>
|
||||||
/// <return>
|
/// <return>
|
||||||
/// A bool indicating whether the model should keep processing
|
/// A bool indicating whether the model should keep processing
|
||||||
/// </return>
|
/// </return>
|
||||||
public record ModelPromptEventArgs(int TokenId)
|
public record ModelPromptEventArgs(int TokenId)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Arguments for the recalculating callback
|
/// Arguments for the recalculating callback
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="IsRecalculating"> whether the model is recalculating the context.</param>
|
/// <param name="IsRecalculating"> whether the model is recalculating the context.</param>
|
||||||
/// <return>
|
/// <return>
|
||||||
/// A bool indicating whether the model should keep generating
|
/// A bool indicating whether the model should keep generating
|
||||||
/// </return>
|
/// </return>
|
||||||
public record ModelRecalculatingEventArgs(bool IsRecalculating);
|
public record ModelRecalculatingEventArgs(bool IsRecalculating);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
|
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public class LLModel : ILLModel
|
public class LLModel : ILLModel
|
||||||
{
|
{
|
||||||
protected readonly IntPtr _handle;
|
protected readonly IntPtr _handle;
|
||||||
private readonly ModelType _modelType;
|
private readonly ModelType _modelType;
|
||||||
private readonly ILogger _logger;
|
private readonly ILogger _logger;
|
||||||
private bool _disposed;
|
private bool _disposed;
|
||||||
|
|
||||||
public ModelType ModelType => _modelType;
|
public ModelType ModelType => _modelType;
|
||||||
|
|
||||||
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
||||||
{
|
{
|
||||||
_handle = handle;
|
_handle = handle;
|
||||||
_modelType = modelType;
|
_modelType = modelType;
|
||||||
_logger = logger ?? NullLogger.Instance;
|
_logger = logger ?? NullLogger.Instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Create a new model from a pointer
|
/// Create a new model from a pointer
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="handle">Pointer to underlying model</param>
|
/// <param name="handle">Pointer to underlying model</param>
|
||||||
/// <param name="modelType">The model type</param>
|
/// <param name="modelType">The model type</param>
|
||||||
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
||||||
{
|
{
|
||||||
return new LLModel(handle, modelType, logger: logger);
|
return new LLModel(handle, modelType, logger: logger);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Generate a response using the model
|
/// Generate a response using the model
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The input promp</param>
|
/// <param name="text">The input promp</param>
|
||||||
/// <param name="context">The context</param>
|
/// <param name="context">The context</param>
|
||||||
/// <param name="promptCallback">A callback function for handling the processing of prompt</param>
|
/// <param name="promptCallback">A callback function for handling the processing of prompt</param>
|
||||||
/// <param name="responseCallback">A callback function for handling the generated response</param>
|
/// <param name="responseCallback">A callback function for handling the generated response</param>
|
||||||
/// <param name="recalculateCallback">A callback function for handling recalculation requests</param>
|
/// <param name="recalculateCallback">A callback function for handling recalculation requests</param>
|
||||||
/// <param name="cancellationToken"></param>
|
/// <param name="cancellationToken"></param>
|
||||||
public void Prompt(
|
public void Prompt(
|
||||||
string text,
|
string text,
|
||||||
LLModelPromptContext context,
|
LLModelPromptContext context,
|
||||||
Func<ModelPromptEventArgs, bool>? promptCallback = null,
|
Func<ModelPromptEventArgs, bool>? promptCallback = null,
|
||||||
Func<ModelResponseEventArgs, bool>? responseCallback = null,
|
Func<ModelResponseEventArgs, bool>? responseCallback = null,
|
||||||
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
|
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
|
||||||
CancellationToken cancellationToken = default)
|
CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
GC.KeepAlive(promptCallback);
|
GC.KeepAlive(promptCallback);
|
||||||
GC.KeepAlive(responseCallback);
|
GC.KeepAlive(responseCallback);
|
||||||
GC.KeepAlive(recalculateCallback);
|
GC.KeepAlive(recalculateCallback);
|
||||||
GC.KeepAlive(cancellationToken);
|
GC.KeepAlive(cancellationToken);
|
||||||
|
|
||||||
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
|
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
|
||||||
|
|
||||||
NativeMethods.llmodel_prompt(
|
NativeMethods.llmodel_prompt(
|
||||||
_handle,
|
_handle,
|
||||||
text,
|
text,
|
||||||
(tokenId) =>
|
(tokenId) =>
|
||||||
{
|
{
|
||||||
if (cancellationToken.IsCancellationRequested) return false;
|
if (cancellationToken.IsCancellationRequested) return false;
|
||||||
if (promptCallback == null) return true;
|
if (promptCallback == null) return true;
|
||||||
var args = new ModelPromptEventArgs(tokenId);
|
var args = new ModelPromptEventArgs(tokenId);
|
||||||
return promptCallback(args);
|
return promptCallback(args);
|
||||||
},
|
},
|
||||||
(tokenId, response) =>
|
(tokenId, response) =>
|
||||||
{
|
{
|
||||||
if (cancellationToken.IsCancellationRequested)
|
if (cancellationToken.IsCancellationRequested)
|
||||||
{
|
{
|
||||||
_logger.LogDebug("ResponseCallback evt=CancellationRequested");
|
_logger.LogDebug("ResponseCallback evt=CancellationRequested");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (responseCallback == null) return true;
|
if (responseCallback == null) return true;
|
||||||
var args = new ModelResponseEventArgs(tokenId, response);
|
var args = new ModelResponseEventArgs(tokenId, response);
|
||||||
return responseCallback(args);
|
return responseCallback(args);
|
||||||
},
|
},
|
||||||
(isRecalculating) =>
|
(isRecalculating) =>
|
||||||
{
|
{
|
||||||
if (cancellationToken.IsCancellationRequested) return false;
|
if (cancellationToken.IsCancellationRequested) return false;
|
||||||
if (recalculateCallback == null) return true;
|
if (recalculateCallback == null) return true;
|
||||||
var args = new ModelRecalculatingEventArgs(isRecalculating);
|
var args = new ModelRecalculatingEventArgs(isRecalculating);
|
||||||
return recalculateCallback(args);
|
return recalculateCallback(args);
|
||||||
},
|
},
|
||||||
ref context.UnderlyingContext
|
ref context.UnderlyingContext
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Set the number of threads to be used by the model.
|
/// Set the number of threads to be used by the model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="threadCount">The new thread count</param>
|
/// <param name="threadCount">The new thread count</param>
|
||||||
public void SetThreadCount(int threadCount)
|
public void SetThreadCount(int threadCount)
|
||||||
{
|
{
|
||||||
NativeMethods.llmodel_setThreadCount(_handle, threadCount);
|
NativeMethods.llmodel_setThreadCount(_handle, threadCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of threads used by the model.
|
/// Get the number of threads used by the model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <returns>the number of threads used by the model</returns>
|
/// <returns>the number of threads used by the model</returns>
|
||||||
public int GetThreadCount()
|
public int GetThreadCount()
|
||||||
{
|
{
|
||||||
return NativeMethods.llmodel_threadCount(_handle);
|
return NativeMethods.llmodel_threadCount(_handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the size of the internal state of the model.
|
/// Get the size of the internal state of the model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <remarks>
|
/// <remarks>
|
||||||
/// This state data is specific to the type of model you have created.
|
/// This state data is specific to the type of model you have created.
|
||||||
/// </remarks>
|
/// </remarks>
|
||||||
/// <returns>the size in bytes of the internal state of the model</returns>
|
/// <returns>the size in bytes of the internal state of the model</returns>
|
||||||
public ulong GetStateSizeBytes()
|
public ulong GetStateSizeBytes()
|
||||||
{
|
{
|
||||||
return NativeMethods.llmodel_get_state_size(_handle);
|
return NativeMethods.llmodel_get_state_size(_handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Saves the internal state of the model to the specified destination address.
|
/// Saves the internal state of the model to the specified destination address.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="source">A pointer to the src</param>
|
/// <param name="source">A pointer to the src</param>
|
||||||
/// <returns>The number of bytes copied</returns>
|
/// <returns>The number of bytes copied</returns>
|
||||||
public unsafe ulong SaveStateData(byte* source)
|
public unsafe ulong SaveStateData(byte* source)
|
||||||
{
|
{
|
||||||
return NativeMethods.llmodel_save_state_data(_handle, source);
|
return NativeMethods.llmodel_save_state_data(_handle, source);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Restores the internal state of the model using data from the specified address.
|
/// Restores the internal state of the model using data from the specified address.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="destination">A pointer to destination</param>
|
/// <param name="destination">A pointer to destination</param>
|
||||||
/// <returns>the number of bytes read</returns>
|
/// <returns>the number of bytes read</returns>
|
||||||
public unsafe ulong RestoreStateData(byte* destination)
|
public unsafe ulong RestoreStateData(byte* destination)
|
||||||
{
|
{
|
||||||
return NativeMethods.llmodel_restore_state_data(_handle, destination);
|
return NativeMethods.llmodel_restore_state_data(_handle, destination);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Check if the model is loaded.
|
/// Check if the model is loaded.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||||
public bool IsLoaded()
|
public bool IsLoaded()
|
||||||
{
|
{
|
||||||
return NativeMethods.llmodel_isModelLoaded(_handle);
|
return NativeMethods.llmodel_isModelLoaded(_handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Load the model from a file.
|
/// Load the model from a file.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="modelPath">The path to the model file.</param>
|
/// <param name="modelPath">The path to the model file.</param>
|
||||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||||
public bool Load(string modelPath)
|
public bool Load(string modelPath)
|
||||||
{
|
{
|
||||||
return NativeMethods.llmodel_loadModel(_handle, modelPath);
|
return NativeMethods.llmodel_loadModel(_handle, modelPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void Destroy()
|
protected void Destroy()
|
||||||
{
|
{
|
||||||
NativeMethods.llmodel_model_destroy(_handle);
|
NativeMethods.llmodel_model_destroy(_handle);
|
||||||
}
|
}
|
||||||
|
protected virtual void Dispose(bool disposing)
|
||||||
protected void DestroyLLama()
|
{
|
||||||
{
|
if (_disposed) return;
|
||||||
NativeMethods.llmodel_llama_destroy(_handle);
|
|
||||||
}
|
if (disposing)
|
||||||
|
{
|
||||||
protected void DestroyGptj()
|
// dispose managed state
|
||||||
{
|
}
|
||||||
NativeMethods.llmodel_gptj_destroy(_handle);
|
|
||||||
}
|
switch (_modelType)
|
||||||
|
{
|
||||||
protected void DestroyMtp()
|
default:
|
||||||
{
|
Destroy();
|
||||||
NativeMethods.llmodel_mpt_destroy(_handle);
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected virtual void Dispose(bool disposing)
|
_disposed = true;
|
||||||
{
|
}
|
||||||
if (_disposed) return;
|
|
||||||
|
public void Dispose()
|
||||||
if (disposing)
|
{
|
||||||
{
|
Dispose(disposing: true);
|
||||||
// dispose managed state
|
GC.SuppressFinalize(this);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
switch (_modelType)
|
|
||||||
{
|
|
||||||
case ModelType.LLAMA:
|
|
||||||
DestroyLLama();
|
|
||||||
break;
|
|
||||||
case ModelType.GPTJ:
|
|
||||||
DestroyGptj();
|
|
||||||
break;
|
|
||||||
case ModelType.MPT:
|
|
||||||
DestroyMtp();
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
Destroy();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
_disposed = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void Dispose()
|
|
||||||
{
|
|
||||||
Dispose(disposing: true);
|
|
||||||
GC.SuppressFinalize(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,138 +1,138 @@
|
|||||||
namespace Gpt4All.Bindings;
|
namespace Gpt4All.Bindings;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
|
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <remarks>
|
/// <remarks>
|
||||||
/// The implementation takes care of all the memory handling of the raw logits pointer and the
|
/// The implementation takes care of all the memory handling of the raw logits pointer and the
|
||||||
/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
|
/// raw tokens pointer.Attempting to resize them or modify them in any way can lead to undefined behavior
|
||||||
/// </remarks>
|
/// </remarks>
|
||||||
public unsafe class LLModelPromptContext
|
public unsafe class LLModelPromptContext
|
||||||
{
|
{
|
||||||
private llmodel_prompt_context _ctx;
|
private llmodel_prompt_context _ctx;
|
||||||
|
|
||||||
internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
|
internal ref llmodel_prompt_context UnderlyingContext => ref _ctx;
|
||||||
|
|
||||||
public LLModelPromptContext()
|
public LLModelPromptContext()
|
||||||
{
|
{
|
||||||
_ctx = new();
|
_ctx = new();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// logits of current context
|
/// logits of current context
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public Span<float> Logits => new(_ctx.logits, (int)_ctx.logits_size);
|
public Span<float> Logits => new(_ctx.logits, (int)_ctx.logits_size);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// the size of the raw logits vector
|
/// the size of the raw logits vector
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public nuint LogitsSize
|
public nuint LogitsSize
|
||||||
{
|
{
|
||||||
get => _ctx.logits_size;
|
get => _ctx.logits_size;
|
||||||
set => _ctx.logits_size = value;
|
set => _ctx.logits_size = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// current tokens in the context window
|
/// current tokens in the context window
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public Span<int> Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
|
public Span<int> Tokens => new(_ctx.tokens, (int)_ctx.tokens_size);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// the size of the raw tokens vector
|
/// the size of the raw tokens vector
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public nuint TokensSize
|
public nuint TokensSize
|
||||||
{
|
{
|
||||||
get => _ctx.tokens_size;
|
get => _ctx.tokens_size;
|
||||||
set => _ctx.tokens_size = value;
|
set => _ctx.tokens_size = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// top k logits to sample from
|
/// top k logits to sample from
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int TopK
|
public int TopK
|
||||||
{
|
{
|
||||||
get => _ctx.top_k;
|
get => _ctx.top_k;
|
||||||
set => _ctx.top_k = value;
|
set => _ctx.top_k = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// nucleus sampling probability threshold
|
/// nucleus sampling probability threshold
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public float TopP
|
public float TopP
|
||||||
{
|
{
|
||||||
get => _ctx.top_p;
|
get => _ctx.top_p;
|
||||||
set => _ctx.top_p = value;
|
set => _ctx.top_p = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// temperature to adjust model's output distribution
|
/// temperature to adjust model's output distribution
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public float Temperature
|
public float Temperature
|
||||||
{
|
{
|
||||||
get => _ctx.temp;
|
get => _ctx.temp;
|
||||||
set => _ctx.temp = value;
|
set => _ctx.temp = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// number of tokens in past conversation
|
/// number of tokens in past conversation
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int PastNum
|
public int PastNum
|
||||||
{
|
{
|
||||||
get => _ctx.n_past;
|
get => _ctx.n_past;
|
||||||
set => _ctx.n_past = value;
|
set => _ctx.n_past = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// number of predictions to generate in parallel
|
/// number of predictions to generate in parallel
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int Batches
|
public int Batches
|
||||||
{
|
{
|
||||||
get => _ctx.n_batch;
|
get => _ctx.n_batch;
|
||||||
set => _ctx.n_batch = value;
|
set => _ctx.n_batch = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// number of tokens to predict
|
/// number of tokens to predict
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int TokensToPredict
|
public int TokensToPredict
|
||||||
{
|
{
|
||||||
get => _ctx.n_predict;
|
get => _ctx.n_predict;
|
||||||
set => _ctx.n_predict = value;
|
set => _ctx.n_predict = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// penalty factor for repeated tokens
|
/// penalty factor for repeated tokens
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public float RepeatPenalty
|
public float RepeatPenalty
|
||||||
{
|
{
|
||||||
get => _ctx.repeat_penalty;
|
get => _ctx.repeat_penalty;
|
||||||
set => _ctx.repeat_penalty = value;
|
set => _ctx.repeat_penalty = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// last n tokens to penalize
|
/// last n tokens to penalize
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int RepeatLastN
|
public int RepeatLastN
|
||||||
{
|
{
|
||||||
get => _ctx.repeat_last_n;
|
get => _ctx.repeat_last_n;
|
||||||
set => _ctx.repeat_last_n = value;
|
set => _ctx.repeat_last_n = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// number of tokens possible in context window
|
/// number of tokens possible in context window
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int ContextSize
|
public int ContextSize
|
||||||
{
|
{
|
||||||
get => _ctx.n_ctx;
|
get => _ctx.n_ctx;
|
||||||
set => _ctx.n_ctx = value;
|
set => _ctx.n_ctx = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// percent of context to erase if we exceed the context window
|
/// percent of context to erase if we exceed the context window
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public float ContextErase
|
public float ContextErase
|
||||||
{
|
{
|
||||||
get => _ctx.context_erase;
|
get => _ctx.context_erase;
|
||||||
set => _ctx.context_erase = value;
|
set => _ctx.context_erase = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,126 +1,107 @@
|
|||||||
using System.Runtime.InteropServices;
|
using System.Runtime.InteropServices;
|
||||||
|
|
||||||
namespace Gpt4All.Bindings;
|
namespace Gpt4All.Bindings;
|
||||||
|
|
||||||
public unsafe partial struct llmodel_prompt_context
|
public unsafe partial struct llmodel_prompt_context
|
||||||
{
|
{
|
||||||
public float* logits;
|
public float* logits;
|
||||||
|
|
||||||
[NativeTypeName("size_t")]
|
[NativeTypeName("size_t")]
|
||||||
public nuint logits_size;
|
public nuint logits_size;
|
||||||
|
|
||||||
[NativeTypeName("int32_t *")]
|
[NativeTypeName("int32_t *")]
|
||||||
public int* tokens;
|
public int* tokens;
|
||||||
|
|
||||||
[NativeTypeName("size_t")]
|
[NativeTypeName("size_t")]
|
||||||
public nuint tokens_size;
|
public nuint tokens_size;
|
||||||
|
|
||||||
[NativeTypeName("int32_t")]
|
[NativeTypeName("int32_t")]
|
||||||
public int n_past;
|
public int n_past;
|
||||||
|
|
||||||
[NativeTypeName("int32_t")]
|
[NativeTypeName("int32_t")]
|
||||||
public int n_ctx;
|
public int n_ctx;
|
||||||
|
|
||||||
[NativeTypeName("int32_t")]
|
[NativeTypeName("int32_t")]
|
||||||
public int n_predict;
|
public int n_predict;
|
||||||
|
|
||||||
[NativeTypeName("int32_t")]
|
[NativeTypeName("int32_t")]
|
||||||
public int top_k;
|
public int top_k;
|
||||||
|
|
||||||
public float top_p;
|
public float top_p;
|
||||||
|
|
||||||
public float temp;
|
public float temp;
|
||||||
|
|
||||||
[NativeTypeName("int32_t")]
|
[NativeTypeName("int32_t")]
|
||||||
public int n_batch;
|
public int n_batch;
|
||||||
|
|
||||||
public float repeat_penalty;
|
public float repeat_penalty;
|
||||||
|
|
||||||
[NativeTypeName("int32_t")]
|
[NativeTypeName("int32_t")]
|
||||||
public int repeat_last_n;
|
public int repeat_last_n;
|
||||||
|
|
||||||
public float context_erase;
|
public float context_erase;
|
||||||
}
|
}
|
||||||
|
|
||||||
internal static unsafe partial class NativeMethods
|
internal static unsafe partial class NativeMethods
|
||||||
{
|
{
|
||||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||||
[return: MarshalAs(UnmanagedType.I1)]
|
[return: MarshalAs(UnmanagedType.I1)]
|
||||||
public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
|
public delegate bool LlmodelResponseCallback(int token_id, [MarshalAs(UnmanagedType.LPUTF8Str)] string response);
|
||||||
|
|
||||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||||
[return: MarshalAs(UnmanagedType.I1)]
|
[return: MarshalAs(UnmanagedType.I1)]
|
||||||
public delegate bool LlmodelPromptCallback(int token_id);
|
public delegate bool LlmodelPromptCallback(int token_id);
|
||||||
|
|
||||||
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
|
||||||
[return: MarshalAs(UnmanagedType.I1)]
|
[return: MarshalAs(UnmanagedType.I1)]
|
||||||
public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
|
public delegate bool LlmodelRecalculateCallback(bool isRecalculating);
|
||||||
|
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||||
[return: NativeTypeName("llmodel_model")]
|
[return: NativeTypeName("llmodel_model")]
|
||||||
public static extern IntPtr llmodel_gptj_create();
|
public static extern IntPtr llmodel_model_create2(
|
||||||
|
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string build_variant,
|
||||||
public static extern void llmodel_gptj_destroy([NativeTypeName("llmodel_model")] IntPtr gptj);
|
out IntPtr error);
|
||||||
|
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||||
[return: NativeTypeName("llmodel_model")]
|
public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
|
||||||
public static extern IntPtr llmodel_mpt_create();
|
|
||||||
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
[return: MarshalAs(UnmanagedType.I1)]
|
||||||
public static extern void llmodel_mpt_destroy([NativeTypeName("llmodel_model")] IntPtr mpt);
|
public static extern bool llmodel_loadModel(
|
||||||
|
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
||||||
[return: NativeTypeName("llmodel_model")]
|
|
||||||
public static extern IntPtr llmodel_llama_create();
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||||
|
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
[return: MarshalAs(UnmanagedType.I1)]
|
||||||
public static extern void llmodel_llama_destroy([NativeTypeName("llmodel_model")] IntPtr llama);
|
public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
|
||||||
|
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||||
[return: NativeTypeName("llmodel_model")]
|
[return: NativeTypeName("uint64_t")]
|
||||||
public static extern IntPtr llmodel_model_create(
|
public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
|
||||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
|
||||||
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
[return: NativeTypeName("uint64_t")]
|
||||||
public static extern void llmodel_model_destroy([NativeTypeName("llmodel_model")] IntPtr model);
|
public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
|
||||||
|
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||||
[return: MarshalAs(UnmanagedType.I1)]
|
[return: NativeTypeName("uint64_t")]
|
||||||
public static extern bool llmodel_loadModel(
|
public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
|
||||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
|
||||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
||||||
|
public static extern void llmodel_prompt(
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||||
|
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
||||||
[return: MarshalAs(UnmanagedType.I1)]
|
LlmodelPromptCallback prompt_callback,
|
||||||
public static extern bool llmodel_isModelLoaded([NativeTypeName("llmodel_model")] IntPtr model);
|
LlmodelResponseCallback response_callback,
|
||||||
|
LlmodelRecalculateCallback recalculate_callback,
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
ref llmodel_prompt_context ctx);
|
||||||
[return: NativeTypeName("uint64_t")]
|
|
||||||
public static extern ulong llmodel_get_state_size([NativeTypeName("llmodel_model")] IntPtr model);
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||||
|
public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
|
||||||
[return: NativeTypeName("uint64_t")]
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||||
public static extern ulong llmodel_save_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("uint8_t *")] byte* dest);
|
[return: NativeTypeName("int32_t")]
|
||||||
|
public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
}
|
||||||
[return: NativeTypeName("uint64_t")]
|
|
||||||
public static extern ulong llmodel_restore_state_data([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("const uint8_t *")] byte* src);
|
|
||||||
|
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true, BestFitMapping = false, ThrowOnUnmappableChar = true)]
|
|
||||||
public static extern void llmodel_prompt(
|
|
||||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
|
||||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
|
|
||||||
LlmodelPromptCallback prompt_callback,
|
|
||||||
LlmodelResponseCallback response_callback,
|
|
||||||
LlmodelRecalculateCallback recalculate_callback,
|
|
||||||
ref llmodel_prompt_context ctx);
|
|
||||||
|
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
|
||||||
public static extern void llmodel_setThreadCount([NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("int32_t")] int n_threads);
|
|
||||||
|
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
|
||||||
[return: NativeTypeName("int32_t")]
|
|
||||||
public static extern int llmodel_threadCount([NativeTypeName("llmodel_model")] IntPtr model);
|
|
||||||
}
|
|
||||||
|
@ -1,27 +1,11 @@
|
|||||||
<Project Sdk="Microsoft.NET.Sdk">
|
<Project Sdk="Microsoft.NET.Sdk">
|
||||||
|
<PropertyGroup>
|
||||||
<PropertyGroup>
|
<TargetFramework>net6.0</TargetFramework>
|
||||||
<TargetFrameworks>net6.0</TargetFrameworks>
|
<ImplicitUsings>enable</ImplicitUsings>
|
||||||
<ImplicitUsings>enable</ImplicitUsings>
|
<Nullable>enable</Nullable>
|
||||||
<Nullable>enable</Nullable>
|
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
|
||||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
|
</PropertyGroup>
|
||||||
</PropertyGroup>
|
<ItemGroup>
|
||||||
|
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="7.0.0" />
|
||||||
<ItemGroup>
|
</ItemGroup>
|
||||||
<!-- Windows -->
|
|
||||||
<None Include="..\runtimes\win-x64\native\*.dll" Pack="true" PackagePath="runtimes\win-x64\native\%(Filename)%(Extension)" />
|
|
||||||
<!-- Linux -->
|
|
||||||
<None Include="..\runtimes\linux-x64\native\*.so" Pack="true" PackagePath="runtimes\linux-x64\native\%(Filename)%(Extension)" />
|
|
||||||
</ItemGroup>
|
|
||||||
|
|
||||||
<ItemGroup>
|
|
||||||
<!-- Windows -->
|
|
||||||
<None Condition="$([MSBuild]::IsOSPlatform('Windows'))" Include="..\runtimes\win-x64\native\*.dll" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
|
||||||
<!-- Linux -->
|
|
||||||
<None Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="..\runtimes\linux-x64\native\*.so" Visible="False" CopyToOutputDirectory="PreserveNewest" />
|
|
||||||
</ItemGroup>
|
|
||||||
|
|
||||||
<ItemGroup>
|
|
||||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="7.0.0" />
|
|
||||||
</ItemGroup>
|
|
||||||
</Project>
|
</Project>
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
namespace Gpt4All.LibraryLoader;
|
||||||
|
|
||||||
|
public interface ILibraryLoader
|
||||||
|
{
|
||||||
|
LoadResult OpenLibrary(string? fileName);
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
using System.Runtime.InteropServices;
|
||||||
|
|
||||||
|
namespace Gpt4All.LibraryLoader;
|
||||||
|
|
||||||
|
internal class LinuxLibraryLoader : ILibraryLoader
|
||||||
|
{
|
||||||
|
#pragma warning disable CA2101
|
||||||
|
[DllImport("libdl.so", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")]
|
||||||
|
#pragma warning restore CA2101
|
||||||
|
public static extern IntPtr NativeOpenLibraryLibdl(string? filename, int flags);
|
||||||
|
|
||||||
|
#pragma warning disable CA2101
|
||||||
|
[DllImport("libdl.so.2", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")]
|
||||||
|
#pragma warning restore CA2101
|
||||||
|
public static extern IntPtr NativeOpenLibraryLibdl2(string? filename, int flags);
|
||||||
|
|
||||||
|
[DllImport("libdl.so", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")]
|
||||||
|
public static extern IntPtr GetLoadError();
|
||||||
|
|
||||||
|
[DllImport("libdl.so.2", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")]
|
||||||
|
public static extern IntPtr GetLoadError2();
|
||||||
|
|
||||||
|
public LoadResult OpenLibrary(string? fileName)
|
||||||
|
{
|
||||||
|
IntPtr loadedLib;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
// open with rtls lazy flag
|
||||||
|
loadedLib = NativeOpenLibraryLibdl2(fileName, 0x00001);
|
||||||
|
}
|
||||||
|
catch (DllNotFoundException)
|
||||||
|
{
|
||||||
|
loadedLib = NativeOpenLibraryLibdl(fileName, 0x00001);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (loadedLib == IntPtr.Zero)
|
||||||
|
{
|
||||||
|
string errorMessage;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
errorMessage = Marshal.PtrToStringAnsi(GetLoadError2()) ?? "Unknown error";
|
||||||
|
}
|
||||||
|
catch (DllNotFoundException)
|
||||||
|
{
|
||||||
|
errorMessage = Marshal.PtrToStringAnsi(GetLoadError()) ?? "Unknown error";
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoadResult.Failure(errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoadResult.Success;
|
||||||
|
}
|
||||||
|
}
|
20
gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LoadResult.cs
Normal file
20
gpt4all-bindings/csharp/Gpt4All/LibraryLoader/LoadResult.cs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
namespace Gpt4All.LibraryLoader;
|
||||||
|
|
||||||
|
public class LoadResult
|
||||||
|
{
|
||||||
|
private LoadResult(bool isSuccess, string? errorMessage)
|
||||||
|
{
|
||||||
|
IsSuccess = isSuccess;
|
||||||
|
ErrorMessage = errorMessage;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static LoadResult Success { get; } = new(true, null);
|
||||||
|
|
||||||
|
public static LoadResult Failure(string errorMessage)
|
||||||
|
{
|
||||||
|
return new(false, errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
public bool IsSuccess { get; }
|
||||||
|
public string? ErrorMessage { get; }
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
using System.Runtime.InteropServices;
|
||||||
|
|
||||||
|
namespace Gpt4All.LibraryLoader;
|
||||||
|
|
||||||
|
internal class MacOsLibraryLoader : ILibraryLoader
|
||||||
|
{
|
||||||
|
#pragma warning disable CA2101
|
||||||
|
[DllImport("libdl.dylib", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlopen")]
|
||||||
|
#pragma warning restore CA2101
|
||||||
|
public static extern IntPtr NativeOpenLibraryLibdl(string? filename, int flags);
|
||||||
|
|
||||||
|
[DllImport("libdl.dylib", ExactSpelling = true, CharSet = CharSet.Auto, EntryPoint = "dlerror")]
|
||||||
|
public static extern IntPtr GetLoadError();
|
||||||
|
|
||||||
|
public LoadResult OpenLibrary(string? fileName)
|
||||||
|
{
|
||||||
|
var loadedLib = NativeOpenLibraryLibdl(fileName, 0x00001);
|
||||||
|
|
||||||
|
if (loadedLib == IntPtr.Zero)
|
||||||
|
{
|
||||||
|
var errorMessage = Marshal.PtrToStringAnsi(GetLoadError()) ?? "Unknown error";
|
||||||
|
|
||||||
|
return LoadResult.Failure(errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoadResult.Success;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,81 @@
|
|||||||
|
#if !IOS && !MACCATALYST && !TVOS && !ANDROID
|
||||||
|
using System.Runtime.InteropServices;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace Gpt4All.LibraryLoader;
|
||||||
|
|
||||||
|
public static class NativeLibraryLoader
|
||||||
|
{
|
||||||
|
private static ILibraryLoader? defaultLibraryLoader;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Sets the library loader used to load the native libraries. Overwrite this only if you want some custom loading.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="libraryLoader">The library loader to be used.</param>
|
||||||
|
public static void SetLibraryLoader(ILibraryLoader libraryLoader)
|
||||||
|
{
|
||||||
|
defaultLibraryLoader = libraryLoader;
|
||||||
|
}
|
||||||
|
|
||||||
|
internal static LoadResult LoadNativeLibrary(string? path = default, bool bypassLoading = true)
|
||||||
|
{
|
||||||
|
// If the user has handled loading the library themselves, we don't need to do anything.
|
||||||
|
if (bypassLoading)
|
||||||
|
{
|
||||||
|
return LoadResult.Success;
|
||||||
|
}
|
||||||
|
|
||||||
|
var architecture = RuntimeInformation.OSArchitecture switch
|
||||||
|
{
|
||||||
|
Architecture.X64 => "x64",
|
||||||
|
Architecture.X86 => "x86",
|
||||||
|
Architecture.Arm => "arm",
|
||||||
|
Architecture.Arm64 => "arm64",
|
||||||
|
_ => throw new PlatformNotSupportedException(
|
||||||
|
$"Unsupported OS platform, architecture: {RuntimeInformation.OSArchitecture}")
|
||||||
|
};
|
||||||
|
|
||||||
|
var (platform, extension) = Environment.OSVersion.Platform switch
|
||||||
|
{
|
||||||
|
_ when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ("win", "dll"),
|
||||||
|
_ when RuntimeInformation.IsOSPlatform(OSPlatform.Linux) => ("linux", "so"),
|
||||||
|
_ when RuntimeInformation.IsOSPlatform(OSPlatform.OSX) => ("osx", "dylib"),
|
||||||
|
_ => throw new PlatformNotSupportedException(
|
||||||
|
$"Unsupported OS platform, architecture: {RuntimeInformation.OSArchitecture}")
|
||||||
|
};
|
||||||
|
|
||||||
|
// If the user hasn't set the path, we'll try to find it ourselves.
|
||||||
|
if (string.IsNullOrEmpty(path))
|
||||||
|
{
|
||||||
|
var libraryName = "libllmodel";
|
||||||
|
var assemblySearchPath = new[]
|
||||||
|
{
|
||||||
|
AppDomain.CurrentDomain.RelativeSearchPath,
|
||||||
|
Path.GetDirectoryName(typeof(NativeLibraryLoader).Assembly.Location),
|
||||||
|
Path.GetDirectoryName(Environment.GetCommandLineArgs()[0])
|
||||||
|
}.FirstOrDefault(it => !string.IsNullOrEmpty(it));
|
||||||
|
// Search for the library dll within the assembly search path. If it doesn't exist, for whatever reason, use the default path.
|
||||||
|
path = Directory.EnumerateFiles(assemblySearchPath ?? string.Empty, $"{libraryName}.{extension}", SearchOption.AllDirectories).FirstOrDefault() ?? Path.Combine("runtimes", $"{platform}-{architecture}", $"{libraryName}.{extension}");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (defaultLibraryLoader != null)
|
||||||
|
{
|
||||||
|
return defaultLibraryLoader.OpenLibrary(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!File.Exists(path))
|
||||||
|
{
|
||||||
|
throw new FileNotFoundException($"Native Library not found in path {path}. " +
|
||||||
|
$"Verify you have have included the native Gpt4All library in your application.");
|
||||||
|
}
|
||||||
|
|
||||||
|
ILibraryLoader libraryLoader = platform switch
|
||||||
|
{
|
||||||
|
"win" => new WindowsLibraryLoader(),
|
||||||
|
"osx" => new MacOsLibraryLoader(),
|
||||||
|
"linux" => new LinuxLibraryLoader(),
|
||||||
|
_ => throw new PlatformNotSupportedException($"Currently {platform} platform is not supported")
|
||||||
|
};
|
||||||
|
return libraryLoader.OpenLibrary(path);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,24 @@
|
|||||||
|
using System.ComponentModel;
|
||||||
|
using System.Runtime.InteropServices;
|
||||||
|
|
||||||
|
namespace Gpt4All.LibraryLoader;
|
||||||
|
|
||||||
|
internal class WindowsLibraryLoader : ILibraryLoader
|
||||||
|
{
|
||||||
|
public LoadResult OpenLibrary(string? fileName)
|
||||||
|
{
|
||||||
|
var loadedLib = LoadLibrary(fileName);
|
||||||
|
|
||||||
|
if (loadedLib == IntPtr.Zero)
|
||||||
|
{
|
||||||
|
var errorCode = Marshal.GetLastWin32Error();
|
||||||
|
var errorMessage = new Win32Exception(errorCode).Message;
|
||||||
|
return LoadResult.Failure(errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoadResult.Success;
|
||||||
|
}
|
||||||
|
|
||||||
|
[DllImport("kernel32", SetLastError = true, CharSet = CharSet.Auto)]
|
||||||
|
private static extern IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPWStr)] string? lpFileName);
|
||||||
|
}
|
@ -1,61 +1,58 @@
|
|||||||
using System.Diagnostics;
|
using System.Diagnostics;
|
||||||
using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Logging.Abstractions;
|
||||||
using Gpt4All.Bindings;
|
using Microsoft.Extensions.Logging;
|
||||||
using Microsoft.Extensions.Logging.Abstractions;
|
using Gpt4All.Bindings;
|
||||||
|
using Gpt4All.LibraryLoader;
|
||||||
namespace Gpt4All;
|
|
||||||
|
namespace Gpt4All;
|
||||||
public class Gpt4AllModelFactory : IGpt4AllModelFactory
|
|
||||||
{
|
public class Gpt4AllModelFactory : IGpt4AllModelFactory
|
||||||
private readonly ILoggerFactory _loggerFactory;
|
{
|
||||||
private readonly ILogger _logger;
|
private readonly ILoggerFactory _loggerFactory;
|
||||||
|
private readonly ILogger _logger;
|
||||||
public Gpt4AllModelFactory(ILoggerFactory? loggerFactory = null)
|
private static bool bypassLoading;
|
||||||
{
|
private static string? libraryPath;
|
||||||
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
|
|
||||||
_logger = _loggerFactory.CreateLogger<Gpt4AllModelFactory>();
|
private static readonly Lazy<LoadResult> libraryLoaded = new(() =>
|
||||||
}
|
{
|
||||||
|
return NativeLibraryLoader.LoadNativeLibrary(Gpt4AllModelFactory.libraryPath, Gpt4AllModelFactory.bypassLoading);
|
||||||
private IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null)
|
}, true);
|
||||||
{
|
|
||||||
var modelType_ = modelType ?? ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath);
|
public Gpt4AllModelFactory(string? libraryPath = default, bool bypassLoading = true, ILoggerFactory? loggerFactory = null)
|
||||||
|
{
|
||||||
_logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_);
|
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
|
||||||
|
_logger = _loggerFactory.CreateLogger<Gpt4AllModelFactory>();
|
||||||
var handle = modelType_ switch
|
Gpt4AllModelFactory.libraryPath = libraryPath;
|
||||||
{
|
Gpt4AllModelFactory.bypassLoading = bypassLoading;
|
||||||
ModelType.LLAMA => NativeMethods.llmodel_llama_create(),
|
|
||||||
ModelType.GPTJ => NativeMethods.llmodel_gptj_create(),
|
if (!libraryLoaded.Value.IsSuccess)
|
||||||
ModelType.MPT => NativeMethods.llmodel_mpt_create(),
|
{
|
||||||
_ => NativeMethods.llmodel_model_create(modelPath),
|
throw new Exception($"Failed to load native gpt4all library. Error: {libraryLoaded.Value.ErrorMessage}");
|
||||||
};
|
}
|
||||||
|
}
|
||||||
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
|
|
||||||
_logger.LogInformation("Model loading started");
|
private IGpt4AllModel CreateModel(string modelPath)
|
||||||
|
{
|
||||||
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
|
var modelType_ = ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath);
|
||||||
|
_logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_);
|
||||||
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
|
IntPtr error;
|
||||||
|
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
|
||||||
if (loadedSuccessfully == false)
|
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
|
||||||
{
|
_logger.LogInformation("Model loading started");
|
||||||
throw new Exception($"Failed to load model: '{modelPath}'");
|
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
|
||||||
}
|
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
|
||||||
|
if (!loadedSuccessfully)
|
||||||
var logger = _loggerFactory.CreateLogger<LLModel>();
|
{
|
||||||
|
throw new Exception($"Failed to load model: '{modelPath}'");
|
||||||
var underlyingModel = LLModel.Create(handle, modelType_, logger: logger);
|
}
|
||||||
|
|
||||||
Debug.Assert(underlyingModel.IsLoaded());
|
var logger = _loggerFactory.CreateLogger<LLModel>();
|
||||||
|
var underlyingModel = LLModel.Create(handle, modelType_, logger: logger);
|
||||||
return new Gpt4All(underlyingModel, logger: logger);
|
|
||||||
}
|
Debug.Assert(underlyingModel.IsLoaded());
|
||||||
|
|
||||||
public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath, modelType: null);
|
return new Gpt4All(underlyingModel, logger: logger);
|
||||||
|
}
|
||||||
public IGpt4AllModel LoadMptModel(string modelPath) => CreateModel(modelPath, ModelType.MPT);
|
|
||||||
|
public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath);
|
||||||
public IGpt4AllModel LoadGptjModel(string modelPath) => CreateModel(modelPath, ModelType.GPTJ);
|
}
|
||||||
|
|
||||||
public IGpt4AllModel LoadLlamaModel(string modelPath) => CreateModel(modelPath, ModelType.LLAMA);
|
|
||||||
}
|
|
||||||
|
@ -1,12 +1,6 @@
|
|||||||
namespace Gpt4All;
|
namespace Gpt4All;
|
||||||
|
|
||||||
public interface IGpt4AllModelFactory
|
public interface IGpt4AllModelFactory
|
||||||
{
|
{
|
||||||
IGpt4AllModel LoadGptjModel(string modelPath);
|
IGpt4AllModel LoadModel(string modelPath);
|
||||||
|
}
|
||||||
IGpt4AllModel LoadLlamaModel(string modelPath);
|
|
||||||
|
|
||||||
IGpt4AllModel LoadModel(string modelPath);
|
|
||||||
|
|
||||||
IGpt4AllModel LoadMptModel(string modelPath);
|
|
||||||
}
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
namespace Gpt4All;
|
namespace Gpt4All;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// The supported model types
|
/// The supported model types
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public enum ModelType
|
public enum ModelType
|
||||||
{
|
{
|
||||||
LLAMA = 0,
|
LLAMA = 0,
|
||||||
GPTJ,
|
GPTJ,
|
||||||
MPT
|
MPT
|
||||||
}
|
}
|
||||||
|
@ -1,31 +1,31 @@
|
|||||||
namespace Gpt4All;
|
namespace Gpt4All;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Interface for text prediction services
|
/// Interface for text prediction services
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public interface ITextPrediction
|
public interface ITextPrediction
|
||||||
{
|
{
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get prediction results for the prompt and provided options.
|
/// Get prediction results for the prompt and provided options.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to complete</param>
|
/// <param name="text">The text to complete</param>
|
||||||
/// <param name="opts">The prediction settings</param>
|
/// <param name="opts">The prediction settings</param>
|
||||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
/// <param name="cancellation">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||||
/// <returns>The prediction result generated by the model</returns>
|
/// <returns>The prediction result generated by the model</returns>
|
||||||
Task<ITextPredictionResult> GetPredictionAsync(
|
Task<ITextPredictionResult> GetPredictionAsync(
|
||||||
string text,
|
string text,
|
||||||
PredictRequestOptions opts,
|
PredictRequestOptions opts,
|
||||||
CancellationToken cancellation = default);
|
CancellationToken cancellation = default);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get streaming prediction results for the prompt and provided options.
|
/// Get streaming prediction results for the prompt and provided options.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="text">The text to complete</param>
|
/// <param name="text">The text to complete</param>
|
||||||
/// <param name="opts">The prediction settings</param>
|
/// <param name="opts">The prediction settings</param>
|
||||||
/// <param name="cancellationToken">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
/// <param name="cancellationToken">The <see cref="CancellationToken"/> for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
|
||||||
/// <returns>The prediction result generated by the model</returns>
|
/// <returns>The prediction result generated by the model</returns>
|
||||||
Task<ITextPredictionStreamingResult> GetStreamingPredictionAsync(
|
Task<ITextPredictionStreamingResult> GetStreamingPredictionAsync(
|
||||||
string text,
|
string text,
|
||||||
PredictRequestOptions opts,
|
PredictRequestOptions opts,
|
||||||
CancellationToken cancellationToken = default);
|
CancellationToken cancellationToken = default);
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,6 @@ mkdir runtimes/linux-x64/build
|
|||||||
cmake -S ../../gpt4all-backend -B runtimes/linux-x64/build
|
cmake -S ../../gpt4all-backend -B runtimes/linux-x64/build
|
||||||
cmake --build runtimes/linux-x64/build --parallel --config Release
|
cmake --build runtimes/linux-x64/build --parallel --config Release
|
||||||
cp runtimes/linux-x64/build/libllmodel.so runtimes/linux-x64/native/libllmodel.so
|
cp runtimes/linux-x64/build/libllmodel.so runtimes/linux-x64/native/libllmodel.so
|
||||||
cp runtimes/linux-x64/build/llama.cpp/libllama.so runtimes/linux-x64/native/libllama.so
|
cp runtimes/linux-x64/build/libgptj*.so runtimes/linux-x64/native/
|
||||||
|
cp runtimes/linux-x64/build/libllama*.so runtimes/linux-x64/native/
|
||||||
|
cp runtimes/linux-x64/build/libmpt*.so runtimes/linux-x64/native/
|
||||||
|
@ -13,4 +13,5 @@ cmake --build $BUILD_DIR --parallel --config Release
|
|||||||
|
|
||||||
# copy native dlls
|
# copy native dlls
|
||||||
cp "C:\ProgramData\chocolatey\lib\mingw\tools\install\mingw64\bin\*dll" $LIBS_DIR
|
cp "C:\ProgramData\chocolatey\lib\mingw\tools\install\mingw64\bin\*dll" $LIBS_DIR
|
||||||
cp "$BUILD_DIR\*.dll" $LIBS_DIR
|
cp "$BUILD_DIR\bin\*.dll" $LIBS_DIR
|
||||||
|
mv $LIBS_DIR\llmodel.dll $LIBS_DIR\libllmodel.dll
|
@ -2,4 +2,5 @@ Remove-Item -Force -Recurse .\runtimes\win-x64\msvc -ErrorAction SilentlyContinu
|
|||||||
mkdir .\runtimes\win-x64\msvc\build | Out-Null
|
mkdir .\runtimes\win-x64\msvc\build | Out-Null
|
||||||
cmake -G "Visual Studio 17 2022" -A X64 -S ..\..\gpt4all-backend -B .\runtimes\win-x64\msvc\build
|
cmake -G "Visual Studio 17 2022" -A X64 -S ..\..\gpt4all-backend -B .\runtimes\win-x64\msvc\build
|
||||||
cmake --build .\runtimes\win-x64\msvc\build --parallel --config Release
|
cmake --build .\runtimes\win-x64\msvc\build --parallel --config Release
|
||||||
cp .\runtimes\win-x64\msvc\build\bin\Release\*.dll .\runtimes\win-x64
|
cp .\runtimes\win-x64\msvc\build\bin\Release\*.dll .\runtimes\win-x64
|
||||||
|
mv .\runtimes\win-x64\llmodel.dll .\runtimes\win-x64\libllmodel.dll
|
Loading…
Reference in New Issue
Block a user