Compare commits

...

1 Commits

Author SHA1 Message Date
William Fu-Hinthorn
33c353b8f3 Update to not use pydantic 2023-08-02 15:08:25 -07:00
4 changed files with 14 additions and 33 deletions

View File

@@ -6,7 +6,7 @@ import warnings
from typing import Any, Dict, List, Optional
from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunBase as Run
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from pydantic import BaseModel, Field, root_validator
@@ -96,28 +96,6 @@ class ToolRun(BaseRun):
child_tool_runs: List[ToolRun] = Field(default_factory=list)
# Begin V2 API Schemas
class Run(BaseRunV2):
"""Run schema for the V2 API in the Tracer."""
execution_order: int
child_execution_order: int
child_runs: List[Run] = Field(default_factory=list)
tags: Optional[List[str]] = Field(default_factory=list)
@root_validator(pre=True)
def assign_name(cls, values: dict) -> dict:
"""Assign name to the run."""
if values.get("name") is None:
if "name" in values["serialized"]:
values["name"] = values["serialized"]["name"]
elif "id" in values["serialized"]:
values["name"] = values["serialized"]["id"][-1]
return values
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
from enum import Enum
import functools
import itertools
import logging
@@ -20,6 +21,7 @@ from typing import (
Union,
)
from urllib.parse import urlparse, urlunparse
import uuid
from langsmith import Client, RunEvaluator
from langsmith.schemas import Dataset, DataType, Example
@@ -233,12 +235,12 @@ def _get_project_name(
"""
if project_name is not None:
return project_name
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
if isinstance(llm_or_chain_factory, BaseLanguageModel):
model_name = llm_or_chain_factory.__class__.__name__
else:
model_name = llm_or_chain_factory().__class__.__name__
return f"{current_time}-{model_name}"
hex = uuid.uuid4().hex[:8]
return f"{hex}-{model_name}"
## Shared Validation Utilities
@@ -345,9 +347,10 @@ def _setup_evaluation(
else:
run_type = "chain"
if data_type in (DataType.chat, DataType.llm):
val = data_type.value if isinstance(data_type, Enum) else data_type
raise ValueError(
"Cannot evaluate a chain on dataset with "
f"data_type={data_type.value}. "
f"data_type={data_type}. "
"Please specify a dataset with the default 'kv' data type."
)
chain = llm_or_chain_factory()
@@ -1076,9 +1079,6 @@ def _run_on_examples(
return results
## Public API
def _prepare_eval_run(
client: Client,
dataset_name: str,
@@ -1104,6 +1104,9 @@ def _prepare_eval_run(
return llm_or_chain_factory, project_name, dataset, examples
## Public API
async def arun_on_dataset(
client: Client,
dataset_name: str,

View File

@@ -261,9 +261,9 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
return evaluate_strings_inputs
def _prepare_output(self, output: Dict[str, Any]) -> Dict[str, Any]:
evaluation_result = EvaluationResult(
key=self.name, comment=output.get("reasoning"), **output
)
if "key" not in output:
output["key"] = self.name
evaluation_result = EvaluationResult.from_dict(output)
if RUN_KEY in output:
# TODO: Not currently surfaced. Update
evaluation_result.evaluator_info[RUN_KEY] = output[RUN_KEY]

View File

@@ -120,7 +120,7 @@ cassio = {version = "^0.0.7", optional = true}
rdflib = {version = "^6.3.2", optional = true}
sympy = {version = "^1.12", optional = true}
rapidfuzz = {version = "^3.1.1", optional = true}
langsmith = "~0.0.11"
langsmith = "~0.1.0"
rank-bm25 = {version = "^0.2.2", optional = true}
amadeus = {version = ">=8.1.0", optional = true}
geopandas = {version = "^0.13.1", optional = true}