mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 09:58:44 +00:00
Adding LLM wrapper for Kobold AI (#7560)
- Description: add wrapper that lets you use KoboldAI api in langchain - Issue: n/a - Dependencies: none extra, just what exists in lanchain - Tag maintainer: @baskaryan - Twitter handle: @zanzibased --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
603a0bea29
commit
50316f6477
@ -0,0 +1,88 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "FPF4vhdZyJ7S"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# KoboldAI API\n",
|
||||||
|
"\n",
|
||||||
|
"[KoboldAI](https://github.com/KoboldAI/KoboldAI-Client) is a \"a browser-based front-end for AI-assisted writing with multiple local & remote AI models...\". It has a public and local API that is able to be used in langchain.\n",
|
||||||
|
"\n",
|
||||||
|
"This example goes over how to use LangChain with that API.\n",
|
||||||
|
"\n",
|
||||||
|
"Documentation can be found in the browser adding /api to the end of your endpoint (i.e http://127.0.0.1/:5000/api).\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"id": "lyzOsRRTf_Vr"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import KoboldApiLLM"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "1a_H7mvfy51O"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"Replace the endpoint seen below with the one shown in the output after starting the webui with --api or --public-api\n",
|
||||||
|
"\n",
|
||||||
|
"Optionally, you can pass in parameters like temperature or max_length"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {
|
||||||
|
"id": "g3vGebq8f_Vr"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = KoboldApiLLM(endpoint=\"http://192.168.1.144:5000\", max_length=80)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "sPxNGGiDf_Vr",
|
||||||
|
"outputId": "024a1d62-3cd7-49a8-c6a8-5278224d02ef"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = llm(\"### Instruction:\\nWhat is the first book of the bible?\\n### Response:\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "venv"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 1
|
||||||
|
}
|
@ -29,6 +29,7 @@ from langchain.llms.huggingface_hub import HuggingFaceHub
|
|||||||
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
|
||||||
from langchain.llms.human import HumanInputLLM
|
from langchain.llms.human import HumanInputLLM
|
||||||
|
from langchain.llms.koboldai import KoboldApiLLM
|
||||||
from langchain.llms.llamacpp import LlamaCpp
|
from langchain.llms.llamacpp import LlamaCpp
|
||||||
from langchain.llms.manifest import ManifestWrapper
|
from langchain.llms.manifest import ManifestWrapper
|
||||||
from langchain.llms.modal import Modal
|
from langchain.llms.modal import Modal
|
||||||
@ -81,6 +82,7 @@ __all__ = [
|
|||||||
"HuggingFacePipeline",
|
"HuggingFacePipeline",
|
||||||
"HuggingFaceTextGenInference",
|
"HuggingFaceTextGenInference",
|
||||||
"HumanInputLLM",
|
"HumanInputLLM",
|
||||||
|
"KoboldApiLLM",
|
||||||
"LlamaCpp",
|
"LlamaCpp",
|
||||||
"TextGen",
|
"TextGen",
|
||||||
"ManifestWrapper",
|
"ManifestWrapper",
|
||||||
@ -136,6 +138,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
|||||||
"huggingface_pipeline": HuggingFacePipeline,
|
"huggingface_pipeline": HuggingFacePipeline,
|
||||||
"huggingface_textgen_inference": HuggingFaceTextGenInference,
|
"huggingface_textgen_inference": HuggingFaceTextGenInference,
|
||||||
"human-input": HumanInputLLM,
|
"human-input": HumanInputLLM,
|
||||||
|
"koboldai": KoboldApiLLM,
|
||||||
"llamacpp": LlamaCpp,
|
"llamacpp": LlamaCpp,
|
||||||
"textgen": TextGen,
|
"textgen": TextGen,
|
||||||
"modal": Modal,
|
"modal": Modal,
|
||||||
|
200
langchain/llms/koboldai.py
Normal file
200
langchain/llms/koboldai.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
"""Wrapper around KoboldAI API."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def clean_url(url: str) -> str:
|
||||||
|
"""Remove trailing slash and /api from url if present."""
|
||||||
|
if url.endswith("/api"):
|
||||||
|
return url[:-4]
|
||||||
|
elif url.endswith("/"):
|
||||||
|
return url[:-1]
|
||||||
|
else:
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
class KoboldApiLLM(LLM):
|
||||||
|
"""
|
||||||
|
A class that acts as a wrapper for the Kobold API language model.
|
||||||
|
|
||||||
|
It includes several fields that can be used to control the text generation process.
|
||||||
|
|
||||||
|
To use this class, instantiate it with the required parameters and call it with a
|
||||||
|
prompt to generate text. For example:
|
||||||
|
|
||||||
|
kobold = KoboldApiLLM(endpoint="http://localhost:5000")
|
||||||
|
result = kobold("Write a story about a dragon.")
|
||||||
|
|
||||||
|
This will send a POST request to the Kobold API with the provided prompt and
|
||||||
|
generate text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint: str
|
||||||
|
"""The API endpoint to use for generating text."""
|
||||||
|
|
||||||
|
use_story: Optional[bool] = False
|
||||||
|
""" Whether or not to use the story from the KoboldAI GUI when generating text. """
|
||||||
|
|
||||||
|
use_authors_note: Optional[bool] = False
|
||||||
|
"""Whether to use the author's note from the KoboldAI GUI when generating text.
|
||||||
|
|
||||||
|
This has no effect unless use_story is also enabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
use_world_info: Optional[bool] = False
|
||||||
|
"""Whether to use the world info from the KoboldAI GUI when generating text."""
|
||||||
|
|
||||||
|
use_memory: Optional[bool] = False
|
||||||
|
"""Whether to use the memory from the KoboldAI GUI when generating text."""
|
||||||
|
|
||||||
|
max_context_length: Optional[int] = 1600
|
||||||
|
"""Maximum number of tokens to send to the model.
|
||||||
|
|
||||||
|
minimum: 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_length: Optional[int] = 80
|
||||||
|
"""Number of tokens to generate.
|
||||||
|
|
||||||
|
maximum: 512
|
||||||
|
minimum: 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
rep_pen: Optional[float] = 1.12
|
||||||
|
"""Base repetition penalty value.
|
||||||
|
|
||||||
|
minimum: 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
rep_pen_range: Optional[int] = 1024
|
||||||
|
"""Repetition penalty range.
|
||||||
|
|
||||||
|
minimum: 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
rep_pen_slope: Optional[float] = 0.9
|
||||||
|
"""Repetition penalty slope.
|
||||||
|
|
||||||
|
minimum: 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
temperature: Optional[float] = 0.6
|
||||||
|
"""Temperature value.
|
||||||
|
|
||||||
|
exclusiveMinimum: 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
tfs: Optional[float] = 0.9
|
||||||
|
"""Tail free sampling value.
|
||||||
|
|
||||||
|
maximum: 1
|
||||||
|
minimum: 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_a: Optional[float] = 0.9
|
||||||
|
"""Top-a sampling value.
|
||||||
|
|
||||||
|
minimum: 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_p: Optional[float] = 0.95
|
||||||
|
"""Top-p sampling value.
|
||||||
|
|
||||||
|
maximum: 1
|
||||||
|
minimum: 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_k: Optional[int] = 0
|
||||||
|
"""Top-k sampling value.
|
||||||
|
|
||||||
|
minimum: 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
typical: Optional[float] = 0.5
|
||||||
|
"""Typical sampling value.
|
||||||
|
|
||||||
|
maximum: 1
|
||||||
|
minimum: 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "koboldai"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call the API and return the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to use for generation.
|
||||||
|
stop: A list of strings to stop generation when encountered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The generated text.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import KoboldApiLLM
|
||||||
|
|
||||||
|
llm = KoboldApiLLM(endpoint="http://localhost:5000")
|
||||||
|
llm("Write a story about dragons.")
|
||||||
|
"""
|
||||||
|
data: Dict[str, Any] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"use_story": self.use_story,
|
||||||
|
"use_authors_note": self.use_authors_note,
|
||||||
|
"use_world_info": self.use_world_info,
|
||||||
|
"use_memory": self.use_memory,
|
||||||
|
"max_context_length": self.max_context_length,
|
||||||
|
"max_length": self.max_length,
|
||||||
|
"rep_pen": self.rep_pen,
|
||||||
|
"rep_pen_range": self.rep_pen_range,
|
||||||
|
"rep_pen_slope": self.rep_pen_slope,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"tfs": self.tfs,
|
||||||
|
"top_a": self.top_a,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"typical": self.typical,
|
||||||
|
}
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
data["stop_sequence"] = stop
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{clean_url(self.endpoint)}/api/v1/generate", json=data
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
json_response = response.json()
|
||||||
|
|
||||||
|
if (
|
||||||
|
"results" in json_response
|
||||||
|
and len(json_response["results"]) > 0
|
||||||
|
and "text" in json_response["results"][0]
|
||||||
|
):
|
||||||
|
text = json_response["results"][0]["text"].strip()
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
for sequence in stop:
|
||||||
|
if text.endswith(sequence):
|
||||||
|
text = text[: -len(sequence)].rstrip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected response format from Kobold API: {json_response}"
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user