feat: grammar-based sampling in llama-cpp (#9712)

## Description 

The following PR enables the [grammar-based
sampling](https://github.com/ggerganov/llama.cpp/tree/master/grammars)
in llama-cpp LLM.

In short, loading file with formal grammar definition will constrain
model outputs. For instance, one can force the model to generate valid
JSON or generate only python lists.

In the follow-up PR we will add:
* docs with some description why it is cool and how it works
* maybe some code sample for some task such as in llama repo

---------

Co-authored-by: Lance Martin <lance@langchain.dev>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
eryk-dsai
2023-08-28 18:52:55 +02:00
committed by GitHub
parent cb642ef658
commit 7f5713b80a
4 changed files with 620 additions and 48 deletions

View File

@@ -0,0 +1,29 @@
# Grammar for subset of JSON - doesn't support full string or number syntax
root ::= object
value ::= object | array | string | number | boolean | "null"
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}"
array ::=
"[" ws (
value
("," ws value)*
)? "]"
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
# Only plain integers currently
number ::= "-"? [0-9]+ ws
boolean ::= ("true" | "false") ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?

View File

@@ -0,0 +1,14 @@
root ::= "[" items "]" EOF
items ::= item ("," ws* item)*
item ::= string
string ::=
"\"" word (ws+ word)* "\"" ws*
word ::= [a-zA-Z]+
ws ::= " "
EOF ::= "\n"

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
import logging
from typing import Any, Dict, Iterator, List, Optional
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
@@ -8,6 +11,9 @@ from langchain.schema.output import GenerationChunk
from langchain.utils import get_pydantic_field_names
from langchain.utils.utils import build_extra_kwargs
if TYPE_CHECKING:
from llama_cpp import LlamaGrammar
logger = logging.getLogger(__name__)
@@ -113,12 +119,35 @@ class LlamaCpp(LLM):
streaming: bool = True
"""Whether to stream the results, token by token."""
grammar_path: Optional[Union[str, Path]] = None
"""
grammar_path: Path to the .gbnf file that defines formal grammars
for constraining model outputs. For instance, the grammar can be used
to force the model to generate valid JSON or to speak exclusively in emojis. At most
one of grammar_path and grammar should be passed in.
"""
grammar: Optional[Union[str, LlamaGrammar]] = None
"""
grammar: formal grammar for constraining model outputs. For instance, the grammar
can be used to force the model to generate valid JSON or to speak exclusively in
emojis. At most one of grammar_path and grammar should be passed in.
"""
verbose: bool = True
"""Print verbose output to stderr."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
try:
from llama_cpp import Llama, LlamaGrammar
except ImportError:
raise ImportError(
"Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python"
)
model_path = values["model_path"]
model_param_names = [
"rope_freq_scale",
@@ -146,21 +175,26 @@ class LlamaCpp(LLM):
model_params.update(values["model_kwargs"])
try:
from llama_cpp import Llama
values["client"] = Llama(model_path, **model_params)
except ImportError:
raise ImportError(
"Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python"
)
except Exception as e:
raise ValueError(
f"Could not load Llama model from path: {model_path}. "
f"Received error {e}"
)
if values["grammar"] and values["grammar_path"]:
grammar = values["grammar"]
grammar_path = values["grammar_path"]
raise ValueError(
"Can only pass in one of grammar and grammar_path. Received "
f"{grammar=} and {grammar_path=}."
)
elif isinstance(values["grammar"], str):
values["grammar"] = LlamaGrammar.from_string(values["grammar"])
elif values["grammar_path"]:
values["grammar"] = LlamaGrammar.from_file(values["grammar_path"])
else:
pass
return values
@root_validator(pre=True)
@@ -176,7 +210,7 @@ class LlamaCpp(LLM):
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling llama_cpp."""
return {
params = {
"suffix": self.suffix,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
@@ -187,6 +221,9 @@ class LlamaCpp(LLM):
"repeat_penalty": self.repeat_penalty,
"top_k": self.top_k,
}
if self.grammar:
params["grammar"] = self.grammar
return params
@property
def _identifying_params(self) -> Dict[str, Any]:
@@ -252,7 +289,10 @@ class LlamaCpp(LLM):
# and return the combined strings from the first choices's text:
combined_text_output = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
prompt=prompt,
stop=stop,
run_manager=run_manager,
**kwargs,
):
combined_text_output += chunk.text
return combined_text_output