mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-21 12:01:47 +00:00
community[minor]: Adds Llamafile as an LLM (#17431)
* **Description:** Adds a simple LLM implementation for interacting with [llamafile](https://github.com/Mozilla-Ocho/llamafile)-based models. * **Dependencies:** N/A * **Issue:** N/A **Detail** [llamafile](https://github.com/Mozilla-Ocho/llamafile) lets you run LLMs locally from a single file on most computers without installing any dependencies. To use the llamafile LLM implementation, the user needs to: 1. Download a llamafile e.g. https://huggingface.co/jartine/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile?download=true 2. Make the file executable. 3. Run the llamafile in 'server mode'. (All llamafiles come packaged with a lightweight server; by default, the server listens at `http://localhost:8080`.) ```bash wget https://url/of/model.llamafile chmod +x model.llamafile ./model.llamafile --server --nobrowser ``` Now, the user can invoke the LLM via the LangChain client: ```python from langchain_community.llms.llamafile import Llamafile llm = Llamafile() llm.invoke("Tell me a joke.") ```
This commit is contained in:
parent
5ce1827d31
commit
0bc4a9b3fc
133
docs/docs/integrations/llms/llamafile.ipynb
Normal file
133
docs/docs/integrations/llms/llamafile.ipynb
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Llamafile\n",
|
||||||
|
"\n",
|
||||||
|
"[Llamafile](https://github.com/Mozilla-Ocho/llamafile) lets you distribute and run LLMs with a single file.\n",
|
||||||
|
"\n",
|
||||||
|
"Llamafile does this by combining [llama.cpp](https://github.com/ggerganov/llama.cpp) with [Cosmopolitan Libc](https://github.com/jart/cosmopolitan) into one framework that collapses all the complexity of LLMs down to a single-file executable (called a \"llamafile\") that runs locally on most computers, with no installation.\n",
|
||||||
|
"\n",
|
||||||
|
"## Setup\n",
|
||||||
|
"\n",
|
||||||
|
"1. Download a llamafile for the model you'd like to use. You can find many models in llamafile format on [HuggingFace](https://huggingface.co/models?other=llamafile). In this guide, we will download a small one, `TinyLlama-1.1B-Chat-v1.0.Q5_K_M`. Note: if you don't have `wget`, you can just download the model via this [link](https://huggingface.co/jartine/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile?download=true).\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"wget https://huggingface.co/jartine/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"2. Make the llamafile executable. First, if you haven't done so already, open a terminal. **If you're using MacOS, Linux, or BSD,** you'll need to grant permission for your computer to execute this new file using `chmod` (see below). **If you're on Windows,** rename the file by adding \".exe\" to the end (model file should be named `TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile.exe`).\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"chmod +x TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile # run if you're on MacOS, Linux, or BSD\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"3. Run the llamafile in \"server mode\":\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Now you can make calls to the llamafile's REST API. By default, the llamafile server listens at http://localhost:8080. You can find full server documentation [here](https://github.com/Mozilla-Ocho/llamafile/blob/main/llama.cpp/server/README.md#api-endpoints). You can interact with the llamafile directly via the REST API, but here we'll show how to interact with it using LangChain.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Usage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'? \\nI\\'ve got a thing for pink, but you know that.\\n\"Can we not talk about work anymore?\" - What did she say?\\nI don\\'t want to be a burden on you.\\nIt\\'s hard to keep a good thing going.\\nYou can\\'t tell me what I want, I have a life too!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.llms.llamafile import Llamafile\n",
|
||||||
|
"\n",
|
||||||
|
"llm = Llamafile()\n",
|
||||||
|
"\n",
|
||||||
|
"llm.invoke(\"Tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"To stream tokens, use the `.stream(...)` method:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
".\n",
|
||||||
|
"- She said, \"I’m tired of my life. What should I do?\"\n",
|
||||||
|
"- The man replied, \"I hear you. But don’t worry. Life is just like a joke. It has its funny parts too.\"\n",
|
||||||
|
"- The woman looked at him, amazed and happy to hear his wise words. - \"Thank you for your wisdom,\" she said, smiling. - He replied, \"Any time. But it doesn't come easy. You have to laugh and keep moving forward in life.\"\n",
|
||||||
|
"- She nodded, thanking him again. - The man smiled wryly. \"Life can be tough. Sometimes it seems like you’re never going to get out of your situation.\"\n",
|
||||||
|
"- He said, \"I know that. But the key is not giving up. Life has many ups and downs, but in the end, it will turn out okay.\"\n",
|
||||||
|
"- The woman's eyes softened. \"Thank you for your advice. It's so important to keep moving forward in life,\" she said. - He nodded once again. \"You’re welcome. I hope your journey is filled with laughter and joy.\"\n",
|
||||||
|
"- They both smiled and left the bar, ready to embark on their respective adventures.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"query = \"Tell me a joke\"\n",
|
||||||
|
"\n",
|
||||||
|
"for chunks in llm.stream(query):\n",
|
||||||
|
" print(chunks, end=\"\")\n",
|
||||||
|
"\n",
|
||||||
|
"print()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"To learn more about the LangChain Expressive Language and the available methods on an LLM, see the [LCEL Interface](https://python.langchain.com/docs/expression_language/interface)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
318
libs/community/langchain_community/llms/llamafile.py
Normal file
318
libs/community/langchain_community/llms/llamafile.py
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from io import StringIO
|
||||||
|
from typing import Any, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models.llms import LLM
|
||||||
|
from langchain_core.outputs import GenerationChunk
|
||||||
|
from langchain_core.pydantic_v1 import Extra
|
||||||
|
from langchain_core.utils import get_pydantic_field_names
|
||||||
|
|
||||||
|
|
||||||
|
class Llamafile(LLM):
|
||||||
|
"""Llamafile lets you distribute and run large language models with a
|
||||||
|
single file.
|
||||||
|
|
||||||
|
To get started, see: https://github.com/Mozilla-Ocho/llamafile
|
||||||
|
|
||||||
|
To use this class, you will need to first:
|
||||||
|
|
||||||
|
1. Download a llamafile.
|
||||||
|
2. Make the downloaded file executable: `chmod +x path/to/model.llamafile`
|
||||||
|
3. Start the llamafile in server mode:
|
||||||
|
|
||||||
|
`./path/to/model.llamafile --server --nobrowser`
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.llms import Llamafile
|
||||||
|
llm = Llamafile()
|
||||||
|
llm.invoke("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_url: str = "http://localhost:8080"
|
||||||
|
"""Base url where the llamafile server is listening."""
|
||||||
|
|
||||||
|
request_timeout: Optional[int] = None
|
||||||
|
"""Timeout for server requests"""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Allows receiving each predicted token in real-time instead of
|
||||||
|
waiting for the completion to finish. To enable this, set to true."""
|
||||||
|
|
||||||
|
# Generation options
|
||||||
|
|
||||||
|
seed: int = -1
|
||||||
|
"""Random Number Generator (RNG) seed. A random seed is used if this is
|
||||||
|
less than zero. Default: -1"""
|
||||||
|
|
||||||
|
temperature: float = 0.8
|
||||||
|
"""Temperature. Default: 0.8"""
|
||||||
|
|
||||||
|
top_k: int = 40
|
||||||
|
"""Limit the next token selection to the K most probable tokens.
|
||||||
|
Default: 40."""
|
||||||
|
|
||||||
|
top_p: float = 0.95
|
||||||
|
"""Limit the next token selection to a subset of tokens with a cumulative
|
||||||
|
probability above a threshold P. Default: 0.95."""
|
||||||
|
|
||||||
|
min_p: float = 0.05
|
||||||
|
"""The minimum probability for a token to be considered, relative to
|
||||||
|
the probability of the most likely token. Default: 0.05."""
|
||||||
|
|
||||||
|
n_predict: int = -1
|
||||||
|
"""Set the maximum number of tokens to predict when generating text.
|
||||||
|
Note: May exceed the set limit slightly if the last token is a partial
|
||||||
|
multibyte character. When 0, no tokens will be generated but the prompt
|
||||||
|
is evaluated into the cache. Default: -1 = infinity."""
|
||||||
|
|
||||||
|
n_keep: int = 0
|
||||||
|
"""Specify the number of tokens from the prompt to retain when the
|
||||||
|
context size is exceeded and tokens need to be discarded. By default,
|
||||||
|
this value is set to 0 (meaning no tokens are kept). Use -1 to retain all
|
||||||
|
tokens from the prompt."""
|
||||||
|
|
||||||
|
tfs_z: float = 1.0
|
||||||
|
"""Enable tail free sampling with parameter z. Default: 1.0 = disabled."""
|
||||||
|
|
||||||
|
typical_p: float = 1.0
|
||||||
|
"""Enable locally typical sampling with parameter p.
|
||||||
|
Default: 1.0 = disabled."""
|
||||||
|
|
||||||
|
repeat_penalty: float = 1.1
|
||||||
|
"""Control the repetition of token sequences in the generated text.
|
||||||
|
Default: 1.1"""
|
||||||
|
|
||||||
|
repeat_last_n: int = 64
|
||||||
|
"""Last n tokens to consider for penalizing repetition. Default: 64,
|
||||||
|
0 = disabled, -1 = ctx-size."""
|
||||||
|
|
||||||
|
penalize_nl: bool = True
|
||||||
|
"""Penalize newline tokens when applying the repeat penalty.
|
||||||
|
Default: true."""
|
||||||
|
|
||||||
|
presence_penalty: float = 0.0
|
||||||
|
"""Repeat alpha presence penalty. Default: 0.0 = disabled."""
|
||||||
|
|
||||||
|
frequency_penalty: float = 0.0
|
||||||
|
"""Repeat alpha frequency penalty. Default: 0.0 = disabled"""
|
||||||
|
|
||||||
|
mirostat: int = 0
|
||||||
|
"""Enable Mirostat sampling, controlling perplexity during text
|
||||||
|
generation. 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0.
|
||||||
|
Default: disabled."""
|
||||||
|
|
||||||
|
mirostat_tau: float = 5.0
|
||||||
|
"""Set the Mirostat target entropy, parameter tau. Default: 5.0."""
|
||||||
|
|
||||||
|
mirostat_eta: float = 0.1
|
||||||
|
"""Set the Mirostat learning rate, parameter eta. Default: 0.1."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "llamafile"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _param_fieldnames(self) -> List[str]:
|
||||||
|
# Return the list of fieldnames that will be passed as configurable
|
||||||
|
# generation options to the llamafile server. Exclude 'builtin' fields
|
||||||
|
# from the BaseLLM class like 'metadata' as well as fields that should
|
||||||
|
# not be passed in requests (base_url, request_timeout).
|
||||||
|
ignore_keys = [
|
||||||
|
"base_url",
|
||||||
|
"cache",
|
||||||
|
"callback_manager",
|
||||||
|
"callbacks",
|
||||||
|
"metadata",
|
||||||
|
"name",
|
||||||
|
"request_timeout",
|
||||||
|
"streaming",
|
||||||
|
"tags",
|
||||||
|
"verbose",
|
||||||
|
]
|
||||||
|
attrs = [
|
||||||
|
k for k in get_pydantic_field_names(self.__class__) if k not in ignore_keys
|
||||||
|
]
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
params = {}
|
||||||
|
for fieldname in self._param_fieldnames:
|
||||||
|
params[fieldname] = getattr(self, fieldname)
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _get_parameters(
|
||||||
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
params = self._default_params
|
||||||
|
|
||||||
|
# Only update keys that are already present in the default config.
|
||||||
|
# This way, we don't accidentally post unknown/unhandled key/values
|
||||||
|
# in the request to the llamafile server
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k in params:
|
||||||
|
params[k] = v
|
||||||
|
|
||||||
|
if stop is not None and len(stop) > 0:
|
||||||
|
params["stop"] = stop
|
||||||
|
|
||||||
|
if self.streaming:
|
||||||
|
params["stream"] = True
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Request prompt completion from the llamafile server and return the
|
||||||
|
output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to use for generation.
|
||||||
|
stop: A list of strings to stop generation when encountered.
|
||||||
|
run_manager:
|
||||||
|
**kwargs: Any additional options to pass as part of the
|
||||||
|
generation request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.streaming:
|
||||||
|
with StringIO() as buff:
|
||||||
|
for chunk in self._stream(
|
||||||
|
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
buff.write(chunk.text)
|
||||||
|
|
||||||
|
text = buff.getvalue()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
else:
|
||||||
|
params = self._get_parameters(stop=stop, **kwargs)
|
||||||
|
payload = {"prompt": prompt, **params}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
url=f"{self.base_url}/completion",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=payload,
|
||||||
|
stream=False,
|
||||||
|
timeout=self.request_timeout,
|
||||||
|
)
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
raise requests.exceptions.ConnectionError(
|
||||||
|
f"Could not connect to Llamafile server. Please make sure "
|
||||||
|
f"that a server is running at {self.base_url}."
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
response.encoding = "utf-8"
|
||||||
|
|
||||||
|
text = response.json()["content"]
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
"""Yields results objects as they are generated in real time.
|
||||||
|
|
||||||
|
It also calls the callback manager's on_llm_new_token event with
|
||||||
|
similar parameters to the OpenAI LLM class method of the same name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompts to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
run_manager:
|
||||||
|
**kwargs: Any additional options to pass as part of the
|
||||||
|
generation request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A generator representing the stream of tokens being generated.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Dictionary-like objects each containing a token
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.llms import Llamafile
|
||||||
|
llm = Llamafile(
|
||||||
|
temperature = 0.0
|
||||||
|
)
|
||||||
|
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
|
||||||
|
stop=["'","\n"]):
|
||||||
|
result = chunk["choices"][0]
|
||||||
|
print(result["text"], end='', flush=True)
|
||||||
|
|
||||||
|
"""
|
||||||
|
params = self._get_parameters(stop=stop, **kwargs)
|
||||||
|
if "stream" not in params:
|
||||||
|
params["stream"] = True
|
||||||
|
|
||||||
|
payload = {"prompt": prompt, **params}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
url=f"{self.base_url}/completion",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=payload,
|
||||||
|
stream=True,
|
||||||
|
timeout=self.request_timeout,
|
||||||
|
)
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
raise requests.exceptions.ConnectionError(
|
||||||
|
f"Could not connect to Llamafile server. Please make sure "
|
||||||
|
f"that a server is running at {self.base_url}."
|
||||||
|
)
|
||||||
|
|
||||||
|
response.encoding = "utf8"
|
||||||
|
|
||||||
|
for raw_chunk in response.iter_lines(decode_unicode=True):
|
||||||
|
content = self._get_chunk_content(raw_chunk)
|
||||||
|
chunk = GenerationChunk(text=content)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token=chunk.text)
|
||||||
|
|
||||||
|
def _get_chunk_content(self, chunk: str) -> str:
|
||||||
|
"""When streaming is turned on, llamafile server returns lines like:
|
||||||
|
|
||||||
|
'data: {"content":" They","multimodal":true,"slot_id":0,"stop":false}'
|
||||||
|
|
||||||
|
Here, we convert this to a dict and return the value of the 'content'
|
||||||
|
field
|
||||||
|
"""
|
||||||
|
|
||||||
|
if chunk.startswith("data:"):
|
||||||
|
cleaned = chunk.lstrip("data: ")
|
||||||
|
data = json.loads(cleaned)
|
||||||
|
return data["content"]
|
||||||
|
else:
|
||||||
|
return chunk
|
@ -0,0 +1,46 @@
|
|||||||
|
import os
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import ConnectionError, HTTPError
|
||||||
|
|
||||||
|
from langchain_community.llms.llamafile import Llamafile
|
||||||
|
|
||||||
|
LLAMAFILE_SERVER_BASE_URL = os.getenv(
|
||||||
|
"LLAMAFILE_SERVER_BASE_URL", "http://localhost:8080"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ping_llamafile_server() -> bool:
|
||||||
|
try:
|
||||||
|
response = requests.get(LLAMAFILE_SERVER_BASE_URL)
|
||||||
|
response.raise_for_status()
|
||||||
|
except (ConnectionError, HTTPError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not _ping_llamafile_server(),
|
||||||
|
reason=f"unable to find llamafile server at {LLAMAFILE_SERVER_BASE_URL}, "
|
||||||
|
f"please start one and re-run this test",
|
||||||
|
)
|
||||||
|
def test_llamafile_call() -> None:
|
||||||
|
llm = Llamafile()
|
||||||
|
output = llm.invoke("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not _ping_llamafile_server(),
|
||||||
|
reason=f"unable to find llamafile server at {LLAMAFILE_SERVER_BASE_URL}, "
|
||||||
|
f"please start one and re-run this test",
|
||||||
|
)
|
||||||
|
def test_llamafile_streaming() -> None:
|
||||||
|
llm = Llamafile(streaming=True)
|
||||||
|
generator = llm.stream("Tell me about Roman dodecahedrons.")
|
||||||
|
assert isinstance(generator, Generator)
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
158
libs/community/tests/unit_tests/llms/test_llamafile.py
Normal file
158
libs/community/tests/unit_tests/llms/test_llamafile.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import json
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
|
from langchain_community.llms.llamafile import Llamafile
|
||||||
|
|
||||||
|
|
||||||
|
def default_generation_params() -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"temperature": 0.8,
|
||||||
|
"seed": -1,
|
||||||
|
"top_k": 40,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"min_p": 0.05,
|
||||||
|
"n_predict": -1,
|
||||||
|
"n_keep": 0,
|
||||||
|
"tfs_z": 1.0,
|
||||||
|
"typical_p": 1.0,
|
||||||
|
"repeat_penalty": 1.1,
|
||||||
|
"repeat_last_n": 64,
|
||||||
|
"penalize_nl": True,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"mirostat": 0,
|
||||||
|
"mirostat_tau": 5.0,
|
||||||
|
"mirostat_eta": 0.1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def mock_response() -> requests.Response:
|
||||||
|
contents = json.dumps({"content": "the quick brown fox"})
|
||||||
|
response = requests.Response()
|
||||||
|
response.status_code = 200
|
||||||
|
response._content = str.encode(contents)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def mock_response_stream(): # type: ignore[no-untyped-def]
|
||||||
|
mock_response = deque(
|
||||||
|
[
|
||||||
|
b'data: {"content":"the","multimodal":false,"slot_id":0,"stop":false}\n\n', # noqa
|
||||||
|
b'data: {"content":" quick","multimodal":false,"slot_id":0,"stop":false}\n\n', # noqa
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
class MockRaw:
|
||||||
|
def read(self, chunk_size): # type: ignore[no-untyped-def]
|
||||||
|
try:
|
||||||
|
return mock_response.popleft()
|
||||||
|
except IndexError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
response = requests.Response()
|
||||||
|
response.status_code = 200
|
||||||
|
response.raw = MockRaw()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def test_call(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
"""
|
||||||
|
Test basic functionality of the `invoke` method
|
||||||
|
"""
|
||||||
|
llm = Llamafile(
|
||||||
|
base_url="http://llamafile-host:8080",
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||||
|
assert url == "http://llamafile-host:8080/completion"
|
||||||
|
assert headers == {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
# 'unknown' kwarg should be ignored
|
||||||
|
assert json == {"prompt": "Test prompt", **default_generation_params()}
|
||||||
|
assert stream is False
|
||||||
|
assert timeout is None
|
||||||
|
return mock_response()
|
||||||
|
|
||||||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
out = llm.invoke("Test prompt")
|
||||||
|
assert out == "the quick brown fox"
|
||||||
|
|
||||||
|
|
||||||
|
def test_call_with_kwargs(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
"""
|
||||||
|
Test kwargs passed to `invoke` override the default values and are passed
|
||||||
|
to the endpoint correctly. Also test that any 'unknown' kwargs that are not
|
||||||
|
present in the LLM class attrs are ignored.
|
||||||
|
"""
|
||||||
|
llm = Llamafile(
|
||||||
|
base_url="http://llamafile-host:8080",
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||||
|
assert url == "http://llamafile-host:8080/completion"
|
||||||
|
assert headers == {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
# 'unknown' kwarg should be ignored
|
||||||
|
expected = {"prompt": "Test prompt", **default_generation_params()}
|
||||||
|
expected["seed"] = 0
|
||||||
|
assert json == expected
|
||||||
|
assert stream is False
|
||||||
|
assert timeout is None
|
||||||
|
return mock_response()
|
||||||
|
|
||||||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
out = llm.invoke(
|
||||||
|
"Test prompt",
|
||||||
|
unknown="unknown option", # should be ignored
|
||||||
|
seed=0, # should override the default
|
||||||
|
)
|
||||||
|
assert out == "the quick brown fox"
|
||||||
|
|
||||||
|
|
||||||
|
def test_call_raises_exception_on_missing_server(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
"""
|
||||||
|
Test that the LLM raises a ConnectionError when no llamafile server is
|
||||||
|
listening at the base_url.
|
||||||
|
"""
|
||||||
|
llm = Llamafile(
|
||||||
|
# invalid url, nothing should actually be running here
|
||||||
|
base_url="http://llamafile-host:8080",
|
||||||
|
)
|
||||||
|
with pytest.raises(requests.exceptions.ConnectionError):
|
||||||
|
llm.invoke("Test prompt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming(monkeypatch: MonkeyPatch) -> None:
|
||||||
|
"""
|
||||||
|
Test basic functionality of `invoke` with streaming enabled.
|
||||||
|
"""
|
||||||
|
llm = Llamafile(
|
||||||
|
base_url="http://llamafile-hostname:8080",
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||||||
|
assert url == "http://llamafile-hostname:8080/completion"
|
||||||
|
assert headers == {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
# 'unknown' kwarg should be ignored
|
||||||
|
assert "unknown" not in json
|
||||||
|
expected = {"prompt": "Test prompt", **default_generation_params()}
|
||||||
|
expected["stream"] = True
|
||||||
|
assert json == expected
|
||||||
|
assert stream is True
|
||||||
|
assert timeout is None
|
||||||
|
|
||||||
|
return mock_response_stream()
|
||||||
|
|
||||||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||||||
|
out = llm.invoke("Test prompt")
|
||||||
|
assert out == "the quick"
|
Loading…
Reference in New Issue
Block a user