mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
Update LlamaCpp parameters (#2411)
Add `n_batch` and `last_n_tokens_size` parameters to the LlamaCpp class. These parameters (epecially `n_batch`) significantly effect performance. There's also a `verbose` flag that prints system timings on the `Llama` class but I wasn't sure where to add this as it conflicts with (should be pulled from?) the LLM base class.
This commit is contained in:
parent
b026a62bc4
commit
e519a81a05
@ -49,6 +49,10 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Number of threads to use. If None, the number
|
"""Number of threads to use. If None, the number
|
||||||
of threads is automatically determined."""
|
of threads is automatically determined."""
|
||||||
|
|
||||||
|
n_batch: Optional[int] = Field(8, alias="n_batch")
|
||||||
|
"""Number of tokens to process in parallel.
|
||||||
|
Should be a number between 1 and n_ctx."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -66,6 +70,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
vocab_only = values["vocab_only"]
|
vocab_only = values["vocab_only"]
|
||||||
use_mlock = values["use_mlock"]
|
use_mlock = values["use_mlock"]
|
||||||
n_threads = values["n_threads"]
|
n_threads = values["n_threads"]
|
||||||
|
n_batch = values["n_batch"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
@ -80,6 +85,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
vocab_only=vocab_only,
|
vocab_only=vocab_only,
|
||||||
use_mlock=use_mlock,
|
use_mlock=use_mlock,
|
||||||
n_threads=n_threads,
|
n_threads=n_threads,
|
||||||
|
n_batch=n_batch,
|
||||||
embedding=True,
|
embedding=True,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -53,6 +53,10 @@ class LlamaCpp(LLM, BaseModel):
|
|||||||
"""Number of threads to use.
|
"""Number of threads to use.
|
||||||
If None, the number of threads is automatically determined."""
|
If None, the number of threads is automatically determined."""
|
||||||
|
|
||||||
|
n_batch: Optional[int] = Field(8, alias="n_batch")
|
||||||
|
"""Number of tokens to process in parallel.
|
||||||
|
Should be a number between 1 and n_ctx."""
|
||||||
|
|
||||||
suffix: Optional[str] = Field(None)
|
suffix: Optional[str] = Field(None)
|
||||||
"""A suffix to append to the generated text. If None, no suffix is appended."""
|
"""A suffix to append to the generated text. If None, no suffix is appended."""
|
||||||
|
|
||||||
@ -80,6 +84,9 @@ class LlamaCpp(LLM, BaseModel):
|
|||||||
top_k: Optional[int] = 40
|
top_k: Optional[int] = 40
|
||||||
"""The top-k value to use for sampling."""
|
"""The top-k value to use for sampling."""
|
||||||
|
|
||||||
|
last_n_tokens_size: Optional[int] = 64
|
||||||
|
"""The number of tokens to look back when applying the repeat_penalty."""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that llama-cpp-python library is installed."""
|
"""Validate that llama-cpp-python library is installed."""
|
||||||
@ -92,6 +99,8 @@ class LlamaCpp(LLM, BaseModel):
|
|||||||
vocab_only = values["vocab_only"]
|
vocab_only = values["vocab_only"]
|
||||||
use_mlock = values["use_mlock"]
|
use_mlock = values["use_mlock"]
|
||||||
n_threads = values["n_threads"]
|
n_threads = values["n_threads"]
|
||||||
|
n_batch = values["n_batch"]
|
||||||
|
last_n_tokens_size = values["last_n_tokens_size"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
@ -106,6 +115,8 @@ class LlamaCpp(LLM, BaseModel):
|
|||||||
vocab_only=vocab_only,
|
vocab_only=vocab_only,
|
||||||
use_mlock=use_mlock,
|
use_mlock=use_mlock,
|
||||||
n_threads=n_threads,
|
n_threads=n_threads,
|
||||||
|
n_batch=n_batch,
|
||||||
|
last_n_tokens_size=last_n_tokens_size,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
|
Loading…
Reference in New Issue
Block a user