From 8bebc9206fb77ee22a9b0592c1efb32f27bb45db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Kuci=C5=84ski?= <141181390+ds-jakub-kucinski@users.noreply.github.com> Date: Wed, 16 Aug 2023 22:30:15 +0200 Subject: [PATCH] Add improved sources splitting in BaseQAWithSourcesChain (#8716) ## Type: Improvement --- ## Description: Running QAWithSourcesChain sometimes raises ValueError as mentioned in issue #7184: ``` ValueError: too many values to unpack (expected 2) Traceback: response = qa({"question": pregunta}, return_only_outputs=True) File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\base.py", line 166, in __call__ raise e File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\base.py", line 160, in __call__ self._call(inputs, run_manager=run_manager) File "C:\Anaconda3\envs\iagen_3_10\lib\site-packages\langchain\chains\qa_with_sources\base.py", line 132, in _call answer, sources = re.split(r"SOURCES:\s", answer) ``` This is due to LLM model generating subsequent question, answer and sources, that is complement in a similar form as below: ``` SOURCES: QUESTION: FINAL ANSWER: SOURCES: ``` It leads the following line ``` re.split(r"SOURCES:\s", answer) ``` to return more than 2 elements and result in ValueError. The simple fix is to split also with "QUESTION:\s" and take the first two elements: ``` answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2] ``` Sometimes LLM might also generate some other texts, like alternative answers in a form: ``` SOURCES: SOURCES: SOURCES: ``` In such cases it is the best to split previously obtained sources with new line: ``` sources = re.split(r"\n", sources.lstrip())[0] ``` --- ## Issue: Resolves #7184 --- ## Maintainer: @baskaryan --- .../langchain/chains/qa_with_sources/base.py | 21 +++--- .../unit_tests/chains/test_qa_with_sources.py | 71 +++++++++++++++++++ 2 files changed, 83 insertions(+), 9 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py diff --git a/libs/langchain/langchain/chains/qa_with_sources/base.py b/libs/langchain/langchain/chains/qa_with_sources/base.py index 46eb1419320..38c10627a3a 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/base.py +++ b/libs/langchain/langchain/chains/qa_with_sources/base.py @@ -5,7 +5,7 @@ from __future__ import annotations import inspect import re from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from pydantic_v1 import Extra, root_validator @@ -119,6 +119,15 @@ class BaseQAWithSourcesChain(Chain, ABC): values["combine_documents_chain"] = values.pop("combine_document_chain") return values + def _split_sources(self, answer: str) -> Tuple[str, str]: + """Split sources from answer.""" + if re.search(r"SOURCES:\s", answer): + answer, sources = re.split(r"SOURCES:\s|QUESTION:\s", answer)[:2] + sources = re.split(r"\n", sources)[0] + else: + sources = "" + return answer, sources + @abstractmethod def _get_docs( self, @@ -145,10 +154,7 @@ class BaseQAWithSourcesChain(Chain, ABC): answer = self.combine_documents_chain.run( input_documents=docs, callbacks=_run_manager.get_child(), **inputs ) - if re.search(r"SOURCES:\s", answer): - answer, sources = re.split(r"SOURCES:\s", answer) - else: - sources = "" + answer, sources = self._split_sources(answer) result: Dict[str, Any] = { self.answer_key: answer, self.sources_answer_key: sources, @@ -182,10 +188,7 @@ class BaseQAWithSourcesChain(Chain, ABC): answer = await self.combine_documents_chain.arun( input_documents=docs, callbacks=_run_manager.get_child(), **inputs ) - if re.search(r"SOURCES:\s", answer): - answer, sources = re.split(r"SOURCES:\s", answer) - else: - sources = "" + answer, sources = self._split_sources(answer) result: Dict[str, Any] = { self.answer_key: answer, self.sources_answer_key: sources, diff --git a/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py b/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py new file mode 100644 index 00000000000..e69d9b5cd11 --- /dev/null +++ b/libs/langchain/tests/unit_tests/chains/test_qa_with_sources.py @@ -0,0 +1,71 @@ +import pytest + +from langchain.chains.qa_with_sources.base import QAWithSourcesChain +from tests.unit_tests.llms.fake_llm import FakeLLM + + +@pytest.mark.parametrize( + "text,answer,sources", + [ + ( + "This Agreement is governed by English law.\nSOURCES: 28-pl", + "This Agreement is governed by English law.\n", + "28-pl", + ), + ( + "This Agreement is governed by English law.\n" + "SOURCES: 28-pl\n\n" + "QUESTION: Which state/country's law governs the interpretation of the " + "contract?\n" + "FINAL ANSWER: This Agreement is governed by English law.\n" + "SOURCES: 28-pl", + "This Agreement is governed by English law.\n", + "28-pl", + ), + ( + "The president did not mention Michael Jackson in the provided content.\n" + "SOURCES: \n\n" + "Note: Since the content provided does not contain any information about " + "Michael Jackson, there are no sources to cite for this specific question.", + "The president did not mention Michael Jackson in the provided content.\n", + "", + ), + # The following text was generated by gpt-3.5-turbo + ( + "To diagnose the problem, please answer the following questions and send " + "them in one message to IT:\nA1. Are you connected to the office network? " + "VPN will not work from the office network.\nA2. Are you sure about your " + "login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n" + "A4. When was the last time you used the company VPN?\n" + "SOURCES: 1\n\n" + "ALTERNATIVE OPTION: Another option is to run the VPN in CLI, but keep in " + "mind that DNS settings may not work and there may be a need for manual " + "modification of the local resolver or /etc/hosts and/or ~/.ssh/config " + "files to be able to connect to machines in the company. With the " + "appropriate packages installed, the only thing needed to establish " + "a connection is to run the command:\nsudo openvpn --config config.ovpn" + "\n\nWe will be asked for a username and password - provide the login " + "details, the same ones that have been used so far for VPN connection, " + "connecting to the company's WiFi, or printers (in the Warsaw office)." + "\n\nFinally, just use the VPN connection.\n" + "SOURCES: 2\n\n" + "ALTERNATIVE OPTION (for Windows): Download the" + "OpenVPN client application version 2.6 or newer from the official " + "website: https://openvpn.net/community-downloads/\n" + "SOURCES: 3", + "To diagnose the problem, please answer the following questions and send " + "them in one message to IT:\nA1. Are you connected to the office network? " + "VPN will not work from the office network.\nA2. Are you sure about your " + "login/password?\nA3. Are you using any other VPN (e.g. from a client)?\n" + "A4. When was the last time you used the company VPN?\n", + "1", + ), + ], +) +def test_spliting_answer_into_answer_and_sources( + text: str, answer: str, sources: str +) -> None: + qa_chain = QAWithSourcesChain.from_llm(FakeLLM()) + generated_answer, generated_sources = qa_chain._split_sources(text) + assert generated_answer == answer + assert generated_sources == sources