csharp: update C# bindings to work with GGUF (#1651)

This commit is contained in:
Jared Van Bortel 2024-01-16 14:33:41 -05:00 committed by GitHub
parent f8564398fc
commit 03a9f0bedf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 17 additions and 57 deletions

View File

@ -41,6 +41,8 @@ insert_final_newline = true
# IDE0055: Fix formatting # IDE0055: Fix formatting
dotnet_diagnostic.IDE0055.severity = error dotnet_diagnostic.IDE0055.severity = error
dotnet_diagnostic.CS1573.severity = suggestion
dotnet_diagnostic.CS1591.severity = suggestion
# Sort using and Import directives with System.* appearing first # Sort using and Import directives with System.* appearing first
dotnet_sort_system_directives_first = true dotnet_sort_system_directives_first = true
@ -343,4 +345,4 @@ dotnet_diagnostic.IDE2004.severity = warning
[src/{VisualStudio}/**/*.{cs,vb}] [src/{VisualStudio}/**/*.{cs,vb}]
# CA1822: Make member static # CA1822: Make member static
# There is a risk of accidentally breaking an internal API that partners rely on though IVT. # There is a risk of accidentally breaking an internal API that partners rely on though IVT.
dotnet_code_quality.CA1822.api_surface = private dotnet_code_quality.CA1822.api_surface = private

View File

@ -5,6 +5,7 @@
<TargetFramework>net7.0</TargetFramework> <TargetFramework>net7.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings> <ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>

View File

@ -5,6 +5,7 @@
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<IsPackable>false</IsPackable> <IsPackable>false</IsPackable>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>

View File

@ -5,8 +5,6 @@
/// </summary> /// </summary>
public interface ILLModel : IDisposable public interface ILLModel : IDisposable
{ {
ModelType ModelType { get; }
ulong GetStateSizeBytes(); ulong GetStateSizeBytes();
int GetThreadCount(); int GetThreadCount();

View File

@ -42,16 +42,12 @@ public record ModelRecalculatingEventArgs(bool IsRecalculating);
public class LLModel : ILLModel public class LLModel : ILLModel
{ {
protected readonly IntPtr _handle; protected readonly IntPtr _handle;
private readonly ModelType _modelType;
private readonly ILogger _logger; private readonly ILogger _logger;
private bool _disposed; private bool _disposed;
public ModelType ModelType => _modelType; internal LLModel(IntPtr handle, ILogger? logger = null)
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
{ {
_handle = handle; _handle = handle;
_modelType = modelType;
_logger = logger ?? NullLogger.Instance; _logger = logger ?? NullLogger.Instance;
} }
@ -59,10 +55,9 @@ public class LLModel : ILLModel
/// Create a new model from a pointer /// Create a new model from a pointer
/// </summary> /// </summary>
/// <param name="handle">Pointer to underlying model</param> /// <param name="handle">Pointer to underlying model</param>
/// <param name="modelType">The model type</param> public static LLModel Create(IntPtr handle, ILogger? logger = null)
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
{ {
return new LLModel(handle, modelType, logger: logger); return new LLModel(handle, logger: logger);
} }
/// <summary> /// <summary>
@ -204,12 +199,7 @@ public class LLModel : ILLModel
// dispose managed state // dispose managed state
} }
switch (_modelType) Destroy();
{
default:
Destroy();
break;
}
_disposed = true; _disposed = true;
} }

View File

@ -4,6 +4,7 @@
<ImplicitUsings>enable</ImplicitUsings> <ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<!-- Windows --> <!-- Windows -->

View File

@ -3,6 +3,7 @@ using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Gpt4All.Bindings; using Gpt4All.Bindings;
using Gpt4All.LibraryLoader; using Gpt4All.LibraryLoader;
using System.Runtime.InteropServices;
namespace Gpt4All; namespace Gpt4All;
@ -33,10 +34,13 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory
private IGpt4AllModel CreateModel(string modelPath) private IGpt4AllModel CreateModel(string modelPath)
{ {
var modelType_ = ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath); _logger.LogInformation("Creating model path={ModelPath}", modelPath);
_logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_);
IntPtr error; IntPtr error;
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error); var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
if (error != IntPtr.Zero)
{
throw new Exception(Marshal.PtrToStringAnsi(error));
}
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle); _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
_logger.LogInformation("Model loading started"); _logger.LogInformation("Model loading started");
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath, 2048); var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath, 2048);
@ -47,7 +51,7 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory
} }
var logger = _loggerFactory.CreateLogger<LLModel>(); var logger = _loggerFactory.CreateLogger<LLModel>();
var underlyingModel = LLModel.Create(handle, modelType_, logger: logger); var underlyingModel = LLModel.Create(handle, logger: logger);
Debug.Assert(underlyingModel.IsLoaded()); Debug.Assert(underlyingModel.IsLoaded());

View File

@ -1,24 +0,0 @@
namespace Gpt4All;
public static class ModelFileUtils
{
private const uint GPTJ_MAGIC = 0x67676d6c;
private const uint LLAMA_MAGIC = 0x67676a74;
private const uint MPT_MAGIC = 0x67676d6d;
public static ModelType GetModelTypeFromModelFileHeader(string modelPath)
{
using var fileStream = new FileStream(modelPath, FileMode.Open);
using var binReader = new BinaryReader(fileStream);
var magic = binReader.ReadUInt32();
return magic switch
{
GPTJ_MAGIC => ModelType.GPTJ,
LLAMA_MAGIC => ModelType.LLAMA,
MPT_MAGIC => ModelType.MPT,
_ => throw new ArgumentOutOfRangeException($"Invalid model file. magic=0x{magic:X8}"),
};
}
}

View File

@ -3,6 +3,4 @@
public record ModelOptions public record ModelOptions
{ {
public int Threads { get; init; } = 4; public int Threads { get; init; } = 4;
public ModelType ModelType { get; init; } = ModelType.GPTJ;
} }

View File

@ -1,11 +0,0 @@
namespace Gpt4All;
/// <summary>
/// The supported model types
/// </summary>
public enum ModelType
{
LLAMA = 0,
GPTJ,
MPT
}