mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 03:01:29 +00:00
Add progress bar + runner fixes (#10348)
- Add progress bar to eval runs - Use thread pool for concurrency - Update some error messages - Friendlier project name - Print out quantiles of the final stats Closes LS-902
This commit is contained in:
parent
0672533b3e
commit
46e9abdc75
@ -2,29 +2,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Union
|
||||
from uuid import UUID
|
||||
|
||||
import langsmith
|
||||
from langsmith import schemas as langsmith_schemas
|
||||
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.callbacks import manager
|
||||
from langchain.callbacks.tracers import langchain as langchain_tracer
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.langchain import _get_client
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TRACERS: List[EvaluatorCallbackHandler] = []
|
||||
|
||||
|
||||
def wait_for_all_evaluators() -> None:
|
||||
"""Wait for all tracers to finish."""
|
||||
global _TRACERS
|
||||
for tracer in _TRACERS:
|
||||
tracer.wait_for_futures()
|
||||
|
||||
|
||||
class EvaluatorCallbackHandler(BaseTracer):
|
||||
"""A tracer that runs a run evaluator whenever a run is persisted.
|
||||
@ -79,17 +70,13 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
self.example_id = (
|
||||
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||
)
|
||||
self.client = client or _get_client()
|
||||
self.client = client or langchain_tracer.get_client()
|
||||
self.evaluators = evaluators
|
||||
self.executor = ThreadPoolExecutor(
|
||||
max_workers=max(max_workers or len(evaluators), 1)
|
||||
)
|
||||
self.max_workers = max_workers or len(evaluators)
|
||||
self.futures: Set[Future] = set()
|
||||
self.skip_unfinished = skip_unfinished
|
||||
self.project_name = project_name
|
||||
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {}
|
||||
global _TRACERS
|
||||
_TRACERS.append(self)
|
||||
|
||||
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
|
||||
"""Evaluate the run in the project.
|
||||
@ -105,7 +92,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
try:
|
||||
if self.project_name is None:
|
||||
feedback = self.client.evaluate_run(run, evaluator)
|
||||
with tracing_v2_enabled(
|
||||
with manager.tracing_v2_enabled(
|
||||
project_name=self.project_name, tags=["eval"], client=self.client
|
||||
):
|
||||
feedback = self.client.evaluate_run(run, evaluator)
|
||||
@ -133,14 +120,15 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
return
|
||||
run_ = run.copy()
|
||||
run_.reference_example_id = self.example_id
|
||||
for evaluator in self.evaluators:
|
||||
self.futures.add(
|
||||
self.executor.submit(self._evaluate_in_project, run_, evaluator)
|
||||
if self.max_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
list(
|
||||
executor.map(
|
||||
self._evaluate_in_project,
|
||||
[run_ for _ in range(len(self.evaluators))],
|
||||
self.evaluators,
|
||||
)
|
||||
|
||||
def wait_for_futures(self) -> None:
|
||||
"""Wait for all futures to complete."""
|
||||
futures = list(self.futures)
|
||||
wait(futures)
|
||||
for future in futures:
|
||||
self.futures.remove(future)
|
||||
)
|
||||
else:
|
||||
for evaluator in self.evaluators:
|
||||
self._evaluate_in_project(run_, evaluator)
|
||||
|
@ -42,7 +42,7 @@ def wait_for_all_tracers() -> None:
|
||||
tracer.wait_for_futures()
|
||||
|
||||
|
||||
def _get_client() -> Client:
|
||||
def get_client() -> Client:
|
||||
"""Get the client."""
|
||||
global _CLIENT
|
||||
if _CLIENT is None:
|
||||
@ -83,7 +83,7 @@ class LangChainTracer(BaseTracer):
|
||||
_EXECUTORS.append(self.executor)
|
||||
else:
|
||||
self.executor = None
|
||||
self.client = client or _get_client()
|
||||
self.client = client or get_client()
|
||||
self._futures: Set[Future] = set()
|
||||
self.tags = tags or []
|
||||
global _TRACERS
|
||||
|
729
libs/langchain/langchain/smith/evaluation/name_generation.py
Normal file
729
libs/langchain/langchain/smith/evaluation/name_generation.py
Normal file
@ -0,0 +1,729 @@
|
||||
import random
|
||||
|
||||
adjectives = [
|
||||
"abandoned",
|
||||
"aching",
|
||||
"advanced",
|
||||
"ample",
|
||||
"artistic",
|
||||
"back",
|
||||
"best",
|
||||
"bold",
|
||||
"brief",
|
||||
"clear",
|
||||
"cold",
|
||||
"complicated",
|
||||
"cooked",
|
||||
"crazy",
|
||||
"crushing",
|
||||
"damp",
|
||||
"dear",
|
||||
"definite",
|
||||
"dependable",
|
||||
"diligent",
|
||||
"drab",
|
||||
"earnest",
|
||||
"elderly",
|
||||
"enchanted",
|
||||
"essential",
|
||||
"excellent",
|
||||
"extraneous",
|
||||
"fixed",
|
||||
"flowery",
|
||||
"formal",
|
||||
"fresh",
|
||||
"frosty",
|
||||
"giving",
|
||||
"glossy",
|
||||
"healthy",
|
||||
"helpful",
|
||||
"impressionable",
|
||||
"kind",
|
||||
"large",
|
||||
"left",
|
||||
"long",
|
||||
"loyal",
|
||||
"mealy",
|
||||
"memorable",
|
||||
"monthly",
|
||||
"new",
|
||||
"notable",
|
||||
"only",
|
||||
"ordinary",
|
||||
"passionate",
|
||||
"perfect",
|
||||
"pertinent",
|
||||
"proper",
|
||||
"puzzled",
|
||||
"reflecting",
|
||||
"respectful",
|
||||
"roasted",
|
||||
"scholarly",
|
||||
"shiny",
|
||||
"slight",
|
||||
"sparkling",
|
||||
"spotless",
|
||||
"stupendous",
|
||||
"sunny",
|
||||
"tart",
|
||||
"terrific",
|
||||
"timely",
|
||||
"unique",
|
||||
"upbeat",
|
||||
"vacant",
|
||||
"virtual",
|
||||
"warm",
|
||||
"weary",
|
||||
"whispered",
|
||||
"worthwhile",
|
||||
"yellow",
|
||||
]
|
||||
|
||||
nouns = [
|
||||
"account",
|
||||
"acknowledgment",
|
||||
"address",
|
||||
"advertising",
|
||||
"airplane",
|
||||
"animal",
|
||||
"appointment",
|
||||
"arrival",
|
||||
"artist",
|
||||
"attachment",
|
||||
"attitude",
|
||||
"availability",
|
||||
"backpack",
|
||||
"bag",
|
||||
"balance",
|
||||
"bass",
|
||||
"bean",
|
||||
"beauty",
|
||||
"bibliography",
|
||||
"bill",
|
||||
"bite",
|
||||
"blossom",
|
||||
"boat",
|
||||
"book",
|
||||
"box",
|
||||
"boy",
|
||||
"bread",
|
||||
"bridge",
|
||||
"broccoli",
|
||||
"building",
|
||||
"butter",
|
||||
"button",
|
||||
"cabbage",
|
||||
"cake",
|
||||
"camera",
|
||||
"camp",
|
||||
"candle",
|
||||
"candy",
|
||||
"canvas",
|
||||
"car",
|
||||
"card",
|
||||
"carrot",
|
||||
"cart",
|
||||
"case",
|
||||
"cat",
|
||||
"chain",
|
||||
"chair",
|
||||
"chalk",
|
||||
"chance",
|
||||
"change",
|
||||
"channel",
|
||||
"character",
|
||||
"charge",
|
||||
"charm",
|
||||
"chart",
|
||||
"check",
|
||||
"cheek",
|
||||
"cheese",
|
||||
"chef",
|
||||
"cherry",
|
||||
"chicken",
|
||||
"child",
|
||||
"church",
|
||||
"circle",
|
||||
"class",
|
||||
"clay",
|
||||
"click",
|
||||
"clock",
|
||||
"cloth",
|
||||
"cloud",
|
||||
"clove",
|
||||
"club",
|
||||
"coach",
|
||||
"coal",
|
||||
"coast",
|
||||
"coat",
|
||||
"cod",
|
||||
"coffee",
|
||||
"collar",
|
||||
"color",
|
||||
"comb",
|
||||
"comfort",
|
||||
"comic",
|
||||
"committee",
|
||||
"community",
|
||||
"company",
|
||||
"comparison",
|
||||
"competition",
|
||||
"condition",
|
||||
"connection",
|
||||
"control",
|
||||
"cook",
|
||||
"copper",
|
||||
"copy",
|
||||
"corn",
|
||||
"cough",
|
||||
"country",
|
||||
"cover",
|
||||
"crate",
|
||||
"crayon",
|
||||
"cream",
|
||||
"creator",
|
||||
"crew",
|
||||
"crown",
|
||||
"current",
|
||||
"curtain",
|
||||
"curve",
|
||||
"cushion",
|
||||
"dad",
|
||||
"daughter",
|
||||
"day",
|
||||
"death",
|
||||
"debt",
|
||||
"decision",
|
||||
"deer",
|
||||
"degree",
|
||||
"design",
|
||||
"desire",
|
||||
"desk",
|
||||
"detail",
|
||||
"development",
|
||||
"digestion",
|
||||
"dime",
|
||||
"dinner",
|
||||
"direction",
|
||||
"dirt",
|
||||
"discovery",
|
||||
"discussion",
|
||||
"disease",
|
||||
"disgust",
|
||||
"distance",
|
||||
"distribution",
|
||||
"division",
|
||||
"doctor",
|
||||
"dog",
|
||||
"door",
|
||||
"drain",
|
||||
"drawer",
|
||||
"dress",
|
||||
"drink",
|
||||
"driving",
|
||||
"dust",
|
||||
"ear",
|
||||
"earth",
|
||||
"edge",
|
||||
"education",
|
||||
"effect",
|
||||
"egg",
|
||||
"end",
|
||||
"energy",
|
||||
"engine",
|
||||
"error",
|
||||
"event",
|
||||
"example",
|
||||
"exchange",
|
||||
"existence",
|
||||
"expansion",
|
||||
"experience",
|
||||
"expert",
|
||||
"eye",
|
||||
"face",
|
||||
"fact",
|
||||
"fall",
|
||||
"family",
|
||||
"farm",
|
||||
"father",
|
||||
"fear",
|
||||
"feeling",
|
||||
"field",
|
||||
"finger",
|
||||
"fire",
|
||||
"fish",
|
||||
"flag",
|
||||
"flight",
|
||||
"floor",
|
||||
"flower",
|
||||
"fold",
|
||||
"food",
|
||||
"football",
|
||||
"force",
|
||||
"form",
|
||||
"frame",
|
||||
"friend",
|
||||
"frog",
|
||||
"fruit",
|
||||
"fuel",
|
||||
"furniture",
|
||||
"game",
|
||||
"garden",
|
||||
"gate",
|
||||
"girl",
|
||||
"glass",
|
||||
"glove",
|
||||
"goat",
|
||||
"gold",
|
||||
"government",
|
||||
"grade",
|
||||
"grain",
|
||||
"grass",
|
||||
"green",
|
||||
"grip",
|
||||
"group",
|
||||
"growth",
|
||||
"guide",
|
||||
"guitar",
|
||||
"hair",
|
||||
"hall",
|
||||
"hand",
|
||||
"harbor",
|
||||
"harmony",
|
||||
"hat",
|
||||
"head",
|
||||
"health",
|
||||
"heart",
|
||||
"heat",
|
||||
"hill",
|
||||
"history",
|
||||
"hobbies",
|
||||
"hole",
|
||||
"hope",
|
||||
"horn",
|
||||
"horse",
|
||||
"hospital",
|
||||
"hour",
|
||||
"house",
|
||||
"humor",
|
||||
"idea",
|
||||
"impulse",
|
||||
"income",
|
||||
"increase",
|
||||
"industry",
|
||||
"ink",
|
||||
"insect",
|
||||
"instrument",
|
||||
"insurance",
|
||||
"interest",
|
||||
"invention",
|
||||
"iron",
|
||||
"island",
|
||||
"jelly",
|
||||
"jet",
|
||||
"jewel",
|
||||
"join",
|
||||
"judge",
|
||||
"juice",
|
||||
"jump",
|
||||
"kettle",
|
||||
"key",
|
||||
"kick",
|
||||
"kiss",
|
||||
"kitten",
|
||||
"knee",
|
||||
"knife",
|
||||
"knowledge",
|
||||
"land",
|
||||
"language",
|
||||
"laugh",
|
||||
"law",
|
||||
"lead",
|
||||
"learning",
|
||||
"leather",
|
||||
"leg",
|
||||
"lettuce",
|
||||
"level",
|
||||
"library",
|
||||
"lift",
|
||||
"light",
|
||||
"limit",
|
||||
"line",
|
||||
"linen",
|
||||
"lip",
|
||||
"liquid",
|
||||
"list",
|
||||
"look",
|
||||
"loss",
|
||||
"love",
|
||||
"lunch",
|
||||
"machine",
|
||||
"man",
|
||||
"manager",
|
||||
"map",
|
||||
"marble",
|
||||
"mark",
|
||||
"market",
|
||||
"mass",
|
||||
"match",
|
||||
"meal",
|
||||
"measure",
|
||||
"meat",
|
||||
"meeting",
|
||||
"memory",
|
||||
"metal",
|
||||
"middle",
|
||||
"milk",
|
||||
"mind",
|
||||
"mine",
|
||||
"minute",
|
||||
"mist",
|
||||
"mitten",
|
||||
"mom",
|
||||
"money",
|
||||
"monkey",
|
||||
"month",
|
||||
"moon",
|
||||
"morning",
|
||||
"mother",
|
||||
"motion",
|
||||
"mountain",
|
||||
"mouth",
|
||||
"muscle",
|
||||
"music",
|
||||
"nail",
|
||||
"name",
|
||||
"nation",
|
||||
"neck",
|
||||
"need",
|
||||
"news",
|
||||
"night",
|
||||
"noise",
|
||||
"note",
|
||||
"number",
|
||||
"nut",
|
||||
"observation",
|
||||
"offer",
|
||||
"oil",
|
||||
"operation",
|
||||
"opinion",
|
||||
"orange",
|
||||
"order",
|
||||
"organization",
|
||||
"ornament",
|
||||
"oven",
|
||||
"page",
|
||||
"pail",
|
||||
"pain",
|
||||
"paint",
|
||||
"pan",
|
||||
"pancake",
|
||||
"paper",
|
||||
"parcel",
|
||||
"parent",
|
||||
"part",
|
||||
"passenger",
|
||||
"paste",
|
||||
"payment",
|
||||
"peace",
|
||||
"pear",
|
||||
"pen",
|
||||
"pencil",
|
||||
"person",
|
||||
"pest",
|
||||
"pet",
|
||||
"picture",
|
||||
"pie",
|
||||
"pin",
|
||||
"pipe",
|
||||
"pizza",
|
||||
"place",
|
||||
"plane",
|
||||
"plant",
|
||||
"plastic",
|
||||
"plate",
|
||||
"play",
|
||||
"pleasure",
|
||||
"plot",
|
||||
"plough",
|
||||
"pocket",
|
||||
"point",
|
||||
"poison",
|
||||
"police",
|
||||
"pollution",
|
||||
"popcorn",
|
||||
"porter",
|
||||
"position",
|
||||
"pot",
|
||||
"potato",
|
||||
"powder",
|
||||
"power",
|
||||
"price",
|
||||
"print",
|
||||
"process",
|
||||
"produce",
|
||||
"product",
|
||||
"profit",
|
||||
"property",
|
||||
"prose",
|
||||
"protest",
|
||||
"pull",
|
||||
"pump",
|
||||
"punishment",
|
||||
"purpose",
|
||||
"push",
|
||||
"quarter",
|
||||
"question",
|
||||
"quiet",
|
||||
"quill",
|
||||
"quilt",
|
||||
"quince",
|
||||
"rabbit",
|
||||
"rail",
|
||||
"rain",
|
||||
"range",
|
||||
"rat",
|
||||
"rate",
|
||||
"ray",
|
||||
"reaction",
|
||||
"reading",
|
||||
"reason",
|
||||
"record",
|
||||
"regret",
|
||||
"relation",
|
||||
"religion",
|
||||
"representative",
|
||||
"request",
|
||||
"respect",
|
||||
"rest",
|
||||
"reward",
|
||||
"rhythm",
|
||||
"rice",
|
||||
"river",
|
||||
"road",
|
||||
"roll",
|
||||
"room",
|
||||
"root",
|
||||
"rose",
|
||||
"route",
|
||||
"rub",
|
||||
"rule",
|
||||
"run",
|
||||
"sack",
|
||||
"sail",
|
||||
"salt",
|
||||
"sand",
|
||||
"scale",
|
||||
"scarecrow",
|
||||
"scarf",
|
||||
"scene",
|
||||
"scent",
|
||||
"school",
|
||||
"science",
|
||||
"scissors",
|
||||
"screw",
|
||||
"sea",
|
||||
"seat",
|
||||
"secretary",
|
||||
"seed",
|
||||
"selection",
|
||||
"self",
|
||||
"sense",
|
||||
"servant",
|
||||
"shade",
|
||||
"shake",
|
||||
"shame",
|
||||
"shape",
|
||||
"sheep",
|
||||
"sheet",
|
||||
"shelf",
|
||||
"ship",
|
||||
"shirt",
|
||||
"shock",
|
||||
"shoe",
|
||||
"shop",
|
||||
"show",
|
||||
"side",
|
||||
"sign",
|
||||
"silk",
|
||||
"sink",
|
||||
"sister",
|
||||
"size",
|
||||
"sky",
|
||||
"slave",
|
||||
"sleep",
|
||||
"smash",
|
||||
"smell",
|
||||
"smile",
|
||||
"smoke",
|
||||
"snail",
|
||||
"snake",
|
||||
"sneeze",
|
||||
"snow",
|
||||
"soap",
|
||||
"society",
|
||||
"sock",
|
||||
"soda",
|
||||
"sofa",
|
||||
"son",
|
||||
"song",
|
||||
"sort",
|
||||
"sound",
|
||||
"soup",
|
||||
"space",
|
||||
"spark",
|
||||
"speed",
|
||||
"sponge",
|
||||
"spoon",
|
||||
"spray",
|
||||
"spring",
|
||||
"spy",
|
||||
"square",
|
||||
"stamp",
|
||||
"star",
|
||||
"start",
|
||||
"statement",
|
||||
"station",
|
||||
"steam",
|
||||
"steel",
|
||||
"stem",
|
||||
"step",
|
||||
"stew",
|
||||
"stick",
|
||||
"stitch",
|
||||
"stocking",
|
||||
"stomach",
|
||||
"stone",
|
||||
"stop",
|
||||
"store",
|
||||
"story",
|
||||
"stove",
|
||||
"stranger",
|
||||
"straw",
|
||||
"stream",
|
||||
"street",
|
||||
"stretch",
|
||||
"string",
|
||||
"structure",
|
||||
"substance",
|
||||
"sugar",
|
||||
"suggestion",
|
||||
"suit",
|
||||
"summer",
|
||||
"sun",
|
||||
"support",
|
||||
"surprise",
|
||||
"sweater",
|
||||
"swim",
|
||||
"system",
|
||||
"table",
|
||||
"tail",
|
||||
"talk",
|
||||
"tank",
|
||||
"taste",
|
||||
"tax",
|
||||
"tea",
|
||||
"teaching",
|
||||
"team",
|
||||
"tendency",
|
||||
"test",
|
||||
"texture",
|
||||
"theory",
|
||||
"thing",
|
||||
"thought",
|
||||
"thread",
|
||||
"throat",
|
||||
"thumb",
|
||||
"thunder",
|
||||
"ticket",
|
||||
"time",
|
||||
"tin",
|
||||
"title",
|
||||
"toad",
|
||||
"toe",
|
||||
"tooth",
|
||||
"toothpaste",
|
||||
"touch",
|
||||
"town",
|
||||
"toy",
|
||||
"trade",
|
||||
"train",
|
||||
"transport",
|
||||
"tray",
|
||||
"treatment",
|
||||
"tree",
|
||||
"trick",
|
||||
"trip",
|
||||
"trouble",
|
||||
"trousers",
|
||||
"truck",
|
||||
"tub",
|
||||
"turkey",
|
||||
"turn",
|
||||
"twist",
|
||||
"umbrella",
|
||||
"uncle",
|
||||
"underwear",
|
||||
"unit",
|
||||
"use",
|
||||
"vacation",
|
||||
"value",
|
||||
"van",
|
||||
"vase",
|
||||
"vegetable",
|
||||
"veil",
|
||||
"vein",
|
||||
"verse",
|
||||
"vessel",
|
||||
"view",
|
||||
"visitor",
|
||||
"voice",
|
||||
"volcano",
|
||||
"walk",
|
||||
"wall",
|
||||
"war",
|
||||
"wash",
|
||||
"waste",
|
||||
"watch",
|
||||
"water",
|
||||
"wave",
|
||||
"wax",
|
||||
"way",
|
||||
"wealth",
|
||||
"weather",
|
||||
"week",
|
||||
"weight",
|
||||
"wheel",
|
||||
"whip",
|
||||
"whistle",
|
||||
"window",
|
||||
"wine",
|
||||
"wing",
|
||||
"winter",
|
||||
"wire",
|
||||
"wish",
|
||||
"woman",
|
||||
"wood",
|
||||
"wool",
|
||||
"word",
|
||||
"work",
|
||||
"worm",
|
||||
"wound",
|
||||
"wrist",
|
||||
"writer",
|
||||
"yard",
|
||||
"yoke",
|
||||
"zebra",
|
||||
"zinc",
|
||||
"zipper",
|
||||
"zone",
|
||||
]
|
||||
|
||||
|
||||
def random_name(prefix: str = "test") -> str:
|
||||
"""Generate a random name."""
|
||||
adjective = random.choice(adjectives)
|
||||
noun = random.choice(nouns)
|
||||
number = random.randint(1, 100)
|
||||
|
||||
return f"{prefix}-{adjective}-{noun}-{number}"
|
82
libs/langchain/langchain/smith/evaluation/progress.py
Normal file
82
libs/langchain/langchain/smith/evaluation/progress.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""A simple progress bar for the console."""
|
||||
import threading
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks import base as base_callbacks
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output import LLMResult
|
||||
|
||||
|
||||
class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
|
||||
"""A simple progress bar for the console."""
|
||||
|
||||
def __init__(self, total: int, ncols: int = 50, **kwargs: Any):
|
||||
"""Initialize the progress bar.
|
||||
|
||||
Args:
|
||||
total: int, the total number of items to be processed.
|
||||
ncols: int, the character width of the progress bar.
|
||||
"""
|
||||
self.total = total
|
||||
self.ncols = ncols
|
||||
self.counter = 0
|
||||
self.lock = threading.Lock()
|
||||
self._print_bar()
|
||||
|
||||
def increment(self) -> None:
|
||||
"""Increment the counter and update the progress bar."""
|
||||
with self.lock:
|
||||
self.counter += 1
|
||||
self._print_bar()
|
||||
|
||||
def _print_bar(self) -> None:
|
||||
"""Print the progress bar to the console."""
|
||||
progress = self.counter / self.total
|
||||
arrow = "-" * int(round(progress * self.ncols) - 1) + ">"
|
||||
spaces = " " * (self.ncols - len(arrow))
|
||||
print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end="")
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if parent_run_id is None:
|
||||
self.increment()
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if parent_run_id is None:
|
||||
self.increment()
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if parent_run_id is None:
|
||||
self.increment()
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if parent_run_id is None:
|
||||
self.increment()
|
@ -2,21 +2,16 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import uuid
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -24,16 +19,13 @@ from typing import (
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from langsmith import Client, RunEvaluator
|
||||
from langsmith.schemas import Dataset, DataType, Example
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.evaluation.loading import load_evaluator
|
||||
from langchain.evaluation.schema import EvaluatorType, StringEvaluator
|
||||
@ -41,8 +33,11 @@ from langchain.schema import ChatResult, LLMResult
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage, messages_from_dict
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableLambda
|
||||
from langchain.smith.evaluation.config import EvalConfig, RunEvalConfig
|
||||
from langchain.smith.evaluation.string_run_evaluator import StringRunEvaluatorChain
|
||||
from langchain.schema.runnable import config as runnable_config
|
||||
from langchain.schema.runnable import utils as runnable_utils
|
||||
from langchain.smith import evaluation as smith_eval
|
||||
from langchain.smith.evaluation import config as smith_eval_config
|
||||
from langchain.smith.evaluation import name_generation, progress
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
@ -69,6 +64,26 @@ class InputFormatError(Exception):
|
||||
class TestResult(dict):
|
||||
"""A dictionary of the results of a single test run."""
|
||||
|
||||
def get_aggregate_feedback(
|
||||
self, quantiles: Optional[Sequence[float]] = None
|
||||
) -> pd.DataFrame:
|
||||
"""Return quantiles for the feedback scores.
|
||||
|
||||
This method calculates and prints the quantiles for the feedback scores
|
||||
across all feedback keys.
|
||||
|
||||
Returns:
|
||||
A DataFrame containing the quantiles for each feedback key.
|
||||
"""
|
||||
df = self.to_dataframe()
|
||||
feedback_cols = [
|
||||
col for col in df.columns if col not in ["input", "output", "reference"]
|
||||
]
|
||||
_quantiles = df[feedback_cols].quantile(
|
||||
quantiles or [0.25, 0.5, 0.75], numeric_only=True
|
||||
)
|
||||
return _quantiles.transpose()
|
||||
|
||||
def to_dataframe(self) -> pd.DataFrame:
|
||||
"""Convert the results to a dataframe."""
|
||||
try:
|
||||
@ -83,27 +98,19 @@ class TestResult(dict):
|
||||
records = []
|
||||
for example_id, result in self["results"].items():
|
||||
feedback = result["feedback"]
|
||||
records.append(
|
||||
{**{f.key: f.score for f in feedback}, "output": result["output"]}
|
||||
)
|
||||
r = {
|
||||
**{f.key: f.score for f in feedback},
|
||||
"input": result["input"],
|
||||
"output": result["output"],
|
||||
}
|
||||
if "reference" in result:
|
||||
r["reference"] = result["reference"]
|
||||
records.append(r)
|
||||
indices.append(example_id)
|
||||
|
||||
return pd.DataFrame(records, index=indices)
|
||||
|
||||
|
||||
def _get_eval_project_url(api_url: str, project_id: str) -> str:
|
||||
"""Get the project url from the api url."""
|
||||
parsed = urlparse(api_url)
|
||||
hostname = parsed.hostname or ""
|
||||
if "api." in hostname:
|
||||
hostname = hostname.replace("api.", "", 1)
|
||||
if "localhost" in hostname:
|
||||
# Remove the port
|
||||
hostname = "localhost"
|
||||
url = urlunparse(parsed._replace(netloc=hostname))
|
||||
return f"{url}/projects/p/{project_id}?eval=true"
|
||||
|
||||
|
||||
def _wrap_in_chain_factory(
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
dataset_name: str = "<my_dataset>",
|
||||
@ -172,15 +179,6 @@ def _wrap_in_chain_factory(
|
||||
return llm_or_chain_factory
|
||||
|
||||
|
||||
def _first_example(examples: Iterator[Example]) -> Tuple[Example, Iterator[Example]]:
|
||||
"""Get the first example while chaining it back and preserving the iterator."""
|
||||
try:
|
||||
example: Example = next(examples)
|
||||
except StopIteration:
|
||||
raise ValueError("No examples provided.")
|
||||
return example, itertools.chain([example], examples)
|
||||
|
||||
|
||||
def _get_prompt(inputs: Dict[str, Any]) -> str:
|
||||
"""Get prompt from inputs.
|
||||
|
||||
@ -277,31 +275,7 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
|
||||
)
|
||||
|
||||
|
||||
def _get_project_name(
|
||||
project_name: Optional[str],
|
||||
llm_or_chain_factory: MCF,
|
||||
) -> str:
|
||||
"""
|
||||
Get the project name.
|
||||
|
||||
Args:
|
||||
project_name: The project name if manually specified.
|
||||
llm_or_chain_factory: The Chain or language model constructor.
|
||||
|
||||
Returns:
|
||||
The project name.
|
||||
"""
|
||||
if project_name is not None:
|
||||
return project_name
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
model_name = llm_or_chain_factory.__class__.__name__
|
||||
else:
|
||||
model_name = llm_or_chain_factory().__class__.__name__
|
||||
hex = uuid.uuid4().hex
|
||||
return f"{hex}-{model_name}"
|
||||
|
||||
|
||||
## Shared Validation Utilities
|
||||
## Shared data validation utilities
|
||||
def _validate_example_inputs_for_language_model(
|
||||
first_example: Example,
|
||||
input_mapper: Optional[Callable[[Dict], Any]],
|
||||
@ -373,22 +347,20 @@ def _validate_example_inputs_for_chain(
|
||||
|
||||
|
||||
def _validate_example_inputs(
|
||||
examples: Iterator[Example],
|
||||
example: Example,
|
||||
llm_or_chain_factory: MCF,
|
||||
input_mapper: Optional[Callable[[Dict], Any]],
|
||||
) -> Iterator[Example]:
|
||||
) -> None:
|
||||
"""Validate that the example inputs are valid for the model."""
|
||||
first_example, examples = _first_example(examples)
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
_validate_example_inputs_for_language_model(first_example, input_mapper)
|
||||
_validate_example_inputs_for_language_model(example, input_mapper)
|
||||
else:
|
||||
chain = llm_or_chain_factory()
|
||||
if isinstance(chain, Chain):
|
||||
# Otherwise it's a runnable
|
||||
_validate_example_inputs_for_chain(first_example, chain, input_mapper)
|
||||
_validate_example_inputs_for_chain(example, chain, input_mapper)
|
||||
elif isinstance(chain, Runnable):
|
||||
logger.debug(f"Skipping input validation for {chain}")
|
||||
return examples
|
||||
|
||||
|
||||
## Shared Evaluator Setup Utilities
|
||||
@ -396,13 +368,12 @@ def _validate_example_inputs(
|
||||
|
||||
def _setup_evaluation(
|
||||
llm_or_chain_factory: MCF,
|
||||
examples: Iterator[Example],
|
||||
evaluation: Optional[RunEvalConfig],
|
||||
examples: List[Example],
|
||||
evaluation: Optional[smith_eval.RunEvalConfig],
|
||||
data_type: DataType,
|
||||
) -> Tuple[Optional[List[RunEvaluator]], Iterator[Example]]:
|
||||
) -> Optional[List[RunEvaluator]]:
|
||||
"""Configure the evaluators to run on the results of the chain."""
|
||||
if evaluation:
|
||||
first_example, examples = _first_example(examples)
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
run_inputs, run_outputs = None, None
|
||||
run_type = "llm"
|
||||
@ -422,18 +393,18 @@ def _setup_evaluation(
|
||||
evaluation,
|
||||
run_type,
|
||||
data_type,
|
||||
list(first_example.outputs) if first_example.outputs else None,
|
||||
list(examples[0].outputs) if examples[0].outputs else None,
|
||||
run_inputs,
|
||||
run_outputs,
|
||||
)
|
||||
else:
|
||||
# TODO: Create a default helpfulness evaluator
|
||||
run_evaluators = None
|
||||
return run_evaluators, examples
|
||||
return run_evaluators
|
||||
|
||||
|
||||
def _determine_input_key(
|
||||
config: RunEvalConfig,
|
||||
config: smith_eval.RunEvalConfig,
|
||||
run_inputs: Optional[List[str]],
|
||||
) -> Optional[str]:
|
||||
input_key = None
|
||||
@ -452,7 +423,7 @@ def _determine_input_key(
|
||||
|
||||
|
||||
def _determine_prediction_key(
|
||||
config: RunEvalConfig,
|
||||
config: smith_eval.RunEvalConfig,
|
||||
run_outputs: Optional[List[str]],
|
||||
) -> Optional[str]:
|
||||
prediction_key = None
|
||||
@ -473,7 +444,7 @@ def _determine_prediction_key(
|
||||
|
||||
|
||||
def _determine_reference_key(
|
||||
config: RunEvalConfig,
|
||||
config: smith_eval.RunEvalConfig,
|
||||
example_outputs: Optional[List[str]],
|
||||
) -> Optional[str]:
|
||||
if config.reference_key:
|
||||
@ -491,7 +462,7 @@ def _determine_reference_key(
|
||||
|
||||
|
||||
def _construct_run_evaluator(
|
||||
eval_config: Union[EvaluatorType, str, EvalConfig],
|
||||
eval_config: Union[EvaluatorType, str, smith_eval_config.EvalConfig],
|
||||
eval_llm: Optional[BaseLanguageModel],
|
||||
run_type: str,
|
||||
data_type: DataType,
|
||||
@ -513,11 +484,11 @@ def _construct_run_evaluator(
|
||||
if isinstance(evaluator_, StringEvaluator):
|
||||
if evaluator_.requires_reference and reference_key is None:
|
||||
raise ValueError(
|
||||
f"Must specify reference_key in RunEvalConfig to use"
|
||||
f"Must specify reference_key in smith_eval.RunEvalConfig to use"
|
||||
f" evaluator of type {eval_type_tag} with"
|
||||
f" dataset with multiple output keys: {example_outputs}."
|
||||
)
|
||||
run_evaluator = StringRunEvaluatorChain.from_run_and_data_type(
|
||||
run_evaluator = smith_eval.StringRunEvaluatorChain.from_run_and_data_type(
|
||||
evaluator_,
|
||||
run_type,
|
||||
data_type,
|
||||
@ -534,7 +505,7 @@ def _construct_run_evaluator(
|
||||
|
||||
|
||||
def _get_keys(
|
||||
config: RunEvalConfig,
|
||||
config: smith_eval.RunEvalConfig,
|
||||
run_inputs: Optional[List[str]],
|
||||
run_outputs: Optional[List[str]],
|
||||
example_outputs: Optional[List[str]],
|
||||
@ -546,7 +517,7 @@ def _get_keys(
|
||||
|
||||
|
||||
def _load_run_evaluators(
|
||||
config: RunEvalConfig,
|
||||
config: smith_eval.RunEvalConfig,
|
||||
run_type: str,
|
||||
data_type: DataType,
|
||||
example_outputs: Optional[List[str]],
|
||||
@ -593,7 +564,7 @@ def _load_run_evaluators(
|
||||
run_evaluators.append(custom_evaluator)
|
||||
elif isinstance(custom_evaluator, StringEvaluator):
|
||||
run_evaluators.append(
|
||||
StringRunEvaluatorChain.from_run_and_data_type(
|
||||
smith_eval.StringRunEvaluatorChain.from_run_and_data_type(
|
||||
custom_evaluator,
|
||||
run_type,
|
||||
data_type,
|
||||
@ -694,10 +665,9 @@ async def _arun_chain(
|
||||
|
||||
async def _arun_llm_or_chain(
|
||||
example: Example,
|
||||
llm_or_chain_factory: MCF,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
llm_or_chain_factory: MCF,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[dict, str, LLMResult, ChatResult]:
|
||||
"""Asynchronously run the Chain or language model.
|
||||
@ -712,15 +682,6 @@ async def _arun_llm_or_chain(
|
||||
Returns:
|
||||
A list of outputs.
|
||||
"""
|
||||
if callbacks:
|
||||
previous_example_ids = [
|
||||
getattr(tracer, "example_id", None) for tracer in callbacks
|
||||
]
|
||||
for tracer in callbacks:
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example.id
|
||||
else:
|
||||
previous_example_ids = None
|
||||
chain_or_llm = (
|
||||
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
|
||||
)
|
||||
@ -730,8 +691,8 @@ async def _arun_llm_or_chain(
|
||||
output: Any = await _arun_llm(
|
||||
llm_or_chain_factory,
|
||||
example.inputs,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
tags=config["tags"],
|
||||
callbacks=config["callbacks"],
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
else:
|
||||
@ -739,200 +700,21 @@ async def _arun_llm_or_chain(
|
||||
output = await _arun_chain(
|
||||
chain,
|
||||
example.inputs,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
tags=config["tags"],
|
||||
callbacks=config["callbacks"],
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
result = output
|
||||
except Exception as e:
|
||||
logger.warning(f"{chain_or_llm} failed for example {example.id}. Error: {e}")
|
||||
result = {"Error": str(e)}
|
||||
if callbacks and previous_example_ids:
|
||||
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example_id
|
||||
logger.warning(
|
||||
f"{chain_or_llm} failed for example {example.id} "
|
||||
f"with inputs {example.inputs}"
|
||||
f"\n{repr(e)}"
|
||||
)
|
||||
result = {"Error": repr(e)}
|
||||
return result
|
||||
|
||||
|
||||
async def _gather_with_concurrency(
|
||||
n: int,
|
||||
initializer: Callable[[], Coroutine[Any, Any, Any]],
|
||||
*async_funcs: Callable[
|
||||
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
|
||||
],
|
||||
) -> List[Any]:
|
||||
"""Run coroutines with a concurrency limit.
|
||||
|
||||
Args:
|
||||
n: The maximum number of concurrent tasks.
|
||||
initializer: A coroutine that initializes shared resources for the tasks.
|
||||
async_funcs: The async_funcs to be run concurrently.
|
||||
|
||||
Returns:
|
||||
A list of results from the coroutines.
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
job_state = {"num_processed": 0}
|
||||
|
||||
callback_queue: asyncio.Queue[Sequence[BaseCallbackHandler]] = asyncio.Queue()
|
||||
for _ in range(n):
|
||||
callback_queue.put_nowait(await initializer())
|
||||
|
||||
async def run_coroutine_with_semaphore(
|
||||
async_func: Callable[
|
||||
[Sequence[BaseCallbackHandler], Dict], Coroutine[Any, Any, Any]
|
||||
]
|
||||
) -> Any:
|
||||
async with semaphore:
|
||||
callbacks = await callback_queue.get()
|
||||
try:
|
||||
result = await async_func(callbacks, job_state)
|
||||
finally:
|
||||
callback_queue.put_nowait(callbacks)
|
||||
return result
|
||||
|
||||
results = await asyncio.gather(
|
||||
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
||||
)
|
||||
while callback_queue:
|
||||
try:
|
||||
callbacks = callback_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
for callback in callbacks:
|
||||
if isinstance(callback, (LangChainTracer, EvaluatorCallbackHandler)):
|
||||
callback.wait_for_futures()
|
||||
return results
|
||||
|
||||
|
||||
async def _callbacks_initializer(
|
||||
project_name: Optional[str],
|
||||
client: Client,
|
||||
run_evaluators: Sequence[RunEvaluator],
|
||||
evaluation_handler_collector: List[EvaluatorCallbackHandler],
|
||||
) -> List[BaseTracer]:
|
||||
"""
|
||||
Initialize a tracer to share across tasks.
|
||||
|
||||
Args:
|
||||
project_name: The project name for the tracer.
|
||||
client: The client to use for the tracer.
|
||||
run_evaluators: The evaluators to run.
|
||||
evaluation_handler_collector: A list to collect the evaluators.
|
||||
Used to wait for the evaluators to finish.
|
||||
|
||||
Returns:
|
||||
The callbacks for this thread.
|
||||
"""
|
||||
callbacks: List[BaseTracer] = []
|
||||
if project_name:
|
||||
callbacks.append(
|
||||
LangChainTracer(
|
||||
project_name=project_name, client=client, use_threading=False
|
||||
)
|
||||
)
|
||||
if run_evaluators:
|
||||
callback = EvaluatorCallbackHandler(
|
||||
client=client,
|
||||
evaluators=run_evaluators,
|
||||
# We already have concurrency, don't want to overload the machine
|
||||
max_workers=1,
|
||||
)
|
||||
callbacks.append(callback)
|
||||
evaluation_handler_collector.append(callback)
|
||||
return callbacks
|
||||
|
||||
|
||||
async def _arun_on_examples(
|
||||
client: Client,
|
||||
examples: Iterator[Example],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
evaluation: Optional[RunEvalConfig] = None,
|
||||
concurrency_level: int = 5,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
data_type: DataType = DataType.kv,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously run the chain on examples and store traces
|
||||
to the specified project name.
|
||||
|
||||
Args:
|
||||
client: LangSmith client to use to log feedback and runs.
|
||||
examples: Examples to run the model or chain over.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
evaluation: Optional evaluation configuration to use when evaluating
|
||||
concurrency_level: The number of async tasks to run concurrently.
|
||||
project_name: Project name to use when tracing runs.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
tags: Tags to add to each run in the project.
|
||||
input_mapper: function to map to the inputs dictionary from an Example
|
||||
to the format expected by the model to be evaluated. This is useful if
|
||||
your model needs to deserialize more complex schema or if your dataset
|
||||
has inputs with keys that differ from what is expected by your chain
|
||||
or agent.
|
||||
data_type: The dataset's data type. This is used to determine determine
|
||||
how to deserialize the reference data and model compatibility.
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
project_name = _get_project_name(project_name, wrapped_model)
|
||||
run_evaluators, examples = _setup_evaluation(
|
||||
wrapped_model, examples, evaluation, data_type
|
||||
)
|
||||
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
|
||||
results: Dict[str, dict] = {}
|
||||
|
||||
async def process_example(
|
||||
example: Example, callbacks: List[BaseCallbackHandler], job_state: dict
|
||||
) -> None:
|
||||
"""Process a single example."""
|
||||
result = await _arun_llm_or_chain(
|
||||
example,
|
||||
wrapped_model,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
results[str(example.id)] = {"output": result}
|
||||
job_state["num_processed"] += 1
|
||||
if verbose:
|
||||
print(
|
||||
f"Processed examples: {job_state['num_processed']}",
|
||||
end="\r",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
evaluation_handlers: List[EvaluatorCallbackHandler] = []
|
||||
await _gather_with_concurrency(
|
||||
concurrency_level,
|
||||
functools.partial(
|
||||
_callbacks_initializer,
|
||||
project_name=project_name,
|
||||
client=client,
|
||||
evaluation_handler_collector=evaluation_handlers,
|
||||
run_evaluators=run_evaluators or [],
|
||||
),
|
||||
*(functools.partial(process_example, e) for e in examples),
|
||||
)
|
||||
all_feedback = {}
|
||||
for handler in evaluation_handlers:
|
||||
handler.wait_for_futures()
|
||||
all_feedback.update(handler.logged_feedback)
|
||||
# join the results and feedback on the example id
|
||||
for example_id, output_dict in results.items():
|
||||
feedback = all_feedback.get(example_id, [])
|
||||
output_dict["feedback"] = feedback
|
||||
return results
|
||||
|
||||
|
||||
## Sync Utilities
|
||||
|
||||
|
||||
@ -1011,10 +793,9 @@ def _run_chain(
|
||||
|
||||
def _run_llm_or_chain(
|
||||
example: Example,
|
||||
llm_or_chain_factory: MCF,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
llm_or_chain_factory: MCF,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[dict, str, LLMResult, ChatResult]:
|
||||
"""
|
||||
@ -1030,15 +811,6 @@ def _run_llm_or_chain(
|
||||
Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
The outputs of the model or chain.
|
||||
"""
|
||||
if callbacks:
|
||||
previous_example_ids = [
|
||||
getattr(tracer, "example_id", None) for tracer in callbacks
|
||||
]
|
||||
for tracer in callbacks:
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example.id
|
||||
else:
|
||||
previous_example_ids = None
|
||||
chain_or_llm = (
|
||||
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
|
||||
)
|
||||
@ -1048,8 +820,8 @@ def _run_llm_or_chain(
|
||||
output: Any = _run_llm(
|
||||
llm_or_chain_factory,
|
||||
example.inputs,
|
||||
callbacks,
|
||||
tags=tags,
|
||||
config["callbacks"],
|
||||
tags=config["tags"],
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
else:
|
||||
@ -1057,98 +829,22 @@ def _run_llm_or_chain(
|
||||
output = _run_chain(
|
||||
chain,
|
||||
example.inputs,
|
||||
callbacks,
|
||||
tags=tags,
|
||||
config["callbacks"],
|
||||
tags=config["tags"],
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
result = output
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
logger.warning(
|
||||
f"{chain_or_llm} failed for example {example.id} with inputs:"
|
||||
f" {example.inputs}.\nError: {e}",
|
||||
f"{chain_or_llm} failed for example {example.id} "
|
||||
f"with inputs {example.inputs}"
|
||||
f"\nError Type: {error_type}, Message: {e}"
|
||||
)
|
||||
result = {"Error": str(e)}
|
||||
if callbacks and previous_example_ids:
|
||||
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example_id
|
||||
result = {"Error": repr(e)}
|
||||
return result
|
||||
|
||||
|
||||
def _run_on_examples(
|
||||
client: Client,
|
||||
examples: Iterator[Example],
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
evaluation: Optional[RunEvalConfig] = None,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
data_type: DataType = DataType.kv,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the Chain or language model on examples and store
|
||||
traces to the specified project name.
|
||||
|
||||
Args:
|
||||
client: LangSmith client to use to log feedback and runs.
|
||||
examples: Examples to run the model or chain over.
|
||||
llm_or_chain_factory: Language model or Chain constructor to run
|
||||
over the dataset. The Chain constructor is used to permit
|
||||
independent calls on each example without carrying over state.
|
||||
evaluation: Optional evaluation configuration to use when evaluating
|
||||
project_name: Name of the project to store the traces in.
|
||||
Defaults to {dataset_name}-{chain class name}-{datetime}.
|
||||
verbose: Whether to print progress.
|
||||
tags: Tags to add to each run in the project.
|
||||
input_mapper: A function to map to the inputs dictionary from an Example
|
||||
to the format expected by the model to be evaluated. This is useful if
|
||||
your model needs to deserialize more complex schema or if your dataset
|
||||
has inputs with keys that differ from what is expected by your chain
|
||||
or agent.
|
||||
data_type: The dataset's data type. This is used to determine determine
|
||||
how to deserialize the reference data and model compatibility.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
results: Dict[str, dict] = {}
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
project_name = _get_project_name(project_name, wrapped_model)
|
||||
tracer = LangChainTracer(
|
||||
project_name=project_name, client=client, use_threading=False
|
||||
)
|
||||
run_evaluators, examples = _setup_evaluation(
|
||||
wrapped_model, examples, evaluation, data_type
|
||||
)
|
||||
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
|
||||
evaluation_handler = EvaluatorCallbackHandler(
|
||||
evaluators=run_evaluators or [],
|
||||
client=client,
|
||||
)
|
||||
callbacks: List[BaseCallbackHandler] = [tracer, evaluation_handler]
|
||||
for i, example in enumerate(examples):
|
||||
result = _run_llm_or_chain(
|
||||
example,
|
||||
wrapped_model,
|
||||
tags=tags,
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
results[str(example.id)] = {"output": result}
|
||||
tracer.wait_for_futures()
|
||||
evaluation_handler.wait_for_futures()
|
||||
all_feedback = evaluation_handler.logged_feedback
|
||||
# join the results and feedback on the example id
|
||||
for example_id, output_dict in results.items():
|
||||
feedback = all_feedback.get(example_id, [])
|
||||
output_dict["feedback"] = feedback
|
||||
return results
|
||||
|
||||
|
||||
## Public API
|
||||
|
||||
|
||||
@ -1156,10 +852,9 @@ def _prepare_eval_run(
|
||||
client: Client,
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
project_name: Optional[str],
|
||||
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
|
||||
project_name: str,
|
||||
) -> Tuple[MCF, str, Dataset, List[Example]]:
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
||||
project_name = _get_project_name(project_name, wrapped_model)
|
||||
try:
|
||||
project = client.create_project(project_name)
|
||||
except ValueError as e:
|
||||
@ -1168,21 +863,95 @@ def _prepare_eval_run(
|
||||
raise ValueError(
|
||||
f"Project {project_name} already exists. Please use a different name."
|
||||
)
|
||||
project_url = _get_eval_project_url(client.api_url, project.id)
|
||||
print(
|
||||
f"View the evaluation results for project '{project_name}' at:\n{project_url}"
|
||||
f"View the evaluation results for project '{project_name}' at:\n{project.url}"
|
||||
)
|
||||
dataset = client.read_dataset(dataset_name=dataset_name)
|
||||
examples = client.list_examples(dataset_id=str(dataset.id))
|
||||
examples = list(client.list_examples(dataset_id=dataset.id))
|
||||
if not examples:
|
||||
raise ValueError(f"Dataset {dataset_name} has no example rows.")
|
||||
return wrapped_model, project_name, dataset, examples
|
||||
|
||||
|
||||
def _prepare_run_on_dataset(
|
||||
client: Client,
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
project_name: Optional[str],
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
concurrency_level: int = 5,
|
||||
) -> Tuple[MCF, str, List[Example], List[RunnableConfig]]:
|
||||
project_name = project_name or name_generation.random_name()
|
||||
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||
client, dataset_name, llm_or_chain_factory, project_name
|
||||
)
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
run_evaluators = _setup_evaluation(
|
||||
wrapped_model, examples, evaluation, dataset.data_type
|
||||
)
|
||||
_validate_example_inputs(examples[0], wrapped_model, input_mapper)
|
||||
progress_bar = progress.ProgressBarCallback(len(examples))
|
||||
configs = [
|
||||
RunnableConfig(
|
||||
callbacks=[
|
||||
LangChainTracer(
|
||||
project_name=project_name,
|
||||
client=client,
|
||||
use_threading=False,
|
||||
example_id=example.id,
|
||||
),
|
||||
EvaluatorCallbackHandler(
|
||||
evaluators=run_evaluators or [],
|
||||
client=client,
|
||||
max_workers=0,
|
||||
example_id=example.id,
|
||||
),
|
||||
progress_bar,
|
||||
],
|
||||
tags=tags or [],
|
||||
max_concurrency=concurrency_level,
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
return wrapped_model, project_name, examples, configs
|
||||
|
||||
|
||||
def _collect_test_results(
|
||||
examples: List[Example],
|
||||
batch_results: List[Union[dict, str, LLMResult, ChatResult]],
|
||||
configs: List[RunnableConfig],
|
||||
project_name: str,
|
||||
) -> TestResult:
|
||||
wait_for_all_tracers()
|
||||
all_feedback = {}
|
||||
for c in configs:
|
||||
for callback in cast(list, c["callbacks"]):
|
||||
if isinstance(callback, EvaluatorCallbackHandler):
|
||||
all_feedback.update(callback.logged_feedback)
|
||||
results = {}
|
||||
for example, output in zip(examples, batch_results):
|
||||
feedback = all_feedback.get(str(example.id), [])
|
||||
results[str(example.id)] = {
|
||||
"output": output,
|
||||
"input": example.inputs,
|
||||
"feedback": feedback,
|
||||
}
|
||||
if example.outputs:
|
||||
results[str(example.id)]["reference"] = example.outputs
|
||||
return TestResult(
|
||||
project_name=project_name,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
async def arun_on_dataset(
|
||||
client: Client,
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
evaluation: Optional[RunEvalConfig] = None,
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
concurrency_level: int = 5,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
@ -1227,7 +996,7 @@ async def arun_on_dataset(
|
||||
from langsmith import Client
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.smith import RunEvalConfig, arun_on_dataset
|
||||
from langchain.smith import smith_eval.RunEvalConfig, arun_on_dataset
|
||||
|
||||
# Chains may have memory. Passing in a constructor function lets the
|
||||
# evaluation framework avoid cross-contamination between runs.
|
||||
@ -1240,12 +1009,12 @@ async def arun_on_dataset(
|
||||
return chain
|
||||
|
||||
# Load off-the-shelf evaluators via config or the EvaluatorType (string or enum)
|
||||
evaluation_config = RunEvalConfig(
|
||||
evaluation_config = smith_eval.RunEvalConfig(
|
||||
evaluators=[
|
||||
"qa", # "Correctness" against a reference answer
|
||||
"embedding_distance",
|
||||
RunEvalConfig.Criteria("helpfulness"),
|
||||
RunEvalConfig.Criteria({
|
||||
smith_eval.RunEvalConfig.Criteria("helpfulness"),
|
||||
smith_eval.RunEvalConfig.Criteria({
|
||||
"fifth-grader-score": "Do you have to be smarter than a fifth grader to answer this question?"
|
||||
}),
|
||||
]
|
||||
@ -1286,7 +1055,7 @@ async def arun_on_dataset(
|
||||
return {"score": prediction == reference}
|
||||
|
||||
|
||||
evaluation_config = RunEvalConfig(
|
||||
evaluation_config = smith_eval.RunEvalConfig(
|
||||
custom_evaluators = [MyStringEvaluator()],
|
||||
)
|
||||
|
||||
@ -1299,51 +1068,43 @@ async def arun_on_dataset(
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"The following arguments are deprecated and will "
|
||||
"be removed in a future release: "
|
||||
"The following arguments are deprecated and "
|
||||
"will be removed in a future release: "
|
||||
f"{kwargs.keys()}.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||
client, dataset_name, llm_or_chain_factory, project_name
|
||||
)
|
||||
results = await _arun_on_examples(
|
||||
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
|
||||
client,
|
||||
examples,
|
||||
wrapped_model,
|
||||
concurrency_level=concurrency_level,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
evaluation=evaluation,
|
||||
dataset_name,
|
||||
llm_or_chain_factory,
|
||||
project_name,
|
||||
evaluation,
|
||||
tags,
|
||||
input_mapper,
|
||||
concurrency_level,
|
||||
)
|
||||
|
||||
batch_results = await runnable_utils.gather_with_concurrency(
|
||||
configs[0].get("max_concurrency"),
|
||||
*map(
|
||||
functools.partial(
|
||||
_arun_llm_or_chain,
|
||||
llm_or_chain_factory=wrapped_model,
|
||||
input_mapper=input_mapper,
|
||||
data_type=dataset.data_type,
|
||||
),
|
||||
examples,
|
||||
configs,
|
||||
),
|
||||
)
|
||||
return TestResult(
|
||||
project_name=project_name,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
def _handle_coroutine(coro: Coroutine) -> Any:
|
||||
"""
|
||||
Handles a coroutine from a sync context.
|
||||
|
||||
Args:
|
||||
coro (asyncio.coroutine): The coroutine to be handled.
|
||||
|
||||
Returns:
|
||||
any: The result of the executed coroutine.
|
||||
"""
|
||||
# Check if there's a running event loop
|
||||
results = _collect_test_results(examples, batch_results, configs, project_name)
|
||||
if verbose:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError: # No event loop
|
||||
return asyncio.run(coro)
|
||||
if loop.is_running():
|
||||
return loop.run_until_complete(coro)
|
||||
else:
|
||||
return asyncio.run(coro)
|
||||
agg_feedback = results.get_aggregate_feedback()
|
||||
print("\n Eval quantiles:")
|
||||
print(agg_feedback)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
|
||||
return results
|
||||
|
||||
|
||||
def run_on_dataset(
|
||||
@ -1351,7 +1112,7 @@ def run_on_dataset(
|
||||
dataset_name: str,
|
||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||
*,
|
||||
evaluation: Optional[RunEvalConfig] = None,
|
||||
evaluation: Optional[smith_eval.RunEvalConfig] = None,
|
||||
concurrency_level: int = 5,
|
||||
project_name: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
@ -1397,7 +1158,7 @@ def run_on_dataset(
|
||||
from langsmith import Client
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.smith import RunEvalConfig, run_on_dataset
|
||||
from langchain.smith import smith_eval.RunEvalConfig, run_on_dataset
|
||||
|
||||
# Chains may have memory. Passing in a constructor function lets the
|
||||
# evaluation framework avoid cross-contamination between runs.
|
||||
@ -1410,12 +1171,12 @@ def run_on_dataset(
|
||||
return chain
|
||||
|
||||
# Load off-the-shelf evaluators via config or the EvaluatorType (string or enum)
|
||||
evaluation_config = RunEvalConfig(
|
||||
evaluation_config = smith_eval.RunEvalConfig(
|
||||
evaluators=[
|
||||
"qa", # "Correctness" against a reference answer
|
||||
"embedding_distance",
|
||||
RunEvalConfig.Criteria("helpfulness"),
|
||||
RunEvalConfig.Criteria({
|
||||
smith_eval.RunEvalConfig.Criteria("helpfulness"),
|
||||
smith_eval.RunEvalConfig.Criteria({
|
||||
"fifth-grader-score": "Do you have to be smarter than a fifth grader to answer this question?"
|
||||
}),
|
||||
]
|
||||
@ -1456,7 +1217,7 @@ def run_on_dataset(
|
||||
return {"score": prediction == reference}
|
||||
|
||||
|
||||
evaluation_config = RunEvalConfig(
|
||||
evaluation_config = smith_eval.RunEvalConfig(
|
||||
custom_evaluators = [MyStringEvaluator()],
|
||||
)
|
||||
|
||||
@ -1474,37 +1235,35 @@ def run_on_dataset(
|
||||
f"{kwargs.keys()}.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||
client, dataset_name, llm_or_chain_factory, project_name
|
||||
)
|
||||
if concurrency_level in (0, 1):
|
||||
results = _run_on_examples(
|
||||
wrapped_model, project_name, examples, configs = _prepare_run_on_dataset(
|
||||
client,
|
||||
examples,
|
||||
wrapped_model,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
evaluation=evaluation,
|
||||
dataset_name,
|
||||
llm_or_chain_factory,
|
||||
project_name,
|
||||
evaluation,
|
||||
tags,
|
||||
input_mapper,
|
||||
concurrency_level,
|
||||
)
|
||||
with runnable_config.get_executor_for_config(configs[0]) as executor:
|
||||
batch_results = list(
|
||||
executor.map(
|
||||
functools.partial(
|
||||
_run_llm_or_chain,
|
||||
llm_or_chain_factory=wrapped_model,
|
||||
input_mapper=input_mapper,
|
||||
data_type=dataset.data_type,
|
||||
)
|
||||
else:
|
||||
# TODO: Use runnables and the batch method
|
||||
coro = _arun_on_examples(
|
||||
client,
|
||||
),
|
||||
examples,
|
||||
wrapped_model,
|
||||
concurrency_level=concurrency_level,
|
||||
project_name=project_name,
|
||||
verbose=verbose,
|
||||
tags=tags,
|
||||
evaluation=evaluation,
|
||||
input_mapper=input_mapper,
|
||||
data_type=dataset.data_type,
|
||||
configs,
|
||||
)
|
||||
results = _handle_coroutine(coro)
|
||||
return TestResult(
|
||||
project_name=project_name,
|
||||
results=results,
|
||||
)
|
||||
|
||||
results = _collect_test_results(examples, batch_results, configs, project_name)
|
||||
if verbose:
|
||||
try:
|
||||
agg_feedback = results.get_aggregate_feedback()
|
||||
print("\n Eval quantiles:")
|
||||
print(agg_feedback)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to print aggregate feedback: {repr(e)}")
|
||||
return results
|
||||
|
@ -148,13 +148,27 @@ class ChainStringRunMapper(StringRunMapper):
|
||||
def map(self, run: Run) -> Dict[str, str]:
|
||||
"""Maps the Run to a dictionary."""
|
||||
if not run.outputs:
|
||||
raise ValueError(f"Run {run.id} has no outputs to evaluate.")
|
||||
if self.input_key is not None and self.input_key not in run.inputs:
|
||||
raise ValueError(f"Run {run.id} does not have input key {self.input_key}.")
|
||||
elif self.prediction_key is not None and self.prediction_key not in run.outputs:
|
||||
raise ValueError(
|
||||
f"Run {run.id} does not have prediction key {self.prediction_key}."
|
||||
f"Run with ID {run.id} lacks outputs required for evaluation."
|
||||
" Ensure the Run has valid outputs."
|
||||
)
|
||||
if self.input_key is not None and self.input_key not in run.inputs:
|
||||
raise ValueError(
|
||||
f"Run with ID {run.id} is missing the expected input key"
|
||||
f" '{self.input_key}'.\nAvailable input keys in this Run"
|
||||
f" are: {run.inputs.keys()}.\nAdjust the evaluator's"
|
||||
f" input_key or ensure your input data includes key"
|
||||
f" '{self.input_key}'."
|
||||
)
|
||||
elif self.prediction_key is not None and self.prediction_key not in run.outputs:
|
||||
available_keys = ", ".join(run.outputs.keys())
|
||||
raise ValueError(
|
||||
f"Run with ID {run.id} doesn't have the expected prediction key"
|
||||
f" '{self.prediction_key}'. Available prediction keys in this Run are:"
|
||||
f" {available_keys}. Adjust the evaluator's prediction_key or"
|
||||
" ensure the Run object's outputs the expected key."
|
||||
)
|
||||
|
||||
else:
|
||||
input_ = self._get_key(run.inputs, self.input_key, "input")
|
||||
prediction = self._get_key(run.outputs, self.prediction_key, "prediction")
|
||||
|
@ -5,7 +5,6 @@ import pytest
|
||||
from langsmith import Client as Client
|
||||
from langsmith.schemas import DataType
|
||||
|
||||
from langchain.callbacks.tracers.evaluation import wait_for_all_evaluators
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.evaluation import EvaluatorType
|
||||
@ -22,7 +21,6 @@ def _check_all_feedback_passed(_project_name: str, client: Client) -> None:
|
||||
# chain or llm passes for the feedback provided.
|
||||
runs = list(client.list_runs(project_name=_project_name, execution_order=1))
|
||||
assert len(runs) == 4
|
||||
wait_for_all_evaluators()
|
||||
feedback = list(client.list_feedback(run_ids=[run.id for run in runs]))
|
||||
assert len(feedback) == 8
|
||||
assert all([f.score == 1 for f in feedback])
|
||||
|
@ -181,11 +181,15 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||
assert "the wrong input" in inputs
|
||||
return {"the right input": inputs["the wrong input"]}
|
||||
|
||||
result = _run_llm_or_chain(example, lambda: mock_chain, input_mapper=input_mapper)
|
||||
result = _run_llm_or_chain(
|
||||
example,
|
||||
{"callbacks": [], "tags": []},
|
||||
llm_or_chain_factory=lambda: mock_chain,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
assert result == {"output": "2", "the right input": "1"}
|
||||
bad_result = _run_llm_or_chain(
|
||||
example,
|
||||
lambda: mock_chain,
|
||||
example, {"callbacks": [], "tags": []}, llm_or_chain_factory=lambda: mock_chain
|
||||
)
|
||||
assert "Error" in bad_result
|
||||
|
||||
@ -195,7 +199,12 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||
return "the right input"
|
||||
|
||||
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
|
||||
llm_result = _run_llm_or_chain(example, mock_llm, input_mapper=llm_input_mapper)
|
||||
llm_result = _run_llm_or_chain(
|
||||
example,
|
||||
{"callbacks": [], "tags": []},
|
||||
llm_or_chain_factory=mock_llm,
|
||||
input_mapper=llm_input_mapper,
|
||||
)
|
||||
assert isinstance(llm_result, str)
|
||||
assert llm_result == "somenumber"
|
||||
|
||||
@ -324,10 +333,14 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
)
|
||||
|
||||
expected = {
|
||||
uuid_: {
|
||||
"output": {"result": f"Result for example {uuid.UUID(uuid_)}"},
|
||||
str(example.id): {
|
||||
"output": {
|
||||
"result": f"Result for example {uuid.UUID(str(example.id))}"
|
||||
},
|
||||
"input": {"input": example.inputs["input"]},
|
||||
"reference": {"output": example.outputs["output"]},
|
||||
"feedback": [],
|
||||
}
|
||||
for uuid_ in uuids
|
||||
for example in examples
|
||||
}
|
||||
assert results["results"] == expected
|
||||
|
Loading…
Reference in New Issue
Block a user