mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-24 20:12:11 +00:00
cohere: Add citations to agent, flexibility to tool parsing, fix SDK issue (#19965)
**Description:** Citations are the main addition in this PR. We now emit them from the multihop agent! Additionally the agent is now more flexible with observations (`Any` is now accepted), and the Cohere SDK version is bumped to fix an issue with the most recent version of pydantic v1 (1.10.15)
This commit is contained in:
parent
605c3f23e1
commit
e103492eb8
@ -1,11 +1,13 @@
|
|||||||
from langchain_cohere.chat_models import ChatCohere
|
from langchain_cohere.chat_models import ChatCohere
|
||||||
from langchain_cohere.cohere_agent import create_cohere_tools_agent
|
from langchain_cohere.cohere_agent import create_cohere_tools_agent
|
||||||
|
from langchain_cohere.common import CohereCitation
|
||||||
from langchain_cohere.embeddings import CohereEmbeddings
|
from langchain_cohere.embeddings import CohereEmbeddings
|
||||||
from langchain_cohere.rag_retrievers import CohereRagRetriever
|
from langchain_cohere.rag_retrievers import CohereRagRetriever
|
||||||
from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent
|
from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent
|
||||||
from langchain_cohere.rerank import CohereRerank
|
from langchain_cohere.rerank import CohereRerank
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"CohereCitation",
|
||||||
"ChatCohere",
|
"ChatCohere",
|
||||||
"CohereEmbeddings",
|
"CohereEmbeddings",
|
||||||
"CohereRagRetriever",
|
"CohereRagRetriever",
|
||||||
|
36
libs/partners/cohere/langchain_cohere/common.py
Normal file
36
libs/partners/cohere/langchain_cohere/common.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, List, Mapping
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CohereCitation:
|
||||||
|
"""
|
||||||
|
Cohere has fine-grained citations that specify the exact part of text.
|
||||||
|
More info at https://docs.cohere.com/docs/documents-and-citations
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
The index of text that the citation starts at, counting from zero. For example, a
|
||||||
|
generation of 'Hello, world!' with a citation on 'world' would have a start value
|
||||||
|
of 7. This is because the citation starts at 'w', which is the seventh character.
|
||||||
|
"""
|
||||||
|
start: int
|
||||||
|
|
||||||
|
"""
|
||||||
|
The index of text that the citation ends after, counting from zero. For example, a
|
||||||
|
generation of 'Hello, world!' with a citation on 'world' would have an end value of
|
||||||
|
11. This is because the citation ends after 'd', which is the eleventh character.
|
||||||
|
"""
|
||||||
|
end: int
|
||||||
|
|
||||||
|
"""
|
||||||
|
The text of the citation. For example, a generation of 'Hello, world!' with a
|
||||||
|
citation of 'world' would have a text value of 'world'.
|
||||||
|
"""
|
||||||
|
text: str
|
||||||
|
|
||||||
|
"""
|
||||||
|
The contents of the documents that were cited. When used with agents these will be
|
||||||
|
the contents of relevant agent outputs.
|
||||||
|
"""
|
||||||
|
documents: List[Mapping[str, Any]]
|
@ -5,17 +5,27 @@
|
|||||||
This agent uses a multi hop prompt by Cohere, which is experimental and subject
|
This agent uses a multi hop prompt by Cohere, which is experimental and subject
|
||||||
to change. The latest prompt can be used by upgrading the langchain-cohere package.
|
to change. The latest prompt can be used by upgrading the langchain-cohere package.
|
||||||
"""
|
"""
|
||||||
from typing import Sequence
|
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
from langchain_core.runnables import (
|
||||||
|
Runnable,
|
||||||
|
RunnableConfig,
|
||||||
|
RunnableParallel,
|
||||||
|
RunnablePassthrough,
|
||||||
|
)
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from langchain_cohere.react_multi_hop.parsing import (
|
from langchain_cohere.react_multi_hop.parsing import (
|
||||||
|
GROUNDED_ANSWER_KEY,
|
||||||
|
OUTPUT_KEY,
|
||||||
CohereToolsReactAgentOutputParser,
|
CohereToolsReactAgentOutputParser,
|
||||||
|
parse_citations,
|
||||||
)
|
)
|
||||||
from langchain_cohere.react_multi_hop.prompt import (
|
from langchain_cohere.react_multi_hop.prompt import (
|
||||||
|
convert_to_documents,
|
||||||
multi_hop_prompt,
|
multi_hop_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -36,8 +46,14 @@ def create_cohere_react_agent(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Runnable sequence representing an agent. It takes as input all the same input
|
A Runnable sequence representing an agent. It takes as input all the same input
|
||||||
variables as the prompt passed in does and returns an AgentAction or
|
variables as the prompt passed in does and returns a List[AgentAction] or a
|
||||||
AgentFinish.
|
single AgentFinish.
|
||||||
|
|
||||||
|
The AgentFinish will have two fields:
|
||||||
|
* output: str - The output string generated by the model
|
||||||
|
* citations: List[CohereCitation] - A list of citations that refer to the
|
||||||
|
output and observations made by the agent. If there are no citations this
|
||||||
|
list will be empty.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
. code-block:: python
|
. code-block:: python
|
||||||
@ -61,14 +77,61 @@ def create_cohere_react_agent(
|
|||||||
"input": "In what year was the company that was founded as Sound of Music added to the S&P 500?",
|
"input": "In what year was the company that was founded as Sound of Music added to the S&P 500?",
|
||||||
})
|
})
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
# Creates a prompt, invokes the model, and produces a
|
||||||
|
# "Union[List[AgentAction], AgentFinish]"
|
||||||
|
generate_agent_steps = (
|
||||||
|
multi_hop_prompt(tools=tools, prompt=prompt)
|
||||||
|
| llm.bind(stop=["\nObservation:"], raw_prompting=True)
|
||||||
|
| CohereToolsReactAgentOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
agent = (
|
agent = (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
# agent_scratchpad isn't used in this chain, but added here for
|
# agent_scratchpad isn't used in this chain, but added here for
|
||||||
# interoperability with other chains that may require it.
|
# interoperability with other chains that may require it.
|
||||||
agent_scratchpad=lambda _: [],
|
agent_scratchpad=lambda _: [],
|
||||||
)
|
)
|
||||||
| multi_hop_prompt(tools=tools, prompt=prompt)
|
| RunnableParallel(
|
||||||
| llm.bind(stop=["\nObservation:"], raw_prompting=True)
|
chain_input=RunnablePassthrough(), agent_steps=generate_agent_steps
|
||||||
| CohereToolsReactAgentOutputParser()
|
)
|
||||||
|
| _AddCitations()
|
||||||
)
|
)
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
class _AddCitations(Runnable):
|
||||||
|
"""
|
||||||
|
Adds a list of citations to the output of the Cohere multi hop chain when the
|
||||||
|
last step is an AgentFinish. Citations are generated from the observations (made
|
||||||
|
in previous agent steps) and the grounded answer (made in the last step).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
|
||||||
|
) -> Union[List[AgentAction], AgentFinish]:
|
||||||
|
agent_steps = input.get("agent_steps", [])
|
||||||
|
if not agent_steps:
|
||||||
|
# The input wasn't as expected.
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not isinstance(agent_steps, AgentFinish):
|
||||||
|
# We're not on the AgentFinish step.
|
||||||
|
return agent_steps
|
||||||
|
agent_finish = agent_steps
|
||||||
|
|
||||||
|
# Build a list of documents from the intermediate_steps used in this chain.
|
||||||
|
intermediate_steps = input.get("chain_input", {}).get("intermediate_steps", [])
|
||||||
|
documents: List[Mapping] = []
|
||||||
|
for _, observation in intermediate_steps:
|
||||||
|
documents.extend(convert_to_documents(observation))
|
||||||
|
|
||||||
|
# Build a list of citations, if any, from the documents + grounded answer.
|
||||||
|
grounded_answer = agent_finish.return_values.pop(GROUNDED_ANSWER_KEY, "")
|
||||||
|
output, citations = parse_citations(
|
||||||
|
grounded_answer=grounded_answer, documents=documents
|
||||||
|
)
|
||||||
|
agent_finish.return_values[OUTPUT_KEY] = output
|
||||||
|
agent_finish.return_values["citations"] = citations
|
||||||
|
|
||||||
|
return agent_finish
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Mapping, Tuple, Union
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
|
|
||||||
|
from langchain_cohere import CohereCitation
|
||||||
|
|
||||||
|
OUTPUT_KEY = "output"
|
||||||
|
GROUNDED_ANSWER_KEY = "grounded_answer"
|
||||||
|
|
||||||
|
|
||||||
class CohereToolsReactAgentOutputParser(
|
class CohereToolsReactAgentOutputParser(
|
||||||
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
|
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
|
||||||
@ -23,7 +28,13 @@ class CohereToolsReactAgentOutputParser(
|
|||||||
"cited_docs": "Cited Documents:",
|
"cited_docs": "Cited Documents:",
|
||||||
}
|
}
|
||||||
parsed_answer = parse_answer_with_prefixes(text, prefix_map)
|
parsed_answer = parse_answer_with_prefixes(text, prefix_map)
|
||||||
return AgentFinish({"output": parsed_answer["answer"]}, text)
|
return AgentFinish(
|
||||||
|
return_values={
|
||||||
|
OUTPUT_KEY: parsed_answer["answer"],
|
||||||
|
GROUNDED_ANSWER_KEY: parsed_answer["grounded_answer"],
|
||||||
|
},
|
||||||
|
log=text,
|
||||||
|
)
|
||||||
elif any([x in text for x in ["Plan: ", "Reflection: ", "Action: "]]):
|
elif any([x in text for x in ["Plan: ", "Reflection: ", "Action: "]]):
|
||||||
completion, plan, actions = parse_actions(text)
|
completion, plan, actions = parse_actions(text)
|
||||||
agent_actions: List[AgentAction] = []
|
agent_actions: List[AgentAction] = []
|
||||||
@ -149,3 +160,144 @@ def parse_actions(generation: str) -> Tuple[str, str, List[Dict]]:
|
|||||||
|
|
||||||
parsed_actions = parse_jsonified_tool_use_generation(actions, "Action:")
|
parsed_actions = parse_jsonified_tool_use_generation(actions, "Action:")
|
||||||
return generation, plan, parsed_actions
|
return generation, plan, parsed_actions
|
||||||
|
|
||||||
|
|
||||||
|
def parse_citations(
|
||||||
|
grounded_answer: str, documents: List[Mapping]
|
||||||
|
) -> Tuple[str, List[CohereCitation]]:
|
||||||
|
"""
|
||||||
|
Parses a grounded_generation (from parse_actions) and documents (from
|
||||||
|
convert_to_documents) into a (generation, CohereCitation list) tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
no_markup_answer, parsed_answer = _parse_answer_spans(grounded_answer)
|
||||||
|
citations: List[CohereCitation] = []
|
||||||
|
start = 0
|
||||||
|
|
||||||
|
for answer in parsed_answer:
|
||||||
|
text = answer.get("text", "")
|
||||||
|
document_indexes = answer.get("cited_docs")
|
||||||
|
if not document_indexes:
|
||||||
|
# There were no citations for this piece of text.
|
||||||
|
start += len(text)
|
||||||
|
continue
|
||||||
|
end = start + len(text)
|
||||||
|
|
||||||
|
# Look up the cited document by index
|
||||||
|
cited_documents: List[Mapping] = []
|
||||||
|
for index in set(document_indexes):
|
||||||
|
if index >= len(documents):
|
||||||
|
# The document index doesn't exist
|
||||||
|
continue
|
||||||
|
cited_documents.append(documents[index])
|
||||||
|
|
||||||
|
citations.append(
|
||||||
|
CohereCitation(
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
text=text,
|
||||||
|
documents=cited_documents,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return no_markup_answer, citations
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_spans(answer: str) -> str:
|
||||||
|
"""removes any <co> tags from a string, including trailing partial tags
|
||||||
|
|
||||||
|
input: "hi my <co>name</co> is <co: 1> patrick</co:3> and <co"
|
||||||
|
output: "hi my name is patrick and"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
answer (str): string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: same string with co tags removed
|
||||||
|
"""
|
||||||
|
answer = re.sub(r"<co(.*?)>|</co(.*?)>", "", answer)
|
||||||
|
idx = answer.find("<co")
|
||||||
|
if idx > -1:
|
||||||
|
answer = answer[:idx]
|
||||||
|
idx = answer.find("</")
|
||||||
|
if idx > -1:
|
||||||
|
answer = answer[:idx]
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_answer_spans(grounded_answer: str) -> Tuple[str, List[Dict[str, Any]]]:
|
||||||
|
actual_cites = []
|
||||||
|
for c in re.findall(r"<co:(.*?)>", grounded_answer):
|
||||||
|
actual_cites.append(c.strip().split(","))
|
||||||
|
no_markup_answer = _strip_spans(grounded_answer)
|
||||||
|
|
||||||
|
current_idx = 0
|
||||||
|
parsed_answer: List[Dict[str, Union[str, List[int]]]] = []
|
||||||
|
cited_docs_set = []
|
||||||
|
last_entry_is_open_cite = False
|
||||||
|
parsed_current_cite_document_idxs: List[int] = []
|
||||||
|
|
||||||
|
while current_idx < len(grounded_answer):
|
||||||
|
current_cite = re.search(r"<co: (.*?)>", grounded_answer[current_idx:])
|
||||||
|
if current_cite:
|
||||||
|
# previous part
|
||||||
|
parsed_answer.append(
|
||||||
|
{
|
||||||
|
"text": grounded_answer[
|
||||||
|
current_idx : current_idx + current_cite.start()
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
current_cite_document_idxs = current_cite.group(1).split(",")
|
||||||
|
parsed_current_cite_document_idxs = []
|
||||||
|
for cited_idx in current_cite_document_idxs:
|
||||||
|
if cited_idx.isdigit():
|
||||||
|
cited_idx = int(cited_idx.strip())
|
||||||
|
parsed_current_cite_document_idxs.append(cited_idx)
|
||||||
|
if cited_idx not in cited_docs_set:
|
||||||
|
cited_docs_set.append(cited_idx)
|
||||||
|
|
||||||
|
current_idx += current_cite.end()
|
||||||
|
|
||||||
|
current_cite_close = re.search(
|
||||||
|
r"</co: " + current_cite.group(1) + ">", grounded_answer[current_idx:]
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_cite_close:
|
||||||
|
# there might have been issues parsing the ids, so we need to check
|
||||||
|
# that they are actually ints and available
|
||||||
|
if len(parsed_current_cite_document_idxs) > 0:
|
||||||
|
pt = grounded_answer[
|
||||||
|
current_idx : current_idx + current_cite_close.start()
|
||||||
|
]
|
||||||
|
parsed_answer.append(
|
||||||
|
{"text": pt, "cited_docs": parsed_current_cite_document_idxs}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parsed_answer.append(
|
||||||
|
{
|
||||||
|
"text": grounded_answer[
|
||||||
|
current_idx : current_idx + current_cite_close.start()
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
current_idx += current_cite_close.end()
|
||||||
|
|
||||||
|
else:
|
||||||
|
last_entry_is_open_cite = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# don't forget about the last one
|
||||||
|
if last_entry_is_open_cite:
|
||||||
|
pt = _strip_spans(grounded_answer[current_idx:])
|
||||||
|
parsed_answer.append(
|
||||||
|
{"text": pt, "cited_docs": parsed_current_cite_document_idxs}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parsed_answer.append({"text": _strip_spans(grounded_answer[current_idx:])})
|
||||||
|
return no_markup_answer, parsed_answer
|
||||||
|
@ -108,36 +108,12 @@ def render_observations(
|
|||||||
index: int,
|
index: int,
|
||||||
) -> Tuple[BaseMessage, int]:
|
) -> Tuple[BaseMessage, int]:
|
||||||
"""Renders the 'output' part of an Agent's intermediate step into prompt content."""
|
"""Renders the 'output' part of an Agent's intermediate step into prompt content."""
|
||||||
if (
|
documents = convert_to_documents(observations)
|
||||||
not isinstance(observations, list)
|
|
||||||
and not isinstance(observations, str)
|
|
||||||
and not isinstance(observations, Mapping)
|
|
||||||
):
|
|
||||||
raise ValueError("observation must be a list, a Mapping, or a string")
|
|
||||||
|
|
||||||
rendered_documents = []
|
rendered_documents: List[str] = []
|
||||||
document_prompt = """Document: {index}
|
document_prompt = """Document: {index}
|
||||||
{fields}"""
|
{fields}"""
|
||||||
|
for doc in documents:
|
||||||
if isinstance(observations, str):
|
|
||||||
# strings are turned into a key/value pair and a key of 'output' is added.
|
|
||||||
observations = [{"output": observations}] # type: ignore
|
|
||||||
|
|
||||||
if isinstance(observations, Mapping):
|
|
||||||
# single items are transformed into a list to simplify the rest of the code.
|
|
||||||
observations = [observations]
|
|
||||||
|
|
||||||
if isinstance(observations, list):
|
|
||||||
for doc in observations:
|
|
||||||
if isinstance(doc, str):
|
|
||||||
# strings are turned into a key/value pair.
|
|
||||||
doc = {"output": doc}
|
|
||||||
|
|
||||||
if not isinstance(doc, Mapping):
|
|
||||||
raise ValueError(
|
|
||||||
"all observation list items must be a Mapping or a string"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Render document fields into Key: value strings.
|
# Render document fields into Key: value strings.
|
||||||
fields: List[str] = []
|
fields: List[str] = []
|
||||||
for k, v in doc.items():
|
for k, v in doc.items():
|
||||||
@ -161,6 +137,30 @@ def render_observations(
|
|||||||
return SystemMessage(content=prompt_content), index
|
return SystemMessage(content=prompt_content), index
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_documents(
|
||||||
|
observations: Any,
|
||||||
|
) -> List[Mapping]:
|
||||||
|
"""Converts observations into a 'document' dict"""
|
||||||
|
documents: List[Mapping] = []
|
||||||
|
if isinstance(observations, str):
|
||||||
|
# strings are turned into a key/value pair and a key of 'output' is added.
|
||||||
|
observations = [{"output": observations}]
|
||||||
|
elif isinstance(observations, Mapping):
|
||||||
|
# single mappings are transformed into a list to simplify the rest of the code.
|
||||||
|
observations = [observations]
|
||||||
|
elif not isinstance(observations, Sequence):
|
||||||
|
# all other types are turned into a key/value pair within a list
|
||||||
|
observations = [{"output": observations}]
|
||||||
|
|
||||||
|
for doc in observations:
|
||||||
|
if not isinstance(doc, Mapping):
|
||||||
|
# types that aren't Mapping are turned into a key/value pair.
|
||||||
|
doc = {"output": doc}
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
|
||||||
def render_intermediate_steps(
|
def render_intermediate_steps(
|
||||||
intermediate_steps: List[Tuple[AgentAction, Any]],
|
intermediate_steps: List[Tuple[AgentAction, Any]],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
14
libs/partners/cohere/poetry.lock
generated
14
libs/partners/cohere/poetry.lock
generated
@ -305,13 +305,13 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cohere"
|
name = "cohere"
|
||||||
version = "5.1.7"
|
version = "5.1.8"
|
||||||
description = ""
|
description = ""
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<4.0,>=3.8"
|
python-versions = "<4.0,>=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "cohere-5.1.7-py3-none-any.whl", hash = "sha256:66e149425ba10d9d6ed2980ad869afae2ed79b1f4c375f215ff4953f389cf5f9"},
|
{file = "cohere-5.1.8-py3-none-any.whl", hash = "sha256:420ebd0fe8fb34c69adfd6081d75cd3954f498f27dff44e0afa539958e9179ed"},
|
||||||
{file = "cohere-5.1.7.tar.gz", hash = "sha256:5b5ba38e614313d96f4eb362046a3470305e57119e39538afa3220a27614ba15"},
|
{file = "cohere-5.1.8.tar.gz", hash = "sha256:2ce7e8541c834d5c01991ededf1d1535f76fef48515fb06dc00f284b62245b9c"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -1035,7 +1035,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "orjson-3.10.0-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:47af5d4b850a2d1328660661f0881b67fdbe712aea905dadd413bdea6f792c33"},
|
|
||||||
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c90681333619d78360d13840c7235fdaf01b2b129cb3a4f1647783b1971542b6"},
|
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c90681333619d78360d13840c7235fdaf01b2b129cb3a4f1647783b1971542b6"},
|
||||||
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:400c5b7c4222cb27b5059adf1fb12302eebcabf1978f33d0824aa5277ca899bd"},
|
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:400c5b7c4222cb27b5059adf1fb12302eebcabf1978f33d0824aa5277ca899bd"},
|
||||||
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dcb32e949eae80fb335e63b90e5808b4b0f64e31476b3777707416b41682db5"},
|
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dcb32e949eae80fb335e63b90e5808b4b0f64e31476b3777707416b41682db5"},
|
||||||
@ -1063,9 +1062,6 @@ files = [
|
|||||||
{file = "orjson-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:237ba922aef472761acd697eef77fef4831ab769a42e83c04ac91e9f9e08fa0e"},
|
{file = "orjson-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:237ba922aef472761acd697eef77fef4831ab769a42e83c04ac91e9f9e08fa0e"},
|
||||||
{file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98c1bfc6a9bec52bc8f0ab9b86cc0874b0299fccef3562b793c1576cf3abb570"},
|
{file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98c1bfc6a9bec52bc8f0ab9b86cc0874b0299fccef3562b793c1576cf3abb570"},
|
||||||
{file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:30d795a24be16c03dca0c35ca8f9c8eaaa51e3342f2c162d327bd0225118794a"},
|
{file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:30d795a24be16c03dca0c35ca8f9c8eaaa51e3342f2c162d327bd0225118794a"},
|
||||||
{file = "orjson-3.10.0-cp312-none-win32.whl", hash = "sha256:6a3f53dc650bc860eb26ec293dfb489b2f6ae1cbfc409a127b01229980e372f7"},
|
|
||||||
{file = "orjson-3.10.0-cp312-none-win_amd64.whl", hash = "sha256:983db1f87c371dc6ffc52931eb75f9fe17dc621273e43ce67bee407d3e5476e9"},
|
|
||||||
{file = "orjson-3.10.0-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9a667769a96a72ca67237224a36faf57db0c82ab07d09c3aafc6f956196cfa1b"},
|
|
||||||
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ade1e21dfde1d37feee8cf6464c20a2f41fa46c8bcd5251e761903e46102dc6b"},
|
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ade1e21dfde1d37feee8cf6464c20a2f41fa46c8bcd5251e761903e46102dc6b"},
|
||||||
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23c12bb4ced1c3308eff7ba5c63ef8f0edb3e4c43c026440247dd6c1c61cea4b"},
|
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23c12bb4ced1c3308eff7ba5c63ef8f0edb3e4c43c026440247dd6c1c61cea4b"},
|
||||||
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b2d014cf8d4dc9f03fc9f870de191a49a03b1bcda51f2a957943fb9fafe55aac"},
|
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b2d014cf8d4dc9f03fc9f870de191a49a03b1bcda51f2a957943fb9fafe55aac"},
|
||||||
@ -1075,7 +1071,6 @@ files = [
|
|||||||
{file = "orjson-3.10.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:13b5d3c795b09a466ec9fcf0bd3ad7b85467d91a60113885df7b8d639a9d374b"},
|
{file = "orjson-3.10.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:13b5d3c795b09a466ec9fcf0bd3ad7b85467d91a60113885df7b8d639a9d374b"},
|
||||||
{file = "orjson-3.10.0-cp38-none-win32.whl", hash = "sha256:5d42768db6f2ce0162544845facb7c081e9364a5eb6d2ef06cd17f6050b048d8"},
|
{file = "orjson-3.10.0-cp38-none-win32.whl", hash = "sha256:5d42768db6f2ce0162544845facb7c081e9364a5eb6d2ef06cd17f6050b048d8"},
|
||||||
{file = "orjson-3.10.0-cp38-none-win_amd64.whl", hash = "sha256:33e6655a2542195d6fd9f850b428926559dee382f7a862dae92ca97fea03a5ad"},
|
{file = "orjson-3.10.0-cp38-none-win_amd64.whl", hash = "sha256:33e6655a2542195d6fd9f850b428926559dee382f7a862dae92ca97fea03a5ad"},
|
||||||
{file = "orjson-3.10.0-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4050920e831a49d8782a1720d3ca2f1c49b150953667eed6e5d63a62e80f46a2"},
|
|
||||||
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1897aa25a944cec774ce4a0e1c8e98fb50523e97366c637b7d0cddabc42e6643"},
|
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1897aa25a944cec774ce4a0e1c8e98fb50523e97366c637b7d0cddabc42e6643"},
|
||||||
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9bf565a69e0082ea348c5657401acec3cbbb31564d89afebaee884614fba36b4"},
|
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9bf565a69e0082ea348c5657401acec3cbbb31564d89afebaee884614fba36b4"},
|
||||||
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b6ebc17cfbbf741f5c1a888d1854354536f63d84bee537c9a7c0335791bb9009"},
|
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b6ebc17cfbbf741f5c1a888d1854354536f63d84bee537c9a7c0335791bb9009"},
|
||||||
@ -1335,7 +1330,6 @@ files = [
|
|||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
@ -1769,4 +1763,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "7546180410ed197e1c2aa9830e32e3a40ebcd930a86a9e3398cd8fe6123b6888"
|
content-hash = "00abb29a38cdcc616e802bfa33a08db9e04faa5565ca2fcbcc0fcacc10c02ba7"
|
||||||
|
@ -13,7 +13,7 @@ license = "MIT"
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
langchain-core = "^0.1.32"
|
langchain-core = "^0.1.32"
|
||||||
cohere = "^5.1.4"
|
cohere = ">=5.1.8,<5.2"
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
optional = true
|
optional = true
|
||||||
|
@ -73,3 +73,4 @@ def test_invoke_multihop_agent() -> None:
|
|||||||
|
|
||||||
assert "output" in actual
|
assert "output" in actual
|
||||||
assert "best buy" in actual["output"].lower()
|
assert "best buy" in actual["output"].lower()
|
||||||
|
assert "citations" in actual
|
||||||
|
@ -0,0 +1,72 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
|
|
||||||
|
from langchain_cohere import CohereCitation
|
||||||
|
from langchain_cohere.react_multi_hop.agent import _AddCitations
|
||||||
|
|
||||||
|
CITATIONS = [CohereCitation(start=1, end=2, text="foo", documents=[{"bar": "baz"}])]
|
||||||
|
GENERATION = "mocked generation"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invoke_with,expected",
|
||||||
|
[
|
||||||
|
pytest.param({}, [], id="no agent_steps or chain_input"),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"chain_input": {"intermediate_steps": []},
|
||||||
|
"agent_steps": [
|
||||||
|
AgentAction(
|
||||||
|
tool="tool_name", tool_input="tool_input", log="tool_log"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
[AgentAction(tool="tool_name", tool_input="tool_input", log="tool_log")],
|
||||||
|
id="not an AgentFinish",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"chain_input": {
|
||||||
|
"intermediate_steps": [
|
||||||
|
(
|
||||||
|
AgentAction(
|
||||||
|
tool="tool_name",
|
||||||
|
tool_input="tool_input",
|
||||||
|
log="tool_log",
|
||||||
|
),
|
||||||
|
{"tool_output": "output"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"agent_steps": AgentFinish(
|
||||||
|
return_values={"output": "output1", "grounded_answer": GENERATION},
|
||||||
|
log="",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
AgentFinish(
|
||||||
|
return_values={"output": GENERATION, "citations": CITATIONS}, log=""
|
||||||
|
),
|
||||||
|
id="AgentFinish",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@mock.patch(
|
||||||
|
"langchain_cohere.react_multi_hop.agent.parse_citations",
|
||||||
|
autospec=True,
|
||||||
|
return_value=(GENERATION, CITATIONS),
|
||||||
|
)
|
||||||
|
def test_add_citations(
|
||||||
|
parse_citations_mock: Any, invoke_with: Dict[str, Any], expected: Any
|
||||||
|
) -> None:
|
||||||
|
chain = _AddCitations()
|
||||||
|
actual = chain.invoke(invoke_with)
|
||||||
|
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
if isinstance(expected, AgentFinish):
|
||||||
|
parse_citations_mock.assert_called_once_with(
|
||||||
|
grounded_answer=GENERATION, documents=[{"tool_output": "output"}]
|
||||||
|
)
|
@ -16,7 +16,8 @@ from tests.unit_tests.react_multi_hop import ExpectationType, read_expectation_f
|
|||||||
"answer_sound_of_music",
|
"answer_sound_of_music",
|
||||||
AgentFinish(
|
AgentFinish(
|
||||||
return_values={
|
return_values={
|
||||||
"output": "Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999." # noqa: E501
|
"output": "Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.", # noqa: E501
|
||||||
|
"grounded_answer": "<co: 0,2>Best Buy</co: 0,2>, originally called Sound of Music, was added to <co: 2>Standard & Poor's S&P 500</co: 2> in <co: 2>1999</co: 2>.", # noqa: E501
|
||||||
},
|
},
|
||||||
log="Relevant Documents: 0,2,3\nCited Documents: 0,2\nAnswer: Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.\nGrounded answer: <co: 0,2>Best Buy</co: 0,2>, originally called Sound of Music, was added to <co: 2>Standard & Poor's S&P 500</co: 2> in <co: 2>1999</co: 2>.", # noqa: E501
|
log="Relevant Documents: 0,2,3\nCited Documents: 0,2\nAnswer: Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.\nGrounded answer: <co: 0,2>Best Buy</co: 0,2>, originally called Sound of Music, was added to <co: 2>Standard & Poor's S&P 500</co: 2> in <co: 2>1999</co: 2>.", # noqa: E501
|
||||||
),
|
),
|
||||||
|
@ -0,0 +1,86 @@
|
|||||||
|
from typing import List, Mapping
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_cohere import CohereCitation
|
||||||
|
from langchain_cohere.react_multi_hop.parsing import parse_citations
|
||||||
|
|
||||||
|
DOCUMENTS = [{"foo": "bar"}, {"baz": "foobar"}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"text,documents,expected_generation,expected_citations",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
"no citations",
|
||||||
|
DOCUMENTS,
|
||||||
|
"no citations",
|
||||||
|
[],
|
||||||
|
id="no citations",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"with <co: 0>one citation</co: 0>.",
|
||||||
|
DOCUMENTS,
|
||||||
|
"with one citation.",
|
||||||
|
[
|
||||||
|
CohereCitation(
|
||||||
|
start=5, end=17, text="one citation", documents=[DOCUMENTS[0]]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
id="one citation (normal)",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"with <co: 0,1>two documents</co: 0,1>.",
|
||||||
|
DOCUMENTS,
|
||||||
|
"with two documents.",
|
||||||
|
[
|
||||||
|
CohereCitation(
|
||||||
|
start=5,
|
||||||
|
end=18,
|
||||||
|
text="two documents",
|
||||||
|
documents=[DOCUMENTS[0], DOCUMENTS[1]],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
id="two cited documents (normal)",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"with <co: 0>two</co: 0> <co: 1>citations</co: 1>.",
|
||||||
|
DOCUMENTS,
|
||||||
|
"with two citations.",
|
||||||
|
[
|
||||||
|
CohereCitation(start=5, end=8, text="two", documents=[DOCUMENTS[0]]),
|
||||||
|
CohereCitation(
|
||||||
|
start=9, end=18, text="citations", documents=[DOCUMENTS[1]]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
id="more than one citation (normal)",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"with <co: 2>incorrect citation</co: 2>.",
|
||||||
|
DOCUMENTS,
|
||||||
|
"with incorrect citation.",
|
||||||
|
[
|
||||||
|
CohereCitation(
|
||||||
|
start=5,
|
||||||
|
end=23,
|
||||||
|
text="incorrect citation",
|
||||||
|
documents=[], # note no documents.
|
||||||
|
)
|
||||||
|
],
|
||||||
|
id="cited document doesn't exist (abnormal)",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_citations(
|
||||||
|
text: str,
|
||||||
|
documents: List[Mapping],
|
||||||
|
expected_generation: str,
|
||||||
|
expected_citations: List[CohereCitation],
|
||||||
|
) -> None:
|
||||||
|
actual_generation, actual_citations = parse_citations(
|
||||||
|
grounded_answer=text, documents=documents
|
||||||
|
)
|
||||||
|
assert expected_generation == actual_generation
|
||||||
|
assert expected_citations == actual_citations
|
||||||
|
for citation in actual_citations:
|
||||||
|
assert text[citation.start : citation.end]
|
@ -61,6 +61,16 @@ document_template = """Document: {index}
|
|||||||
),
|
),
|
||||||
id="list of dictionaries",
|
id="list of dictionaries",
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
2,
|
||||||
|
document_template.format(index=0, fields="Output: 2"),
|
||||||
|
id="int",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[2],
|
||||||
|
document_template.format(index=0, fields="Output: 2"),
|
||||||
|
id="list of int",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_render_observation_has_correct_content(
|
def test_render_observation_has_correct_content(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from langchain_cohere import __all__
|
from langchain_cohere import __all__
|
||||||
|
|
||||||
EXPECTED_ALL = [
|
EXPECTED_ALL = [
|
||||||
|
"CohereCitation",
|
||||||
"ChatCohere",
|
"ChatCohere",
|
||||||
"CohereEmbeddings",
|
"CohereEmbeddings",
|
||||||
"CohereRagRetriever",
|
"CohereRagRetriever",
|
||||||
|
Loading…
Reference in New Issue
Block a user