mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
change run to use args and kwargs (#367)
Before, `run` was not able to be called with multiple arguments. This expands the functionality.
This commit is contained in:
parent
a7084ad6e4
commit
8d0869c6d3
1
.flake8
1
.flake8
@ -1,5 +1,6 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
exclude =
|
exclude =
|
||||||
|
venv
|
||||||
.venv
|
.venv
|
||||||
__pycache__
|
__pycache__
|
||||||
notebooks
|
notebooks
|
||||||
|
@ -119,16 +119,23 @@ class Chain(BaseModel, ABC):
|
|||||||
"""Call the chain on all inputs in the list."""
|
"""Call the chain on all inputs in the list."""
|
||||||
return [self(inputs) for inputs in input_list]
|
return [self(inputs) for inputs in input_list]
|
||||||
|
|
||||||
def run(self, text: str) -> str:
|
def run(self, *args: str, **kwargs: str) -> str:
|
||||||
"""Run text in, text out (if applicable)."""
|
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||||
if len(self.input_keys) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"`run` not supported when there is not exactly "
|
|
||||||
f"one input key, got {self.input_keys}."
|
|
||||||
)
|
|
||||||
if len(self.output_keys) != 1:
|
if len(self.output_keys) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`run` not supported when there is not exactly "
|
f"`run` not supported when there is not exactly "
|
||||||
f"one output key, got {self.output_keys}."
|
f"one output key. Got {self.output_keys}."
|
||||||
)
|
)
|
||||||
return self({self.input_keys[0]: text})[self.output_keys[0]]
|
|
||||||
|
if args and not kwargs:
|
||||||
|
if len(args) != 1:
|
||||||
|
raise ValueError("`run` supports only one positional argument.")
|
||||||
|
return self(args[0])[self.output_keys[0]]
|
||||||
|
|
||||||
|
if kwargs and not args:
|
||||||
|
return self(kwargs)[self.output_keys[0]]
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"`run` supported with either positional arguments or keyword arguments"
|
||||||
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||||
|
)
|
||||||
|
@ -12,6 +12,7 @@ class FakeChain(Chain, BaseModel):
|
|||||||
|
|
||||||
be_correct: bool = True
|
be_correct: bool = True
|
||||||
the_input_keys: List[str] = ["foo"]
|
the_input_keys: List[str] = ["foo"]
|
||||||
|
the_output_keys: List[str] = ["bar"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
@ -21,7 +22,7 @@ class FakeChain(Chain, BaseModel):
|
|||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Output key of bar."""
|
"""Output key of bar."""
|
||||||
return ["bar"]
|
return self.the_output_keys
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
if self.be_correct:
|
if self.be_correct:
|
||||||
@ -63,3 +64,45 @@ def test_single_input_error() -> None:
|
|||||||
chain = FakeChain(the_input_keys=["foo", "bar"])
|
chain = FakeChain(the_input_keys=["foo", "bar"])
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chain("bar")
|
chain("bar")
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_single_arg() -> None:
|
||||||
|
"""Test run method with single arg."""
|
||||||
|
chain = FakeChain()
|
||||||
|
output = chain.run("bar")
|
||||||
|
assert output == "baz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_multiple_args_error() -> None:
|
||||||
|
"""Test run method with multiple args errors as expected."""
|
||||||
|
chain = FakeChain()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain.run("bar", "foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_kwargs() -> None:
|
||||||
|
"""Test run method with kwargs."""
|
||||||
|
chain = FakeChain(the_input_keys=["foo", "bar"])
|
||||||
|
output = chain.run(foo="bar", bar="foo")
|
||||||
|
assert output == "baz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_kwargs_error() -> None:
|
||||||
|
"""Test run method with kwargs errors as expected."""
|
||||||
|
chain = FakeChain(the_input_keys=["foo", "bar"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain.run(foo="bar", baz="foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_args_and_kwargs_error() -> None:
|
||||||
|
"""Test run method with args and kwargs."""
|
||||||
|
chain = FakeChain(the_input_keys=["foo", "bar"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain.run("bar", foo="bar")
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_output_keys_error() -> None:
|
||||||
|
"""Test run with multiple output keys errors as expected."""
|
||||||
|
chain = FakeChain(the_output_keys=["foo", "bar"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chain.run("bar")
|
||||||
|
Loading…
Reference in New Issue
Block a user