Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
6fe06272c5 Curry 2024-07-10 18:48:21 -07:00

View File

@@ -4,12 +4,15 @@ import contextlib
import datetime
import functools
import importlib
import inspect
import warnings
from functools import wraps
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar, Union, cast
from packaging.version import parse
from requests import HTTPError, Response
from typing_extensions import Concatenate, ParamSpec, TypedDict
from langchain_core.pydantic_v1 import SecretStr
@@ -180,3 +183,31 @@ def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
if isinstance(value, SecretStr):
return value
return SecretStr(value)
P = ParamSpec("P")
PM = TypeVar("PM")
R = TypeVar("R")
def curry(func: Callable[Concatenate[PM, P], R], **fixed_kwargs: PM) -> Callable[P, R]:
"""Bind parameters to a function, removing those parameters from the signature.
Useful for exposing a narrower interface than what the the original function
provides.
"""
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
new_kwargs = {**fixed_kwargs, **cast(fixed_kwargs, kwargs)}
return func(*args, **new_kwargs)
sig = inspect.signature(func)
# Check that fixed_kwargs are all valid parameters of the function
invalid_kwargs = set(fixed_kwargs) - set(sig.parameters)
if invalid_kwargs:
raise ValueError(f"Invalid parameters: {invalid_kwargs}")
new_params = [p for name, p in sig.parameters.items() if name not in fixed_kwargs]
wrapper.__signature__ = sig.replace(parameters=new_params)
return wrapper