Compare commits

...

2 Commits

Author SHA1 Message Date
vowelparrot
0dabbc2fe0 Throw an error instead of just a warning 2023-06-28 07:37:54 -07:00
vowelparrot
ec76a110a4 Accept Chain 2023-06-28 07:33:23 -07:00
2 changed files with 76 additions and 24 deletions

View File

@@ -42,6 +42,7 @@ from langchain.schema import (
logger = logging.getLogger(__name__)
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
MODEL_OR_CHAIN_FACTORY_EXTENDED = Union[Callable[[], Chain], BaseLanguageModel, Chain]
class InputFormatError(Exception):
@@ -323,7 +324,7 @@ async def _callbacks_initializer(
async def arun_on_examples(
examples: Iterator[Example],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY_EXTENDED,
*,
concurrency_level: int = 5,
num_repetitions: int = 1,
@@ -339,9 +340,12 @@ async def arun_on_examples(
Args:
examples: Examples to run the model or chain over.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
llm_or_chain_factory: Language model, Chain, or Chain constructor
You can pass chains without memory directly. For chains and
agents with memory, it is recommended to pass in a factory
function to construct a new instance of the chain for each example.
This permits independent evaluations on each example without
carrying over state.
concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
@@ -357,7 +361,8 @@ async def arun_on_examples(
Returns:
A dictionary mapping example ids to the model outputs.
"""
project_name = _get_project_name(project_name, llm_or_chain_factory, None)
llm_or_chain_factory_ = _wrap_chain(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory_, None)
client_ = client or LangChainPlusClient()
client_.create_project(project_name, mode="eval")
@@ -372,7 +377,7 @@ async def arun_on_examples(
"""Process a single example."""
result = await _arun_llm_or_chain(
example,
llm_or_chain_factory,
llm_or_chain_factory_,
num_repetitions,
tags=tags,
callbacks=callbacks,
@@ -498,9 +503,27 @@ def run_llm_or_chain(
return outputs
def _wrap_chain(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY_EXTENDED,
) -> MODEL_OR_CHAIN_FACTORY:
if isinstance(llm_or_chain_factory, Chain):
if llm_or_chain_factory.memory is not None:
raise ValueError(
f"Attempting to run a chain that uses memory: {llm_or_chain_factory}. "
"This will lead to issues reproducing results. To fix, pass in"
" a chain _factory_ to construct it instead:\n"
"def create_chain():\n"
" return MyChain(..., memory=MyMemory())\n"
"run_on_dataset(..., llm_or_chain_factory=create_chain)\n"
)
chain = llm_or_chain_factory
return lambda: chain
return llm_or_chain_factory
def run_on_examples(
examples: Iterator[Example],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY_EXTENDED,
*,
num_repetitions: int = 1,
project_name: Optional[str] = None,
@@ -515,9 +538,12 @@ def run_on_examples(
Args:
examples: Examples to run the model or chain over.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
llm_or_chain_factory: Language model, Chain, or Chain constructor
You can pass chains without memory directly. For chains and
agents with memory, it is recommended to pass in a factory
function to construct a new instance of the chain for each example.
This permits independent evaluations on each example without
carrying over state.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
intervals.
@@ -533,7 +559,8 @@ def run_on_examples(
A dictionary mapping example ids to the model outputs.
"""
results: Dict[str, Any] = {}
project_name = _get_project_name(project_name, llm_or_chain_factory, None)
llm_or_chain_factory_ = _wrap_chain(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory_, None)
client_ = client or LangChainPlusClient()
client_.create_project(project_name, mode="eval")
tracer = LangChainTracer(project_name=project_name)
@@ -544,7 +571,7 @@ def run_on_examples(
for i, example in enumerate(examples):
result = run_llm_or_chain(
example,
llm_or_chain_factory,
llm_or_chain_factory_,
num_repetitions,
tags=tags,
callbacks=callbacks,
@@ -586,7 +613,7 @@ def _get_project_name(
async def arun_on_dataset(
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY_EXTENDED,
*,
concurrency_level: int = 5,
num_repetitions: int = 1,
@@ -602,9 +629,12 @@ async def arun_on_dataset(
Args:
dataset_name: Name of the dataset to run the chain on.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
llm_or_chain_factory: Language model, Chain, or Chain constructor
You can pass chains without memory directly. For chains and
agents with memory, it is recommended to pass in a factory
function to construct a new instance of the chain for each example.
This permits independent evaluations on each example without
carrying over state.
concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
@@ -621,12 +651,13 @@ async def arun_on_dataset(
A dictionary containing the run's project name and the resulting model outputs.
"""
client_ = client or LangChainPlusClient()
project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name)
llm_or_chain_factory_ = _wrap_chain(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory_, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id))
results = await arun_on_examples(
examples,
llm_or_chain_factory,
llm_or_chain_factory_,
concurrency_level=concurrency_level,
num_repetitions=num_repetitions,
project_name=project_name,
@@ -643,7 +674,7 @@ async def arun_on_dataset(
def run_on_dataset(
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY_EXTENDED,
*,
num_repetitions: int = 1,
project_name: Optional[str] = None,
@@ -658,9 +689,12 @@ def run_on_dataset(
Args:
dataset_name: Name of the dataset to run the chain on.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
llm_or_chain_factory: Language model, Chain, or Chain constructor
You can pass chains without memory directly. For chains and
agents with memory, it is recommended to pass in a factory
function to construct a new instance of the chain for each example.
This permits independent evaluations on each example without
carrying over state.
concurrency_level: Number of async workers to run in parallel.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
@@ -677,12 +711,13 @@ def run_on_dataset(
A dictionary containing the run's project name and the resulting model outputs.
"""
client_ = client or LangChainPlusClient()
project_name = _get_project_name(project_name, llm_or_chain_factory, dataset_name)
llm_or_chain_factory_ = _wrap_chain(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory_, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id))
results = run_on_examples(
examples,
llm_or_chain_factory,
llm_or_chain_factory_,
num_repetitions=num_repetitions,
project_name=project_name,
verbose=verbose,

View File

@@ -14,9 +14,11 @@ from langchain.client.runner_utils import (
InputFormatError,
_get_messages,
_get_prompts,
_wrap_chain,
arun_on_dataset,
run_llm,
)
from tests.unit_tests.chains.test_base import FakeChain, FakeMemory
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -208,3 +210,18 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
for uuid_ in uuids
}
assert results["results"] == expected
def test_wrap_chain() -> None:
chain = FakeChain()
result = _wrap_chain(chain)
assert callable(result)
assert result().__class__ == FakeChain
result2 = _wrap_chain(result)
assert callable(result2)
assert result2().__class__ == FakeChain
with pytest.raises(ValueError):
chain_with_memory = FakeChain(memory=FakeMemory())
_wrap_chain(chain_with_memory)