mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
support for arbitrary kwargs for llamacpp (#8727)
llamacpp params (per their own code) are unstable, so instead of adding/deleting them constantly adding a model_kwargs parameter that allows for arbitrary additional kwargs cc @jsjolund and @zacps re #8599 and #8704
This commit is contained in:
parent
f0b0c72d98
commit
115a77142a
@ -6,6 +6,8 @@ from pydantic import Field, root_validator
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_pydantic_field_names
|
||||
from langchain.utils.utils import build_extra_kwargs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -106,6 +108,9 @@ class LlamaCpp(LLM):
|
||||
rope_freq_base: float = 10000.0
|
||||
"""Base frequency for rope sampling."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Any additional parameters to pass to llama_cpp.Llama."""
|
||||
|
||||
streaming: bool = True
|
||||
"""Whether to stream the results, token by token."""
|
||||
|
||||
@ -139,6 +144,8 @@ class LlamaCpp(LLM):
|
||||
if values["n_gpu_layers"] is not None:
|
||||
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
||||
|
||||
model_params.update(values["model_kwargs"])
|
||||
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
|
||||
@ -157,6 +164,16 @@ class LlamaCpp(LLM):
|
||||
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_model_kwargs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling llama_cpp."""
|
||||
|
@ -30,6 +30,7 @@ from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
from langchain.utils.utils import build_extra_kwargs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -215,25 +216,9 @@ class BaseOpenAI(BaseLLM):
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
|
@ -2,8 +2,9 @@
|
||||
import contextlib
|
||||
import datetime
|
||||
import importlib
|
||||
import warnings
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Optional, Set, Tuple
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple
|
||||
|
||||
from packaging.version import parse
|
||||
from requests import HTTPError, Response
|
||||
@ -122,7 +123,7 @@ def check_package_version(
|
||||
)
|
||||
|
||||
|
||||
def get_pydantic_field_names(pydantic_cls: Any) -> Set:
|
||||
def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
|
||||
"""Get field names, including aliases, for a pydantic class.
|
||||
|
||||
Args:
|
||||
@ -133,3 +134,30 @@ def get_pydantic_field_names(pydantic_cls: Any) -> Set:
|
||||
if field.has_alias:
|
||||
all_required_field_names.add(field.alias)
|
||||
return all_required_field_names
|
||||
|
||||
|
||||
def build_extra_kwargs(
|
||||
extra_kwargs: Dict[str, Any],
|
||||
values: Dict[str, Any],
|
||||
all_required_field_names: Set[str],
|
||||
) -> Dict[str, Any]:
|
||||
""""""
|
||||
for field_name in list(values):
|
||||
if field_name in extra_kwargs:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra_kwargs[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
return extra_kwargs
|
||||
|
@ -4,6 +4,8 @@ import os
|
||||
from typing import Generator
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.llms import LlamaCpp
|
||||
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
@ -68,3 +70,19 @@ def test_llamacpp_streaming_callback() -> None:
|
||||
)
|
||||
llm("Q: Can you count to 10? A:'1, ")
|
||||
assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE
|
||||
|
||||
|
||||
def test_llamacpp_model_kwargs() -> None:
|
||||
llm = LlamaCpp(model_path=get_model(), model_kwargs={"n_gqa": None})
|
||||
assert llm.model_kwargs == {"n_gqa": None}
|
||||
|
||||
|
||||
def test_llamacpp_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
LlamaCpp(model_path=get_model(), model_kwargs={"n_ctx": 1024})
|
||||
|
||||
|
||||
def test_llamacpp_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = LlamaCpp(model_path=get_model(), n_gqa=None)
|
||||
llm.model_kwargs == {"n_gqa": None}
|
||||
|
@ -22,6 +22,12 @@ def test_openai_model_param() -> None:
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_model_kwargs() -> None:
|
||||
llm = OpenAI(model_kwargs={"foo": "bar"})
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user