mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[devops] remove post commit ci (#5566)
* [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from langchain.pydantic_v1 import Field
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||
"""Base class for question-answering chains."""
|
||||
|
||||
@@ -98,7 +99,6 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||
for k, v in inputs.items()
|
||||
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
|
||||
}
|
||||
answers = []
|
||||
if self.combine_documents_chain.memory is not None:
|
||||
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
|
||||
self.combine_documents_chain.memory.buffered_history
|
||||
@@ -117,10 +117,10 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
|
||||
|
||||
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
||||
rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', [])
|
||||
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
|
||||
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None
|
||||
if answer is None:
|
||||
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
|
||||
if answer is None:
|
||||
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
|
||||
if self.combine_documents_chain.memory is not None:
|
||||
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
||||
|
||||
@@ -161,10 +161,14 @@ class CustomBaseRetrievalQA(BaseRetrievalQA):
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
|
||||
)
|
||||
# if rejection_trigger_keywords is not given, return the response from LLM directly
|
||||
rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', [])
|
||||
answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords)==0 else None
|
||||
rejection_trigger_keywords = inputs.get("rejection_trigger_keywords", [])
|
||||
answer = (
|
||||
answer
|
||||
if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords) == 0
|
||||
else None
|
||||
)
|
||||
if answer is None:
|
||||
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
|
||||
answer = inputs.get("rejection_answer", "抱歉,根据提供的信息无法回答该问题。")
|
||||
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
|
||||
|
||||
if self.return_source_documents:
|
||||
|
Reference in New Issue
Block a user