mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-31 07:41:40 +00:00 
			
		
		
		
	Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
		
			
				
	
	
		
			200 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			200 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | |
|  "cells": [
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "593f7553-7038-498e-96d4-8255e5ce34f0",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "# Creating a custom Chain\n",
 | |
|     "\n",
 | |
|     "To implement your own custom chain you can subclass `Chain` and implement the following methods:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 11,
 | |
|    "id": "c19c736e-ca74-4726-bb77-0a849bcc2960",
 | |
|    "metadata": {
 | |
|     "tags": [],
 | |
|     "vscode": {
 | |
|      "languageId": "python"
 | |
|     }
 | |
|    },
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "from __future__ import annotations\n",
 | |
|     "\n",
 | |
|     "from typing import Any, Dict, List, Optional\n",
 | |
|     "\n",
 | |
|     "from pydantic import Extra\n",
 | |
|     "\n",
 | |
|     "from langchain.base_language import BaseLanguageModel\n",
 | |
|     "from langchain.callbacks.manager import (\n",
 | |
|     "    AsyncCallbackManagerForChainRun,\n",
 | |
|     "    CallbackManagerForChainRun,\n",
 | |
|     ")\n",
 | |
|     "from langchain.chains.base import Chain\n",
 | |
|     "from langchain.prompts.base import BasePromptTemplate\n",
 | |
|     "\n",
 | |
|     "\n",
 | |
|     "class MyCustomChain(Chain):\n",
 | |
|     "    \"\"\"\n",
 | |
|     "    An example of a custom chain.\n",
 | |
|     "    \"\"\"\n",
 | |
|     "\n",
 | |
|     "    prompt: BasePromptTemplate\n",
 | |
|     "    \"\"\"Prompt object to use.\"\"\"\n",
 | |
|     "    llm: BaseLanguageModel\n",
 | |
|     "    output_key: str = \"text\"  #: :meta private:\n",
 | |
|     "\n",
 | |
|     "    class Config:\n",
 | |
|     "        \"\"\"Configuration for this pydantic object.\"\"\"\n",
 | |
|     "\n",
 | |
|     "        extra = Extra.forbid\n",
 | |
|     "        arbitrary_types_allowed = True\n",
 | |
|     "\n",
 | |
|     "    @property\n",
 | |
|     "    def input_keys(self) -> List[str]:\n",
 | |
|     "        \"\"\"Will be whatever keys the prompt expects.\n",
 | |
|     "\n",
 | |
|     "        :meta private:\n",
 | |
|     "        \"\"\"\n",
 | |
|     "        return self.prompt.input_variables\n",
 | |
|     "\n",
 | |
|     "    @property\n",
 | |
|     "    def output_keys(self) -> List[str]:\n",
 | |
|     "        \"\"\"Will always return text key.\n",
 | |
|     "\n",
 | |
|     "        :meta private:\n",
 | |
|     "        \"\"\"\n",
 | |
|     "        return [self.output_key]\n",
 | |
|     "\n",
 | |
|     "    def _call(\n",
 | |
|     "        self,\n",
 | |
|     "        inputs: Dict[str, Any],\n",
 | |
|     "        run_manager: Optional[CallbackManagerForChainRun] = None,\n",
 | |
|     "    ) -> Dict[str, str]:\n",
 | |
|     "        # Your custom chain logic goes here\n",
 | |
|     "        # This is just an example that mimics LLMChain\n",
 | |
|     "        prompt_value = self.prompt.format_prompt(**inputs)\n",
 | |
|     "        \n",
 | |
|     "        # Whenever you call a language model, or another chain, you should pass\n",
 | |
|     "        # a callback manager to it. This allows the inner run to be tracked by\n",
 | |
|     "        # any callbacks that are registered on the outer run.\n",
 | |
|     "        # You can always obtain a callback manager for this by calling\n",
 | |
|     "        # `run_manager.get_child()` as shown below.\n",
 | |
|     "        response = self.llm.generate_prompt(\n",
 | |
|     "            [prompt_value],\n",
 | |
|     "            callbacks=run_manager.get_child() if run_manager else None\n",
 | |
|     "        )\n",
 | |
|     "\n",
 | |
|     "        # If you want to log something about this run, you can do so by calling\n",
 | |
|     "        # methods on the `run_manager`, as shown below. This will trigger any\n",
 | |
|     "        # callbacks that are registered for that event.\n",
 | |
|     "        if run_manager:\n",
 | |
|     "            run_manager.on_text(\"Log something about this run\")\n",
 | |
|     "        \n",
 | |
|     "        return {self.output_key: response.generations[0][0].text}\n",
 | |
|     "\n",
 | |
|     "    async def _acall(\n",
 | |
|     "        self,\n",
 | |
|     "        inputs: Dict[str, Any],\n",
 | |
|     "        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
 | |
|     "    ) -> Dict[str, str]:\n",
 | |
|     "        # Your custom chain logic goes here\n",
 | |
|     "        # This is just an example that mimics LLMChain\n",
 | |
|     "        prompt_value = self.prompt.format_prompt(**inputs)\n",
 | |
|     "        \n",
 | |
|     "        # Whenever you call a language model, or another chain, you should pass\n",
 | |
|     "        # a callback manager to it. This allows the inner run to be tracked by\n",
 | |
|     "        # any callbacks that are registered on the outer run.\n",
 | |
|     "        # You can always obtain a callback manager for this by calling\n",
 | |
|     "        # `run_manager.get_child()` as shown below.\n",
 | |
|     "        response = await self.llm.agenerate_prompt(\n",
 | |
|     "            [prompt_value],\n",
 | |
|     "            callbacks=run_manager.get_child() if run_manager else None\n",
 | |
|     "        )\n",
 | |
|     "\n",
 | |
|     "        # If you want to log something about this run, you can do so by calling\n",
 | |
|     "        # methods on the `run_manager`, as shown below. This will trigger any\n",
 | |
|     "        # callbacks that are registered for that event.\n",
 | |
|     "        if run_manager:\n",
 | |
|     "            await run_manager.on_text(\"Log something about this run\")\n",
 | |
|     "        \n",
 | |
|     "        return {self.output_key: response.generations[0][0].text}\n",
 | |
|     "\n",
 | |
|     "    @property\n",
 | |
|     "    def _chain_type(self) -> str:\n",
 | |
|     "        return \"my_custom_chain\"\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 12,
 | |
|    "id": "18361f89",
 | |
|    "metadata": {
 | |
|     "vscode": {
 | |
|      "languageId": "python"
 | |
|     }
 | |
|    },
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "\n",
 | |
|       "\n",
 | |
|       "\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n",
 | |
|       "Log something about this run\n",
 | |
|       "\u001b[1m> Finished chain.\u001b[0m\n"
 | |
|      ]
 | |
|     },
 | |
|     {
 | |
|      "data": {
 | |
|       "text/plain": [
 | |
|        "'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'"
 | |
|       ]
 | |
|      },
 | |
|      "execution_count": 12,
 | |
|      "metadata": {},
 | |
|      "output_type": "execute_result"
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "from langchain.callbacks.stdout import StdOutCallbackHandler\n",
 | |
|     "from langchain.chat_models.openai import ChatOpenAI\n",
 | |
|     "from langchain.prompts.prompt import PromptTemplate\n",
 | |
|     "\n",
 | |
|     "\n",
 | |
|     "chain = MyCustomChain(\n",
 | |
|     "    prompt=PromptTemplate.from_template('tell us a joke about {topic}'),\n",
 | |
|     "    llm=ChatOpenAI()\n",
 | |
|     ")\n",
 | |
|     "\n",
 | |
|     "chain.run({'topic': 'callbacks'}, callbacks=[StdOutCallbackHandler()])"
 | |
|    ]
 | |
|   }
 | |
|  ],
 | |
|  "metadata": {
 | |
|   "kernelspec": {
 | |
|    "display_name": "Python 3 (ipykernel)",
 | |
|    "language": "python",
 | |
|    "name": "python3"
 | |
|   },
 | |
|   "language_info": {
 | |
|    "codemirror_mode": {
 | |
|     "name": "ipython",
 | |
|     "version": 3
 | |
|    },
 | |
|    "file_extension": ".py",
 | |
|    "mimetype": "text/x-python",
 | |
|    "name": "python",
 | |
|    "nbconvert_exporter": "python",
 | |
|    "pygments_lexer": "ipython3",
 | |
|    "version": "3.10.9"
 | |
|   }
 | |
|  },
 | |
|  "nbformat": 4,
 | |
|  "nbformat_minor": 5
 | |
| }
 |