Compare commits

...

7 Commits

Author SHA1 Message Date
Jacob Lee
6ef6c9e7f1 Revert __init__.py 2024-07-12 10:52:28 -07:00
Jacob Lee
244cd5c141 Update __init__.py 2024-07-12 10:48:28 -07:00
Jacob Lee
412bc82c11 Update curry.py 2024-07-12 10:42:08 -07:00
jacoblee93
fc3353636a Fix lint 2024-07-11 12:55:29 -07:00
jacoblee93
7130aa826f Lint 2024-07-11 12:46:16 -07:00
Jacob Lee
367b2d8dbe Update libs/core/langchain_core/utils/curry.py
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-07-11 12:44:55 -07:00
jacoblee93
432ccd686d Adds curry utils function 2024-07-11 11:19:51 -07:00
2 changed files with 93 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
import asyncio
import inspect
from functools import wraps
from typing import Any, Callable
def curry(func: Callable[..., Any], **curried_kwargs: Any) -> Callable[..., Any]:
"""Util that wraps a function and partially applies kwargs to it.
Returns a new function whose signature omits the curried variables.
Args:
func: The function to curry.
curried_kwargs: Arguments to apply to the function.
Returns:
A new function with curried arguments applied.
.. versionadded:: 0.2.18
"""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
new_kwargs = {**curried_kwargs, **kwargs}
return await func(*args, **new_kwargs)
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
new_kwargs = {**curried_kwargs, **kwargs}
return func(*args, **new_kwargs)
sig = inspect.signature(func)
# Create a new signature without the curried parameters
new_params = [p for name, p in sig.parameters.items() if name not in curried_kwargs]
if asyncio.iscoroutinefunction(func):
async_wrapper = wraps(func)(async_wrapper)
setattr(async_wrapper, "__signature__", sig.replace(parameters=new_params))
return async_wrapper
else:
sync_wrapper = wraps(func)(sync_wrapper)
setattr(sync_wrapper, "__signature__", sig.replace(parameters=new_params))
return sync_wrapper

View File

@@ -0,0 +1,51 @@
from typing import Any
from langchain_core.utils.curry import curry
def test_curry() -> None:
def test_fn(a: str, b: str) -> str:
return a + b
curried = curry(test_fn, a="hey")
assert curried(b=" you") == "hey you"
def test_curry_with_kwargs_values() -> None:
def test_fn(a: str, b: str, **kwargs: Any) -> str:
return a + b + kwargs["c"]
curried = curry(test_fn, c=" you you")
assert curried(a="hey", b=" hey") == "hey hey you you"
def test_noop_curry() -> None:
def test_fn(a: str, b: str) -> str:
return a + b
curried = curry(test_fn)
assert curried(a="bye", b=" you") == "bye you"
async def test_async_curry() -> None:
async def test_fn(a: str, b: str) -> str:
return a + b
curried = curry(test_fn, a="hey")
assert await curried(b=" you") == "hey you"
async def test_async_curry_with_kwargs_values() -> None:
async def test_fn(a: str, b: str, **kwargs: Any) -> str:
return a + b + kwargs["c"]
curried = curry(test_fn, c=" you you")
assert await curried(a="hey", b=" hey") == "hey hey you you"
async def test_noop_async_curry() -> None:
async def test_fn(a: str, b: str) -> str:
return a + b
curried = curry(test_fn)
assert await curried(a="bye", b=" you") == "bye you"