cohere: Add citations to agent, flexibility to tool parsing, fix SDK issue (#19965)

**Description:** Citations are the main addition in this PR. We now emit
them from the multihop agent! Additionally the agent is now more
flexible with observations (`Any` is now accepted), and the Cohere SDK
version is bumped to fix an issue with the most recent version of
pydantic v1 (1.10.15)
This commit is contained in:
harry-cohere 2024-04-04 15:02:30 +01:00 committed by GitHub
parent 605c3f23e1
commit e103492eb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 482 additions and 64 deletions

View File

@ -1,11 +1,13 @@
from langchain_cohere.chat_models import ChatCohere from langchain_cohere.chat_models import ChatCohere
from langchain_cohere.cohere_agent import create_cohere_tools_agent from langchain_cohere.cohere_agent import create_cohere_tools_agent
from langchain_cohere.common import CohereCitation
from langchain_cohere.embeddings import CohereEmbeddings from langchain_cohere.embeddings import CohereEmbeddings
from langchain_cohere.rag_retrievers import CohereRagRetriever from langchain_cohere.rag_retrievers import CohereRagRetriever
from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent
from langchain_cohere.rerank import CohereRerank from langchain_cohere.rerank import CohereRerank
__all__ = [ __all__ = [
"CohereCitation",
"ChatCohere", "ChatCohere",
"CohereEmbeddings", "CohereEmbeddings",
"CohereRagRetriever", "CohereRagRetriever",

View File

@ -0,0 +1,36 @@
from dataclasses import dataclass
from typing import Any, List, Mapping
@dataclass
class CohereCitation:
"""
Cohere has fine-grained citations that specify the exact part of text.
More info at https://docs.cohere.com/docs/documents-and-citations
"""
"""
The index of text that the citation starts at, counting from zero. For example, a
generation of 'Hello, world!' with a citation on 'world' would have a start value
of 7. This is because the citation starts at 'w', which is the seventh character.
"""
start: int
"""
The index of text that the citation ends after, counting from zero. For example, a
generation of 'Hello, world!' with a citation on 'world' would have an end value of
11. This is because the citation ends after 'd', which is the eleventh character.
"""
end: int
"""
The text of the citation. For example, a generation of 'Hello, world!' with a
citation of 'world' would have a text value of 'world'.
"""
text: str
"""
The contents of the documents that were cited. When used with agents these will be
the contents of relevant agent outputs.
"""
documents: List[Mapping[str, Any]]

View File

@ -5,17 +5,27 @@
This agent uses a multi hop prompt by Cohere, which is experimental and subject This agent uses a multi hop prompt by Cohere, which is experimental and subject
to change. The latest prompt can be used by upgrading the langchain-cohere package. to change. The latest prompt can be used by upgrading the langchain-cohere package.
""" """
from typing import Sequence from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableParallel,
RunnablePassthrough,
)
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_cohere.react_multi_hop.parsing import ( from langchain_cohere.react_multi_hop.parsing import (
GROUNDED_ANSWER_KEY,
OUTPUT_KEY,
CohereToolsReactAgentOutputParser, CohereToolsReactAgentOutputParser,
parse_citations,
) )
from langchain_cohere.react_multi_hop.prompt import ( from langchain_cohere.react_multi_hop.prompt import (
convert_to_documents,
multi_hop_prompt, multi_hop_prompt,
) )
@ -36,8 +46,14 @@ def create_cohere_react_agent(
Returns: Returns:
A Runnable sequence representing an agent. It takes as input all the same input A Runnable sequence representing an agent. It takes as input all the same input
variables as the prompt passed in does and returns an AgentAction or variables as the prompt passed in does and returns a List[AgentAction] or a
AgentFinish. single AgentFinish.
The AgentFinish will have two fields:
* output: str - The output string generated by the model
* citations: List[CohereCitation] - A list of citations that refer to the
output and observations made by the agent. If there are no citations this
list will be empty.
Example: Example:
. code-block:: python . code-block:: python
@ -61,14 +77,61 @@ def create_cohere_react_agent(
"input": "In what year was the company that was founded as Sound of Music added to the S&P 500?", "input": "In what year was the company that was founded as Sound of Music added to the S&P 500?",
}) })
""" # noqa: E501 """ # noqa: E501
# Creates a prompt, invokes the model, and produces a
# "Union[List[AgentAction], AgentFinish]"
generate_agent_steps = (
multi_hop_prompt(tools=tools, prompt=prompt)
| llm.bind(stop=["\nObservation:"], raw_prompting=True)
| CohereToolsReactAgentOutputParser()
)
agent = ( agent = (
RunnablePassthrough.assign( RunnablePassthrough.assign(
# agent_scratchpad isn't used in this chain, but added here for # agent_scratchpad isn't used in this chain, but added here for
# interoperability with other chains that may require it. # interoperability with other chains that may require it.
agent_scratchpad=lambda _: [], agent_scratchpad=lambda _: [],
) )
| multi_hop_prompt(tools=tools, prompt=prompt) | RunnableParallel(
| llm.bind(stop=["\nObservation:"], raw_prompting=True) chain_input=RunnablePassthrough(), agent_steps=generate_agent_steps
| CohereToolsReactAgentOutputParser() )
| _AddCitations()
) )
return agent return agent
class _AddCitations(Runnable):
"""
Adds a list of citations to the output of the Cohere multi hop chain when the
last step is an AgentFinish. Citations are generated from the observations (made
in previous agent steps) and the grounded answer (made in the last step).
"""
def invoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
) -> Union[List[AgentAction], AgentFinish]:
agent_steps = input.get("agent_steps", [])
if not agent_steps:
# The input wasn't as expected.
return []
if not isinstance(agent_steps, AgentFinish):
# We're not on the AgentFinish step.
return agent_steps
agent_finish = agent_steps
# Build a list of documents from the intermediate_steps used in this chain.
intermediate_steps = input.get("chain_input", {}).get("intermediate_steps", [])
documents: List[Mapping] = []
for _, observation in intermediate_steps:
documents.extend(convert_to_documents(observation))
# Build a list of citations, if any, from the documents + grounded answer.
grounded_answer = agent_finish.return_values.pop(GROUNDED_ANSWER_KEY, "")
output, citations = parse_citations(
grounded_answer=grounded_answer, documents=documents
)
agent_finish.return_values[OUTPUT_KEY] = output
agent_finish.return_values["citations"] = citations
return agent_finish

View File

@ -1,12 +1,17 @@
import json import json
import logging import logging
import re import re
from typing import Dict, List, Tuple, Union from typing import Any, Dict, List, Mapping, Tuple, Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers import BaseOutputParser
from langchain_cohere import CohereCitation
OUTPUT_KEY = "output"
GROUNDED_ANSWER_KEY = "grounded_answer"
class CohereToolsReactAgentOutputParser( class CohereToolsReactAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]] BaseOutputParser[Union[List[AgentAction], AgentFinish]]
@ -23,7 +28,13 @@ class CohereToolsReactAgentOutputParser(
"cited_docs": "Cited Documents:", "cited_docs": "Cited Documents:",
} }
parsed_answer = parse_answer_with_prefixes(text, prefix_map) parsed_answer = parse_answer_with_prefixes(text, prefix_map)
return AgentFinish({"output": parsed_answer["answer"]}, text) return AgentFinish(
return_values={
OUTPUT_KEY: parsed_answer["answer"],
GROUNDED_ANSWER_KEY: parsed_answer["grounded_answer"],
},
log=text,
)
elif any([x in text for x in ["Plan: ", "Reflection: ", "Action: "]]): elif any([x in text for x in ["Plan: ", "Reflection: ", "Action: "]]):
completion, plan, actions = parse_actions(text) completion, plan, actions = parse_actions(text)
agent_actions: List[AgentAction] = [] agent_actions: List[AgentAction] = []
@ -149,3 +160,144 @@ def parse_actions(generation: str) -> Tuple[str, str, List[Dict]]:
parsed_actions = parse_jsonified_tool_use_generation(actions, "Action:") parsed_actions = parse_jsonified_tool_use_generation(actions, "Action:")
return generation, plan, parsed_actions return generation, plan, parsed_actions
def parse_citations(
grounded_answer: str, documents: List[Mapping]
) -> Tuple[str, List[CohereCitation]]:
"""
Parses a grounded_generation (from parse_actions) and documents (from
convert_to_documents) into a (generation, CohereCitation list) tuple.
"""
no_markup_answer, parsed_answer = _parse_answer_spans(grounded_answer)
citations: List[CohereCitation] = []
start = 0
for answer in parsed_answer:
text = answer.get("text", "")
document_indexes = answer.get("cited_docs")
if not document_indexes:
# There were no citations for this piece of text.
start += len(text)
continue
end = start + len(text)
# Look up the cited document by index
cited_documents: List[Mapping] = []
for index in set(document_indexes):
if index >= len(documents):
# The document index doesn't exist
continue
cited_documents.append(documents[index])
citations.append(
CohereCitation(
start=start,
end=end,
text=text,
documents=cited_documents,
)
)
start = end
return no_markup_answer, citations
def _strip_spans(answer: str) -> str:
"""removes any <co> tags from a string, including trailing partial tags
input: "hi my <co>name</co> is <co: 1> patrick</co:3> and <co"
output: "hi my name is patrick and"
Args:
answer (str): string
Returns:
str: same string with co tags removed
"""
answer = re.sub(r"<co(.*?)>|</co(.*?)>", "", answer)
idx = answer.find("<co")
if idx > -1:
answer = answer[:idx]
idx = answer.find("</")
if idx > -1:
answer = answer[:idx]
return answer
def _parse_answer_spans(grounded_answer: str) -> Tuple[str, List[Dict[str, Any]]]:
actual_cites = []
for c in re.findall(r"<co:(.*?)>", grounded_answer):
actual_cites.append(c.strip().split(","))
no_markup_answer = _strip_spans(grounded_answer)
current_idx = 0
parsed_answer: List[Dict[str, Union[str, List[int]]]] = []
cited_docs_set = []
last_entry_is_open_cite = False
parsed_current_cite_document_idxs: List[int] = []
while current_idx < len(grounded_answer):
current_cite = re.search(r"<co: (.*?)>", grounded_answer[current_idx:])
if current_cite:
# previous part
parsed_answer.append(
{
"text": grounded_answer[
current_idx : current_idx + current_cite.start()
]
}
)
current_cite_document_idxs = current_cite.group(1).split(",")
parsed_current_cite_document_idxs = []
for cited_idx in current_cite_document_idxs:
if cited_idx.isdigit():
cited_idx = int(cited_idx.strip())
parsed_current_cite_document_idxs.append(cited_idx)
if cited_idx not in cited_docs_set:
cited_docs_set.append(cited_idx)
current_idx += current_cite.end()
current_cite_close = re.search(
r"</co: " + current_cite.group(1) + ">", grounded_answer[current_idx:]
)
if current_cite_close:
# there might have been issues parsing the ids, so we need to check
# that they are actually ints and available
if len(parsed_current_cite_document_idxs) > 0:
pt = grounded_answer[
current_idx : current_idx + current_cite_close.start()
]
parsed_answer.append(
{"text": pt, "cited_docs": parsed_current_cite_document_idxs}
)
else:
parsed_answer.append(
{
"text": grounded_answer[
current_idx : current_idx + current_cite_close.start()
],
}
)
current_idx += current_cite_close.end()
else:
last_entry_is_open_cite = True
break
else:
break
# don't forget about the last one
if last_entry_is_open_cite:
pt = _strip_spans(grounded_answer[current_idx:])
parsed_answer.append(
{"text": pt, "cited_docs": parsed_current_cite_document_idxs}
)
else:
parsed_answer.append({"text": _strip_spans(grounded_answer[current_idx:])})
return no_markup_answer, parsed_answer

View File

@ -108,36 +108,12 @@ def render_observations(
index: int, index: int,
) -> Tuple[BaseMessage, int]: ) -> Tuple[BaseMessage, int]:
"""Renders the 'output' part of an Agent's intermediate step into prompt content.""" """Renders the 'output' part of an Agent's intermediate step into prompt content."""
if ( documents = convert_to_documents(observations)
not isinstance(observations, list)
and not isinstance(observations, str)
and not isinstance(observations, Mapping)
):
raise ValueError("observation must be a list, a Mapping, or a string")
rendered_documents = [] rendered_documents: List[str] = []
document_prompt = """Document: {index} document_prompt = """Document: {index}
{fields}""" {fields}"""
for doc in documents:
if isinstance(observations, str):
# strings are turned into a key/value pair and a key of 'output' is added.
observations = [{"output": observations}] # type: ignore
if isinstance(observations, Mapping):
# single items are transformed into a list to simplify the rest of the code.
observations = [observations]
if isinstance(observations, list):
for doc in observations:
if isinstance(doc, str):
# strings are turned into a key/value pair.
doc = {"output": doc}
if not isinstance(doc, Mapping):
raise ValueError(
"all observation list items must be a Mapping or a string"
)
# Render document fields into Key: value strings. # Render document fields into Key: value strings.
fields: List[str] = [] fields: List[str] = []
for k, v in doc.items(): for k, v in doc.items():
@ -161,6 +137,30 @@ def render_observations(
return SystemMessage(content=prompt_content), index return SystemMessage(content=prompt_content), index
def convert_to_documents(
observations: Any,
) -> List[Mapping]:
"""Converts observations into a 'document' dict"""
documents: List[Mapping] = []
if isinstance(observations, str):
# strings are turned into a key/value pair and a key of 'output' is added.
observations = [{"output": observations}]
elif isinstance(observations, Mapping):
# single mappings are transformed into a list to simplify the rest of the code.
observations = [observations]
elif not isinstance(observations, Sequence):
# all other types are turned into a key/value pair within a list
observations = [{"output": observations}]
for doc in observations:
if not isinstance(doc, Mapping):
# types that aren't Mapping are turned into a key/value pair.
doc = {"output": doc}
documents.append(doc)
return documents
def render_intermediate_steps( def render_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, Any]], intermediate_steps: List[Tuple[AgentAction, Any]],
) -> str: ) -> str:

View File

@ -305,13 +305,13 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency
[[package]] [[package]]
name = "cohere" name = "cohere"
version = "5.1.7" version = "5.1.8"
description = "" description = ""
optional = false optional = false
python-versions = "<4.0,>=3.8" python-versions = "<4.0,>=3.8"
files = [ files = [
{file = "cohere-5.1.7-py3-none-any.whl", hash = "sha256:66e149425ba10d9d6ed2980ad869afae2ed79b1f4c375f215ff4953f389cf5f9"}, {file = "cohere-5.1.8-py3-none-any.whl", hash = "sha256:420ebd0fe8fb34c69adfd6081d75cd3954f498f27dff44e0afa539958e9179ed"},
{file = "cohere-5.1.7.tar.gz", hash = "sha256:5b5ba38e614313d96f4eb362046a3470305e57119e39538afa3220a27614ba15"}, {file = "cohere-5.1.8.tar.gz", hash = "sha256:2ce7e8541c834d5c01991ededf1d1535f76fef48515fb06dc00f284b62245b9c"},
] ]
[package.dependencies] [package.dependencies]
@ -1035,7 +1035,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "orjson-3.10.0-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:47af5d4b850a2d1328660661f0881b67fdbe712aea905dadd413bdea6f792c33"},
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c90681333619d78360d13840c7235fdaf01b2b129cb3a4f1647783b1971542b6"}, {file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c90681333619d78360d13840c7235fdaf01b2b129cb3a4f1647783b1971542b6"},
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:400c5b7c4222cb27b5059adf1fb12302eebcabf1978f33d0824aa5277ca899bd"}, {file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:400c5b7c4222cb27b5059adf1fb12302eebcabf1978f33d0824aa5277ca899bd"},
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dcb32e949eae80fb335e63b90e5808b4b0f64e31476b3777707416b41682db5"}, {file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dcb32e949eae80fb335e63b90e5808b4b0f64e31476b3777707416b41682db5"},
@ -1063,9 +1062,6 @@ files = [
{file = "orjson-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:237ba922aef472761acd697eef77fef4831ab769a42e83c04ac91e9f9e08fa0e"}, {file = "orjson-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:237ba922aef472761acd697eef77fef4831ab769a42e83c04ac91e9f9e08fa0e"},
{file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98c1bfc6a9bec52bc8f0ab9b86cc0874b0299fccef3562b793c1576cf3abb570"}, {file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98c1bfc6a9bec52bc8f0ab9b86cc0874b0299fccef3562b793c1576cf3abb570"},
{file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:30d795a24be16c03dca0c35ca8f9c8eaaa51e3342f2c162d327bd0225118794a"}, {file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:30d795a24be16c03dca0c35ca8f9c8eaaa51e3342f2c162d327bd0225118794a"},
{file = "orjson-3.10.0-cp312-none-win32.whl", hash = "sha256:6a3f53dc650bc860eb26ec293dfb489b2f6ae1cbfc409a127b01229980e372f7"},
{file = "orjson-3.10.0-cp312-none-win_amd64.whl", hash = "sha256:983db1f87c371dc6ffc52931eb75f9fe17dc621273e43ce67bee407d3e5476e9"},
{file = "orjson-3.10.0-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9a667769a96a72ca67237224a36faf57db0c82ab07d09c3aafc6f956196cfa1b"},
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ade1e21dfde1d37feee8cf6464c20a2f41fa46c8bcd5251e761903e46102dc6b"}, {file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ade1e21dfde1d37feee8cf6464c20a2f41fa46c8bcd5251e761903e46102dc6b"},
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23c12bb4ced1c3308eff7ba5c63ef8f0edb3e4c43c026440247dd6c1c61cea4b"}, {file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23c12bb4ced1c3308eff7ba5c63ef8f0edb3e4c43c026440247dd6c1c61cea4b"},
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b2d014cf8d4dc9f03fc9f870de191a49a03b1bcda51f2a957943fb9fafe55aac"}, {file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b2d014cf8d4dc9f03fc9f870de191a49a03b1bcda51f2a957943fb9fafe55aac"},
@ -1075,7 +1071,6 @@ files = [
{file = "orjson-3.10.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:13b5d3c795b09a466ec9fcf0bd3ad7b85467d91a60113885df7b8d639a9d374b"}, {file = "orjson-3.10.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:13b5d3c795b09a466ec9fcf0bd3ad7b85467d91a60113885df7b8d639a9d374b"},
{file = "orjson-3.10.0-cp38-none-win32.whl", hash = "sha256:5d42768db6f2ce0162544845facb7c081e9364a5eb6d2ef06cd17f6050b048d8"}, {file = "orjson-3.10.0-cp38-none-win32.whl", hash = "sha256:5d42768db6f2ce0162544845facb7c081e9364a5eb6d2ef06cd17f6050b048d8"},
{file = "orjson-3.10.0-cp38-none-win_amd64.whl", hash = "sha256:33e6655a2542195d6fd9f850b428926559dee382f7a862dae92ca97fea03a5ad"}, {file = "orjson-3.10.0-cp38-none-win_amd64.whl", hash = "sha256:33e6655a2542195d6fd9f850b428926559dee382f7a862dae92ca97fea03a5ad"},
{file = "orjson-3.10.0-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4050920e831a49d8782a1720d3ca2f1c49b150953667eed6e5d63a62e80f46a2"},
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1897aa25a944cec774ce4a0e1c8e98fb50523e97366c637b7d0cddabc42e6643"}, {file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1897aa25a944cec774ce4a0e1c8e98fb50523e97366c637b7d0cddabc42e6643"},
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9bf565a69e0082ea348c5657401acec3cbbb31564d89afebaee884614fba36b4"}, {file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9bf565a69e0082ea348c5657401acec3cbbb31564d89afebaee884614fba36b4"},
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b6ebc17cfbbf741f5c1a888d1854354536f63d84bee537c9a7c0335791bb9009"}, {file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b6ebc17cfbbf741f5c1a888d1854354536f63d84bee537c9a7c0335791bb9009"},
@ -1335,7 +1330,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -1769,4 +1763,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "7546180410ed197e1c2aa9830e32e3a40ebcd930a86a9e3398cd8fe6123b6888" content-hash = "00abb29a38cdcc616e802bfa33a08db9e04faa5565ca2fcbcc0fcacc10c02ba7"

View File

@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
langchain-core = "^0.1.32" langchain-core = "^0.1.32"
cohere = "^5.1.4" cohere = ">=5.1.8,<5.2"
[tool.poetry.group.test] [tool.poetry.group.test]
optional = true optional = true

View File

@ -73,3 +73,4 @@ def test_invoke_multihop_agent() -> None:
assert "output" in actual assert "output" in actual
assert "best buy" in actual["output"].lower() assert "best buy" in actual["output"].lower()
assert "citations" in actual

View File

@ -0,0 +1,72 @@
from typing import Any, Dict
from unittest import mock
import pytest
from langchain_core.agents import AgentAction, AgentFinish
from langchain_cohere import CohereCitation
from langchain_cohere.react_multi_hop.agent import _AddCitations
CITATIONS = [CohereCitation(start=1, end=2, text="foo", documents=[{"bar": "baz"}])]
GENERATION = "mocked generation"
@pytest.mark.parametrize(
"invoke_with,expected",
[
pytest.param({}, [], id="no agent_steps or chain_input"),
pytest.param(
{
"chain_input": {"intermediate_steps": []},
"agent_steps": [
AgentAction(
tool="tool_name", tool_input="tool_input", log="tool_log"
)
],
},
[AgentAction(tool="tool_name", tool_input="tool_input", log="tool_log")],
id="not an AgentFinish",
),
pytest.param(
{
"chain_input": {
"intermediate_steps": [
(
AgentAction(
tool="tool_name",
tool_input="tool_input",
log="tool_log",
),
{"tool_output": "output"},
)
]
},
"agent_steps": AgentFinish(
return_values={"output": "output1", "grounded_answer": GENERATION},
log="",
),
},
AgentFinish(
return_values={"output": GENERATION, "citations": CITATIONS}, log=""
),
id="AgentFinish",
),
],
)
@mock.patch(
"langchain_cohere.react_multi_hop.agent.parse_citations",
autospec=True,
return_value=(GENERATION, CITATIONS),
)
def test_add_citations(
parse_citations_mock: Any, invoke_with: Dict[str, Any], expected: Any
) -> None:
chain = _AddCitations()
actual = chain.invoke(invoke_with)
assert expected == actual
if isinstance(expected, AgentFinish):
parse_citations_mock.assert_called_once_with(
grounded_answer=GENERATION, documents=[{"tool_output": "output"}]
)

View File

@ -16,7 +16,8 @@ from tests.unit_tests.react_multi_hop import ExpectationType, read_expectation_f
"answer_sound_of_music", "answer_sound_of_music",
AgentFinish( AgentFinish(
return_values={ return_values={
"output": "Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999." # noqa: E501 "output": "Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.", # noqa: E501
"grounded_answer": "<co: 0,2>Best Buy</co: 0,2>, originally called Sound of Music, was added to <co: 2>Standard & Poor's S&P 500</co: 2> in <co: 2>1999</co: 2>.", # noqa: E501
}, },
log="Relevant Documents: 0,2,3\nCited Documents: 0,2\nAnswer: Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.\nGrounded answer: <co: 0,2>Best Buy</co: 0,2>, originally called Sound of Music, was added to <co: 2>Standard & Poor's S&P 500</co: 2> in <co: 2>1999</co: 2>.", # noqa: E501 log="Relevant Documents: 0,2,3\nCited Documents: 0,2\nAnswer: Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.\nGrounded answer: <co: 0,2>Best Buy</co: 0,2>, originally called Sound of Music, was added to <co: 2>Standard & Poor's S&P 500</co: 2> in <co: 2>1999</co: 2>.", # noqa: E501
), ),

View File

@ -0,0 +1,86 @@
from typing import List, Mapping
import pytest
from langchain_cohere import CohereCitation
from langchain_cohere.react_multi_hop.parsing import parse_citations
DOCUMENTS = [{"foo": "bar"}, {"baz": "foobar"}]
@pytest.mark.parametrize(
"text,documents,expected_generation,expected_citations",
[
pytest.param(
"no citations",
DOCUMENTS,
"no citations",
[],
id="no citations",
),
pytest.param(
"with <co: 0>one citation</co: 0>.",
DOCUMENTS,
"with one citation.",
[
CohereCitation(
start=5, end=17, text="one citation", documents=[DOCUMENTS[0]]
)
],
id="one citation (normal)",
),
pytest.param(
"with <co: 0,1>two documents</co: 0,1>.",
DOCUMENTS,
"with two documents.",
[
CohereCitation(
start=5,
end=18,
text="two documents",
documents=[DOCUMENTS[0], DOCUMENTS[1]],
)
],
id="two cited documents (normal)",
),
pytest.param(
"with <co: 0>two</co: 0> <co: 1>citations</co: 1>.",
DOCUMENTS,
"with two citations.",
[
CohereCitation(start=5, end=8, text="two", documents=[DOCUMENTS[0]]),
CohereCitation(
start=9, end=18, text="citations", documents=[DOCUMENTS[1]]
),
],
id="more than one citation (normal)",
),
pytest.param(
"with <co: 2>incorrect citation</co: 2>.",
DOCUMENTS,
"with incorrect citation.",
[
CohereCitation(
start=5,
end=23,
text="incorrect citation",
documents=[], # note no documents.
)
],
id="cited document doesn't exist (abnormal)",
),
],
)
def test_parse_citations(
text: str,
documents: List[Mapping],
expected_generation: str,
expected_citations: List[CohereCitation],
) -> None:
actual_generation, actual_citations = parse_citations(
grounded_answer=text, documents=documents
)
assert expected_generation == actual_generation
assert expected_citations == actual_citations
for citation in actual_citations:
assert text[citation.start : citation.end]

View File

@ -61,6 +61,16 @@ document_template = """Document: {index}
), ),
id="list of dictionaries", id="list of dictionaries",
), ),
pytest.param(
2,
document_template.format(index=0, fields="Output: 2"),
id="int",
),
pytest.param(
[2],
document_template.format(index=0, fields="Output: 2"),
id="list of int",
),
], ],
) )
def test_render_observation_has_correct_content( def test_render_observation_has_correct_content(

View File

@ -1,6 +1,7 @@
from langchain_cohere import __all__ from langchain_cohere import __all__
EXPECTED_ALL = [ EXPECTED_ALL = [
"CohereCitation",
"ChatCohere", "ChatCohere",
"CohereEmbeddings", "CohereEmbeddings",
"CohereRagRetriever", "CohereRagRetriever",