mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 21:08:59 +00:00
Update Key Check (#8948)
In eval loop. It needn't be done unless you are creating the corresponding evaluators
This commit is contained in:
parent
539672a7fd
commit
90579021f8
@ -502,6 +502,18 @@ def _construct_run_evaluator(
|
|||||||
return run_evaluator
|
return run_evaluator
|
||||||
|
|
||||||
|
|
||||||
|
def _get_keys(
|
||||||
|
config: RunEvalConfig,
|
||||||
|
run_inputs: Optional[List[str]],
|
||||||
|
run_outputs: Optional[List[str]],
|
||||||
|
example_outputs: Optional[List[str]],
|
||||||
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||||
|
input_key = _determine_input_key(config, run_inputs)
|
||||||
|
prediction_key = _determine_prediction_key(config, run_outputs)
|
||||||
|
reference_key = _determine_reference_key(config, example_outputs)
|
||||||
|
return input_key, prediction_key, reference_key
|
||||||
|
|
||||||
|
|
||||||
def _load_run_evaluators(
|
def _load_run_evaluators(
|
||||||
config: RunEvalConfig,
|
config: RunEvalConfig,
|
||||||
run_type: str,
|
run_type: str,
|
||||||
@ -521,9 +533,13 @@ def _load_run_evaluators(
|
|||||||
"""
|
"""
|
||||||
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
|
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
|
||||||
run_evaluators = []
|
run_evaluators = []
|
||||||
input_key = _determine_input_key(config, run_inputs)
|
input_key, prediction_key, reference_key = None, None, None
|
||||||
prediction_key = _determine_prediction_key(config, run_outputs)
|
if config.evaluators or any(
|
||||||
reference_key = _determine_reference_key(config, example_outputs)
|
[isinstance(e, EvaluatorType) for e in config.evaluators]
|
||||||
|
):
|
||||||
|
input_key, prediction_key, reference_key = _get_keys(
|
||||||
|
config, run_inputs, run_outputs, example_outputs
|
||||||
|
)
|
||||||
for eval_config in config.evaluators:
|
for eval_config in config.evaluators:
|
||||||
run_evaluator = _construct_run_evaluator(
|
run_evaluator = _construct_run_evaluator(
|
||||||
eval_config,
|
eval_config,
|
||||||
@ -1074,15 +1090,15 @@ def _run_on_examples(
|
|||||||
A dictionary mapping example ids to the model outputs.
|
A dictionary mapping example ids to the model outputs.
|
||||||
"""
|
"""
|
||||||
results: Dict[str, Any] = {}
|
results: Dict[str, Any] = {}
|
||||||
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory)
|
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||||
project_name = _get_project_name(project_name, llm_or_chain_factory)
|
project_name = _get_project_name(project_name, wrapped_model)
|
||||||
tracer = LangChainTracer(
|
tracer = LangChainTracer(
|
||||||
project_name=project_name, client=client, use_threading=False
|
project_name=project_name, client=client, use_threading=False
|
||||||
)
|
)
|
||||||
run_evaluators, examples = _setup_evaluation(
|
run_evaluators, examples = _setup_evaluation(
|
||||||
llm_or_chain_factory, examples, evaluation, data_type
|
wrapped_model, examples, evaluation, data_type
|
||||||
)
|
)
|
||||||
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper)
|
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
|
||||||
evalution_handler = EvaluatorCallbackHandler(
|
evalution_handler = EvaluatorCallbackHandler(
|
||||||
evaluators=run_evaluators or [],
|
evaluators=run_evaluators or [],
|
||||||
client=client,
|
client=client,
|
||||||
@ -1091,7 +1107,7 @@ def _run_on_examples(
|
|||||||
for i, example in enumerate(examples):
|
for i, example in enumerate(examples):
|
||||||
result = _run_llm_or_chain(
|
result = _run_llm_or_chain(
|
||||||
example,
|
example,
|
||||||
llm_or_chain_factory,
|
wrapped_model,
|
||||||
num_repetitions,
|
num_repetitions,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
@ -1114,8 +1130,8 @@ def _prepare_eval_run(
|
|||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
project_name: Optional[str],
|
project_name: Optional[str],
|
||||||
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
|
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
|
||||||
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
||||||
project_name = _get_project_name(project_name, llm_or_chain_factory)
|
project_name = _get_project_name(project_name, wrapped_model)
|
||||||
try:
|
try:
|
||||||
project = client.create_project(project_name)
|
project = client.create_project(project_name)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -1130,7 +1146,7 @@ def _prepare_eval_run(
|
|||||||
)
|
)
|
||||||
dataset = client.read_dataset(dataset_name=dataset_name)
|
dataset = client.read_dataset(dataset_name=dataset_name)
|
||||||
examples = client.list_examples(dataset_id=str(dataset.id))
|
examples = client.list_examples(dataset_id=str(dataset.id))
|
||||||
return llm_or_chain_factory, project_name, dataset, examples
|
return wrapped_model, project_name, dataset, examples
|
||||||
|
|
||||||
|
|
||||||
async def arun_on_dataset(
|
async def arun_on_dataset(
|
||||||
@ -1256,13 +1272,13 @@ async def arun_on_dataset(
|
|||||||
evaluation=evaluation_config,
|
evaluation=evaluation_config,
|
||||||
)
|
)
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
|
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||||
client, dataset_name, llm_or_chain_factory, project_name
|
client, dataset_name, llm_or_chain_factory, project_name
|
||||||
)
|
)
|
||||||
results = await _arun_on_examples(
|
results = await _arun_on_examples(
|
||||||
client,
|
client,
|
||||||
examples,
|
examples,
|
||||||
llm_or_chain_factory,
|
wrapped_model,
|
||||||
concurrency_level=concurrency_level,
|
concurrency_level=concurrency_level,
|
||||||
num_repetitions=num_repetitions,
|
num_repetitions=num_repetitions,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
@ -1423,14 +1439,14 @@ def run_on_dataset(
|
|||||||
evaluation=evaluation_config,
|
evaluation=evaluation_config,
|
||||||
)
|
)
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
|
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||||
client, dataset_name, llm_or_chain_factory, project_name
|
client, dataset_name, llm_or_chain_factory, project_name
|
||||||
)
|
)
|
||||||
if concurrency_level in (0, 1):
|
if concurrency_level in (0, 1):
|
||||||
results = _run_on_examples(
|
results = _run_on_examples(
|
||||||
client,
|
client,
|
||||||
examples,
|
examples,
|
||||||
llm_or_chain_factory,
|
wrapped_model,
|
||||||
num_repetitions=num_repetitions,
|
num_repetitions=num_repetitions,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
@ -1444,7 +1460,7 @@ def run_on_dataset(
|
|||||||
coro = _arun_on_examples(
|
coro = _arun_on_examples(
|
||||||
client,
|
client,
|
||||||
examples,
|
examples,
|
||||||
llm_or_chain_factory,
|
wrapped_model,
|
||||||
concurrency_level=concurrency_level,
|
concurrency_level=concurrency_level,
|
||||||
num_repetitions=num_repetitions,
|
num_repetitions=num_repetitions,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user