mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
feat: grammar-based sampling in llama-cpp (#9712)
## Description The following PR enables the [grammar-based sampling](https://github.com/ggerganov/llama.cpp/tree/master/grammars) in llama-cpp LLM. In short, loading file with formal grammar definition will constrain model outputs. For instance, one can force the model to generate valid JSON or generate only python lists. In the follow-up PR we will add: * docs with some description why it is cool and how it works * maybe some code sample for some task such as in llama repo --------- Co-authored-by: Lance Martin <lance@langchain.dev> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
29
libs/langchain/langchain/llms/grammars/json.gbnf
Normal file
29
libs/langchain/langchain/llms/grammars/json.gbnf
Normal file
@@ -0,0 +1,29 @@
|
||||
# Grammar for subset of JSON - doesn't support full string or number syntax
|
||||
|
||||
root ::= object
|
||||
value ::= object | array | string | number | boolean | "null"
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}"
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]"
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
# Only plain integers currently
|
||||
number ::= "-"? [0-9]+ ws
|
||||
boolean ::= ("true" | "false") ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
14
libs/langchain/langchain/llms/grammars/list.gbnf
Normal file
14
libs/langchain/langchain/llms/grammars/list.gbnf
Normal file
@@ -0,0 +1,14 @@
|
||||
root ::= "[" items "]" EOF
|
||||
|
||||
items ::= item ("," ws* item)*
|
||||
|
||||
item ::= string
|
||||
|
||||
string ::=
|
||||
"\"" word (ws+ word)* "\"" ws*
|
||||
|
||||
word ::= [a-zA-Z]+
|
||||
|
||||
ws ::= " "
|
||||
|
||||
EOF ::= "\n"
|
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
@@ -8,6 +11,9 @@ from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_pydantic_field_names
|
||||
from langchain.utils.utils import build_extra_kwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp import LlamaGrammar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -113,12 +119,35 @@ class LlamaCpp(LLM):
|
||||
streaming: bool = True
|
||||
"""Whether to stream the results, token by token."""
|
||||
|
||||
grammar_path: Optional[Union[str, Path]] = None
|
||||
"""
|
||||
grammar_path: Path to the .gbnf file that defines formal grammars
|
||||
for constraining model outputs. For instance, the grammar can be used
|
||||
to force the model to generate valid JSON or to speak exclusively in emojis. At most
|
||||
one of grammar_path and grammar should be passed in.
|
||||
"""
|
||||
grammar: Optional[Union[str, LlamaGrammar]] = None
|
||||
"""
|
||||
grammar: formal grammar for constraining model outputs. For instance, the grammar
|
||||
can be used to force the model to generate valid JSON or to speak exclusively in
|
||||
emojis. At most one of grammar_path and grammar should be passed in.
|
||||
"""
|
||||
|
||||
verbose: bool = True
|
||||
"""Print verbose output to stderr."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that llama-cpp-python library is installed."""
|
||||
try:
|
||||
from llama_cpp import Llama, LlamaGrammar
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import llama-cpp-python library. "
|
||||
"Please install the llama-cpp-python library to "
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
)
|
||||
|
||||
model_path = values["model_path"]
|
||||
model_param_names = [
|
||||
"rope_freq_scale",
|
||||
@@ -146,21 +175,26 @@ class LlamaCpp(LLM):
|
||||
model_params.update(values["model_kwargs"])
|
||||
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
|
||||
values["client"] = Llama(model_path, **model_params)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import llama-cpp-python library. "
|
||||
"Please install the llama-cpp-python library to "
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load Llama model from path: {model_path}. "
|
||||
f"Received error {e}"
|
||||
)
|
||||
|
||||
if values["grammar"] and values["grammar_path"]:
|
||||
grammar = values["grammar"]
|
||||
grammar_path = values["grammar_path"]
|
||||
raise ValueError(
|
||||
"Can only pass in one of grammar and grammar_path. Received "
|
||||
f"{grammar=} and {grammar_path=}."
|
||||
)
|
||||
elif isinstance(values["grammar"], str):
|
||||
values["grammar"] = LlamaGrammar.from_string(values["grammar"])
|
||||
elif values["grammar_path"]:
|
||||
values["grammar"] = LlamaGrammar.from_file(values["grammar_path"])
|
||||
else:
|
||||
pass
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
@@ -176,7 +210,7 @@ class LlamaCpp(LLM):
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling llama_cpp."""
|
||||
return {
|
||||
params = {
|
||||
"suffix": self.suffix,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
@@ -187,6 +221,9 @@ class LlamaCpp(LLM):
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
if self.grammar:
|
||||
params["grammar"] = self.grammar
|
||||
return params
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
@@ -252,7 +289,10 @@ class LlamaCpp(LLM):
|
||||
# and return the combined strings from the first choices's text:
|
||||
combined_text_output = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
combined_text_output += chunk.text
|
||||
return combined_text_output
|
||||
|
Reference in New Issue
Block a user