mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 11:39:03 +00:00
Fixed multi input prompt for MapReduceChain (#4979)
# Fixed multi input prompt for MapReduceChain Added `kwargs` support for inner chains of `MapReduceChain` via `from_params` method Currently the `from_method` method of intialising `MapReduceChain` chain doesn't work if prompt has multiple inputs. It happens because it uses `StuffDocumentsChain` and `MapReduceDocumentsChain` underneath, both of them require specifying `document_variable_name` if `prompt` of their `llm_chain` has more than one `input`. With this PR, I have added support for passing their respective `kwargs` via the `from_params` method. ## Fixes https://github.com/hwchase17/langchain/issues/4752 ## Who can review? @dev2049 @hwchase17 @agola11 --------- Co-authored-by: imeckr <chandanroutray2012@gmail.com>
This commit is contained in:
parent
a97e4252e3
commit
bc875a9df1
@ -21,7 +21,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"id": "e9db25f3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -318,6 +318,141 @@
|
||||
"chain({\"input_documents\": docs}, return_only_outputs=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "b882e209",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## The custom `MapReduceChain`\n",
|
||||
"\n",
|
||||
"**Multi input prompt**\n",
|
||||
"\n",
|
||||
"You can also use prompt with multi input. In this example, we will use a MapReduce chain to answer specifc question about our code."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "f7ad9ee2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain\n",
|
||||
"from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
|
||||
"\n",
|
||||
"map_template_string = \"\"\"Give the following python code information, generate a description that explains what the code does and also mention the time complexity.\n",
|
||||
"Code:\n",
|
||||
"{code}\n",
|
||||
"\n",
|
||||
"Return the the description in the following format:\n",
|
||||
"name of the function: description of the function\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"reduce_template_string = \"\"\"Give the following following python fuctions name and their descritpion, answer the following question\n",
|
||||
"{code_description}\n",
|
||||
"Question: {question}\n",
|
||||
"Answer:\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"MAP_PROMPT = PromptTemplate(input_variables=[\"code\"], template=map_template_string)\n",
|
||||
"REDUCE_PROMPT = PromptTemplate(input_variables=[\"code_description\", \"question\"], template=reduce_template_string)\n",
|
||||
"\n",
|
||||
"llm = OpenAI()\n",
|
||||
"\n",
|
||||
"map_llm_chain = LLMChain(llm=llm, prompt=MAP_PROMPT)\n",
|
||||
"reduce_llm_chain = LLMChain(llm=llm, prompt=REDUCE_PROMPT)\n",
|
||||
"\n",
|
||||
"generative_result_reduce_chain = StuffDocumentsChain(\n",
|
||||
" llm_chain=reduce_llm_chain,\n",
|
||||
" document_variable_name=\"code_description\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"combine_documents = MapReduceDocumentsChain(\n",
|
||||
" llm_chain=map_llm_chain,\n",
|
||||
" combine_document_chain=generative_result_reduce_chain,\n",
|
||||
" document_variable_name=\"code\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"map_reduce = MapReduceChain(\n",
|
||||
" combine_documents_chain=combine_documents,\n",
|
||||
" text_splitter=CharacterTextSplitter(separator=\"\\n##\\n\", chunk_size=100, chunk_overlap=0),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "0d4caccb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"code = \"\"\"\n",
|
||||
"def bubblesort(list):\n",
|
||||
" for iter_num in range(len(list)-1,0,-1):\n",
|
||||
" for idx in range(iter_num):\n",
|
||||
" if list[idx]>list[idx+1]:\n",
|
||||
" temp = list[idx]\n",
|
||||
" list[idx] = list[idx+1]\n",
|
||||
" list[idx+1] = temp\n",
|
||||
" return list\n",
|
||||
"##\n",
|
||||
"def insertion_sort(InputList):\n",
|
||||
" for i in range(1, len(InputList)):\n",
|
||||
" j = i-1\n",
|
||||
" nxt_element = InputList[i]\n",
|
||||
" while (InputList[j] > nxt_element) and (j >= 0):\n",
|
||||
" InputList[j+1] = InputList[j]\n",
|
||||
" j=j-1\n",
|
||||
" InputList[j+1] = nxt_element\n",
|
||||
" return InputList\n",
|
||||
"##\n",
|
||||
"def shellSort(input_list):\n",
|
||||
" gap = len(input_list) // 2\n",
|
||||
" while gap > 0:\n",
|
||||
" for i in range(gap, len(input_list)):\n",
|
||||
" temp = input_list[i]\n",
|
||||
" j = i\n",
|
||||
" while j >= gap and input_list[j - gap] > temp:\n",
|
||||
" input_list[j] = input_list[j - gap]\n",
|
||||
" j = j-gap\n",
|
||||
" input_list[j] = temp\n",
|
||||
" gap = gap//2\n",
|
||||
" return input_list\n",
|
||||
"\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "d5a9a35b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Created a chunk of size 247, which is longer than the specified 100\n",
|
||||
"Created a chunk of size 267, which is longer than the specified 100\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'shellSort has a better time complexity than both bubblesort and insertion_sort, as it has a time complexity of O(n^2), while the other two have a time complexity of O(n^2).'"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"map_reduce.run(input_text=code, question=\"Which function has a better time complexity?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f61350f9",
|
||||
@ -470,7 +605,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.8.16"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
@ -5,7 +5,7 @@ then combines the results with another one.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
@ -38,15 +38,22 @@ class MapReduceChain(Chain):
|
||||
prompt: BasePromptTemplate,
|
||||
text_splitter: TextSplitter,
|
||||
callbacks: Callbacks = None,
|
||||
combine_chain_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
reduce_chain_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceChain:
|
||||
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
|
||||
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain, callbacks=callbacks)
|
||||
reduce_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
callbacks=callbacks,
|
||||
**(reduce_chain_kwargs if reduce_chain_kwargs else {}),
|
||||
)
|
||||
combine_documents_chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
combine_document_chain=reduce_chain,
|
||||
callbacks=callbacks,
|
||||
**(combine_chain_kwargs if combine_chain_kwargs else {}),
|
||||
)
|
||||
return cls(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
@ -84,9 +91,14 @@ class MapReduceChain(Chain):
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
# Split the larger text into smaller chunks.
|
||||
texts = self.text_splitter.split_text(inputs[self.input_key])
|
||||
doc_text = inputs.pop(self.input_key)
|
||||
texts = self.text_splitter.split_text(doc_text)
|
||||
docs = [Document(page_content=text) for text in texts]
|
||||
_inputs: Dict[str, Any] = {
|
||||
**inputs,
|
||||
self.combine_documents_chain.input_key: docs,
|
||||
}
|
||||
outputs = self.combine_documents_chain.run(
|
||||
input_documents=docs, callbacks=_run_manager.get_child()
|
||||
_inputs, callbacks=_run_manager.get_child()
|
||||
)
|
||||
return {self.output_key: outputs}
|
||||
|
Loading…
Reference in New Issue
Block a user