mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 00:00:34 +00:00
fmt
This commit is contained in:
@@ -59,7 +59,7 @@ class _ScoredAnswer(TypedDict):
|
||||
answer: str
|
||||
|
||||
|
||||
class _MapRerankOutputType(BaseModel):
|
||||
class _MapRerankOutput(BaseModel):
|
||||
top_answer: str = Field(..., description="The highest-scored answer.")
|
||||
all_answers: List[_ScoredAnswer] = Field(
|
||||
...,
|
||||
@@ -136,7 +136,7 @@ def create_map_rerank_documents_chain(
|
||||
_document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
|
||||
_output_parser = output_parser or (StrOutputParser() | _default_regex)
|
||||
|
||||
def format_inputs(inputs):
|
||||
def format_inputs(inputs: dict) -> list:
|
||||
docs = inputs.pop(DOCUMENTS_KEY)
|
||||
return [
|
||||
{DOCUMENTS_KEY: format_document(doc, _document_prompt), **inputs}
|
||||
@@ -148,8 +148,17 @@ def create_map_rerank_documents_chain(
|
||||
def top_answer(results):
|
||||
return max(results["all_answers"], key=lambda x: float(x["score"]))["answer"]
|
||||
|
||||
return RunnableParallel(all_answers=map_chain) | RunnablePassthrough.assign(
|
||||
top_answer=top_answer
|
||||
return (
|
||||
(
|
||||
RunnableParallel(all_answers=map_chain).with_config(
|
||||
run_name="answer_and_score"
|
||||
)
|
||||
| RunnablePassthrough.assign(top_answer=top_answer).with_config(
|
||||
run_name="return_answer"
|
||||
)
|
||||
)
|
||||
.with_config(run_name="map_rerank_documents_chain")
|
||||
.with_types(output_type=_MapRerankOutput)
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user