This commit is contained in:
CG80499
2023-09-26 17:14:47 +00:00
parent fe775f929e
commit 087a0db0ae
3 changed files with 13 additions and 14 deletions

View File

@@ -9,7 +9,7 @@ from langchain.chains import LLMChain
from langchain.chat_models.openai import ChatOpenAI
from langchain.pydantic_v1 import BaseModel
from langchain.output_parsers.json import SimpleJsonOutputParser
from langchain.evaluation.comparison.llm_as_a_judge.eval_chain import LLMAsAJudgePairwiseEvalChain
from langchain.evaluation.comparison import PairwiseStringEvalChain
from langchain.callbacks.manager import get_openai_callback
class SummaryParser(SimpleJsonOutputParser):
@@ -93,7 +93,7 @@ cod_summarize_chain = LLMChain(llm=llm, prompt=cod_summarization_prompt, output_
base_summarize_chaim = BASE_PROMPT | llm
evaluator = LLMAsAJudgePairwiseEvalChain.from_llm(llm=llm)
evaluator = PairwiseStringEvalChain.from_llm(llm=llm)
def _reverse_verdict(verdict: str) -> str:
return "Win" if verdict == "Loss" else "Loss" if verdict == "Win" else "Tie"

View File

@@ -9,7 +9,7 @@ from langchain.chains import LLMChain
from langchain.chat_models.openai import ChatOpenAI
from langchain.pydantic_v1 import BaseModel
from langchain.output_parsers.json import SimpleJsonOutputParser
from langchain.evaluation.comparison.llm_as_a_judge.eval_chain import LLMAsAJudgePairwiseEvalChain
from langchain.evaluation.comparison import PairwiseStringEvalChain
from langchain.callbacks.manager import get_openai_callback
class SummaryParser(SimpleJsonOutputParser):
@@ -59,12 +59,12 @@ base_summarize_chaim = BASE_PROMPT | llm
ft_summarize_chain = FT_PROMPT | ft_llm
evaluator = LLMAsAJudgePairwiseEvalChain.from_llm(llm=llm)
evaluator = PairwiseStringEvalChain.from_llm(llm=llm)
def _reverse_verdict(verdict: str) -> str:
return "Win" if verdict == "Loss" else "Loss" if verdict == "Win" else "Tie"
def _reverse_verdict(verdict: str | None) -> str | None:
return "B" if verdict == "A" else "A" if verdict == "B" else None
async def evaluate(sample: Sample) -> bool:
async def evaluate(sample: Sample) -> str | None:
base_summary = (await base_summarize_chaim.ainvoke({"article": sample.article})).content
ft_summary = (await ft_summarize_chain.ainvoke({"article": sample.article})).content
reverse = (len(base_summary) + len(ft_summary)) % 2 == 0
@@ -76,10 +76,9 @@ async def evaluate(sample: Sample) -> bool:
print("Base summary:", base_summary)
print("FT summary:", ft_summary)
print("Reverse:", reverse)
print(result)
if reverse:
return _reverse_verdict(result["verdict"])
return result["verdict"]
return _reverse_verdict(result["value"])
return result["value"]
async def main() -> None:
pbar = tqdm(total=len(samples[:40]))
@@ -97,10 +96,10 @@ async def main() -> None:
*[boxed_evaluate(sample) for sample in samples[:40]]
)
results_excluding_ties = [result for result in results if result != "Tie"]
results_excluding_ties = [result for result in results if result != None]
print(
"Win rate:",
sum([result == "Win" for result in results]) / len(results_excluding_ties),
sum([result == "A" for result in results]) / len(results_excluding_ties),
)
print("Number of ties:", len(results) - len(results_excluding_ties))

View File

@@ -9,7 +9,7 @@ from langchain.chains import LLMChain
from langchain.chat_models.openai import ChatOpenAI
from langchain.pydantic_v1 import BaseModel
from langchain.output_parsers.json import SimpleJsonOutputParser
from langchain.evaluation.comparison.llm_as_a_judge.eval_chain import LLMAsAJudgePairwiseEvalChain
from langchain.evaluation.comparison import PairwiseStringEvalChain
from langchain.callbacks.manager import get_openai_callback
class SummaryParser(SimpleJsonOutputParser):
@@ -92,7 +92,7 @@ cod_summarize_chain = LLMChain(llm=llm, prompt=cod_summarization_prompt, output_
ft_summarize_chain = FT_PROMPT | ft_llm
evaluator = LLMAsAJudgePairwiseEvalChain.from_llm(llm=llm)
evaluator = PairwiseStringEvalChain.from_llm(llm=llm)
def _reverse_verdict(verdict: str) -> str:
return "Win" if verdict == "Loss" else "Loss" if verdict == "Win" else "Tie"