This commit is contained in:
Bagatur
2023-12-26 17:25:06 -05:00
parent 4747cfafee
commit 299851bcac

View File

@@ -59,7 +59,7 @@ class _ScoredAnswer(TypedDict):
answer: str
class _MapRerankOutputType(BaseModel):
class _MapRerankOutput(BaseModel):
top_answer: str = Field(..., description="The highest-scored answer.")
all_answers: List[_ScoredAnswer] = Field(
...,
@@ -136,7 +136,7 @@ def create_map_rerank_documents_chain(
_document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
_output_parser = output_parser or (StrOutputParser() | _default_regex)
def format_inputs(inputs):
def format_inputs(inputs: dict) -> list:
docs = inputs.pop(DOCUMENTS_KEY)
return [
{DOCUMENTS_KEY: format_document(doc, _document_prompt), **inputs}
@@ -148,8 +148,17 @@ def create_map_rerank_documents_chain(
def top_answer(results):
return max(results["all_answers"], key=lambda x: float(x["score"]))["answer"]
return RunnableParallel(all_answers=map_chain) | RunnablePassthrough.assign(
top_answer=top_answer
return (
(
RunnableParallel(all_answers=map_chain).with_config(
run_name="answer_and_score"
)
| RunnablePassthrough.assign(top_answer=top_answer).with_config(
run_name="return_answer"
)
)
.with_config(run_name="map_rerank_documents_chain")
.with_types(output_type=_MapRerankOutput)
)