diff --git a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java index 2e51a245..6758ccdf 100644 --- a/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java +++ b/gpt4all-bindings/java/src/main/java/com/hexadevlabs/gpt4all/LLModel.java @@ -8,9 +8,8 @@ import java.io.ByteArrayOutputStream; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; +import java.util.stream.Collectors; public class LLModel implements AutoCloseable { @@ -306,6 +305,197 @@ public class LLModel implements AutoCloseable { }; } + /** + * The array of messages for the conversation. + */ + public static class Messages { + + private final List messages = new ArrayList<>(); + + public Messages(PromptMessage...messages) { + this.messages.addAll(Arrays.asList(messages)); + } + + public Messages(List messages) { + this.messages.addAll(messages); + } + + public Messages addPromptMessage(PromptMessage promptMessage) { + this.messages.add(promptMessage); + return this; + } + + List toList() { + return Collections.unmodifiableList(this.messages); + } + + List> toListMap() { + return messages.stream() + .map(PromptMessage::toMap).collect(Collectors.toList()); + } + + } + + /** + * A message in the conversation, identical to OpenAI's chat message. + */ + public static class PromptMessage { + + private static final String ROLE = "role"; + private static final String CONTENT = "content"; + + private final Map message = new HashMap<>(); + + public PromptMessage() { + } + + public PromptMessage(Role role, String content) { + addRole(role); + addContent(content); + } + + public PromptMessage addRole(Role role) { + return this.addParameter(ROLE, role.type()); + } + + public PromptMessage addContent(String content) { + return this.addParameter(CONTENT, content); + } + + public PromptMessage addParameter(String key, String value) { + this.message.put(key, value); + return this; + } + + public String content() { + return this.parameter(CONTENT); + } + + public Role role() { + String role = this.parameter(ROLE); + return Role.from(role); + } + + public String parameter(String key) { + return this.message.get(key); + } + + Map toMap() { + return Collections.unmodifiableMap(this.message); + } + + } + + public enum Role { + + SYSTEM("system"), ASSISTANT("assistant"), USER("user"); + + private final String type; + + String type() { + return this.type; + } + + static Role from(String type) { + + if (type == null) { + return null; + } + + switch (type) { + case "system": return SYSTEM; + case "assistant": return ASSISTANT; + case "user": return USER; + default: throw new IllegalArgumentException( + String.format("You passed %s type but only %s are supported", + type, Arrays.toString(Role.values()) + ) + ); + } + } + + Role(String type) { + this.type = type; + } + + @Override + public String toString() { + return type(); + } + } + + /** + * The result of the completion, similar to OpenAI's format. + */ + public static class CompletionReturn { + private String model; + private Usage usage; + private Choices choices; + + public CompletionReturn(String model, Usage usage, Choices choices) { + this.model = model; + this.usage = usage; + this.choices = choices; + } + + public Choices choices() { + return choices; + } + + public String model() { + return model; + } + + public Usage usage() { + return usage; + } + } + + /** + * The generated completions. + */ + public static class Choices { + + private final List choices = new ArrayList<>(); + + public Choices(List choices) { + this.choices.addAll(choices); + } + + public Choices(CompletionChoice...completionChoices){ + this.choices.addAll(Arrays.asList(completionChoices)); + } + + public Choices addCompletionChoice(CompletionChoice completionChoice) { + this.choices.add(completionChoice); + return this; + } + + public CompletionChoice first() { + return this.choices.get(0); + } + + public int totalChoices() { + return this.choices.size(); + } + + public CompletionChoice get(int index) { + return this.choices.get(index); + } + + public List choices() { + return Collections.unmodifiableList(choices); + } + } + + /** + * A completion choice, similar to OpenAI's format. + */ + public static class CompletionChoice extends PromptMessage { + public CompletionChoice(Role role, String content) { + super(role, content); + } + } public static class ChatCompletionResponse { public String model; @@ -323,6 +513,41 @@ public class LLModel implements AutoCloseable { // Getters and setters } + public CompletionReturn chatCompletionResponse(Messages messages, + GenerationConfig generationConfig) { + return chatCompletion(messages, generationConfig, false, false); + } + + /** + * chatCompletion formats the existing chat conversation into a template to be + * easier to process for chat UIs. It is not absolutely necessary as generate method + * may be directly used to make generations with gpt models. + * + * @param messages object to create theMessages to send to GPT model + * @param generationConfig How to decode/process the generation. + * @param streamToStdOut Send tokens as they are calculated Standard output. + * @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output. + * @return CompletionReturn contains stats and generated Text. + */ + public CompletionReturn chatCompletion(Messages messages, + GenerationConfig generationConfig, boolean streamToStdOut, + boolean outputFullPromptToStdOut) { + + String fullPrompt = buildPrompt(messages.toListMap()); + + if(outputFullPromptToStdOut) + System.out.print(fullPrompt); + + String generatedText = generate(fullPrompt, generationConfig, streamToStdOut); + + final CompletionChoice promptMessage = new CompletionChoice(Role.ASSISTANT, generatedText); + final Choices choices = new Choices(promptMessage); + + final Usage usage = getUsage(fullPrompt, generatedText); + return new CompletionReturn(this.modelName, usage, choices); + + } + public ChatCompletionResponse chatCompletion(List> messages, GenerationConfig generationConfig) { return chatCompletion(messages, generationConfig, false, false); @@ -352,19 +577,23 @@ public class LLModel implements AutoCloseable { ChatCompletionResponse response = new ChatCompletionResponse(); response.model = this.modelName; - Usage usage = new Usage(); - usage.promptTokens = fullPrompt.length(); - usage.completionTokens = generatedText.length(); - usage.totalTokens = fullPrompt.length() + generatedText.length(); - response.usage = usage; + response.usage = getUsage(fullPrompt, generatedText); Map message = new HashMap<>(); message.put("role", "assistant"); message.put("content", generatedText); response.choices = List.of(message); - return response; + + } + + private Usage getUsage(String fullPrompt, String generatedText) { + Usage usage = new Usage(); + usage.promptTokens = fullPrompt.length(); + usage.completionTokens = generatedText.length(); + usage.totalTokens = fullPrompt.length() + generatedText.length(); + return usage; } protected static String buildPrompt(List> messages) { diff --git a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java index 8bc7c914..6f2be894 100644 --- a/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java +++ b/gpt4all-bindings/java/src/test/java/com/hexadevlabs/gpt4all/BasicTests.java @@ -28,6 +28,33 @@ import static org.mockito.Mockito.*; @ExtendWith(MockitoExtension.class) public class BasicTests { + @Test + public void simplePromptWithObject(){ + + LLModel model = Mockito.spy(new LLModel()); + + LLModel.GenerationConfig config = + LLModel.config() + .withNPredict(20) + .build(); + + // The generate method will return "4" + doReturn("4").when( model ).generate(anyString(), eq(config), eq(true)); + + LLModel.PromptMessage promptMessage1 = new LLModel.PromptMessage(LLModel.Role.SYSTEM, "You are a helpful assistant"); + LLModel.PromptMessage promptMessage2 = new LLModel.PromptMessage(LLModel.Role.USER, "Add 2+2"); + + LLModel.Messages messages = new LLModel.Messages(promptMessage1, promptMessage2); + + LLModel.CompletionReturn response = model.chatCompletion( + messages, config, true, true); + + assertTrue( response.choices().first().content().contains("4") ); + + // Verifies the prompt and response are certain length. + assertEquals( 224 , response.usage().totalTokens ); + } + @Test public void simplePrompt(){