mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 02:03:32 +00:00 
			
		
		
		
	- [ ] **PR title**: "cookbook: using Gemma on LangChain"
- [ ] **PR message**: 
- **Description:** added a tutorial how to use Gemma with LangChain
(from VertexAI or locally from Kaggle or HF)
    - **Dependencies:** langchain-google-vertexai==0.0.7
    - **Twitter handle:** lkuligin
		
	
		
			
				
	
	
		
			933 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			933 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
{
 | 
						||
 "cells": [
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "BYejgj8Zf-LG",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "## Getting started with LangChain and Gemma, running locally or in the Cloud"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "2IxjMb9-jIJ8"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "### Installing dependencies"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 1,
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 9436,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708975187360,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "XZaTsXfcheTF",
 | 
						||
    "outputId": "eb21d603-d824-46c5-f99f-087fb2f618b1",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "!pip install --upgrade langchain langchain-google-vertexai"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "IXmAujvC3Kwp"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "### Running the model"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "CI8Elyc5gBQF"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "Go to the VertexAI Model Garden on Google Cloud [console](https://pantheon.corp.google.com/vertex-ai/publishers/google/model-garden/335), and deploy the desired version of Gemma to VertexAI. It will take a few minutes, and after the endpoint it ready, you need to copy its number."
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 1,
 | 
						||
   "metadata": {
 | 
						||
    "id": "gv1j8FrVftsC"
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# @title Basic parameters\n",
 | 
						||
    "project: str = \"PUT_YOUR_PROJECT_ID_HERE\"  # @param {type:\"string\"}\n",
 | 
						||
    "endpoint_id: str = \"PUT_YOUR_ENDPOINT_ID_HERE\"  # @param {type:\"string\"}\n",
 | 
						||
    "location: str = \"PUT_YOUR_ENDPOINT_LOCAtION_HERE\"  # @param {type:\"string\"}"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 3,
 | 
						||
   "metadata": {
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 3,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708975440503,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "bhIHsFGYjtFt",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "2024-02-27 17:15:10.457149: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
 | 
						||
      "2024-02-27 17:15:10.508925: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
 | 
						||
      "2024-02-27 17:15:10.508957: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
 | 
						||
      "2024-02-27 17:15:10.510289: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
 | 
						||
      "2024-02-27 17:15:10.518898: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
 | 
						||
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from langchain_google_vertexai import (\n",
 | 
						||
    "    GemmaChatVertexAIModelGarden,\n",
 | 
						||
    "    GemmaVertexAIModelGarden,\n",
 | 
						||
    ")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 4,
 | 
						||
   "metadata": {
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 351,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708975440852,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "WJv-UVWwh0lk",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "llm = GemmaVertexAIModelGarden(\n",
 | 
						||
    "    endpoint_id=endpoint_id,\n",
 | 
						||
    "    project=project,\n",
 | 
						||
    "    location=location,\n",
 | 
						||
    ")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 5,
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 714,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708975441564,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "6kM7cEFdiN9h",
 | 
						||
    "outputId": "fb420c56-5614-4745-cda8-0ee450a3e539",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "Prompt:\n",
 | 
						||
      "What is the meaning of life?\n",
 | 
						||
      "Output:\n",
 | 
						||
      " Who am I? Why do I exist? These are questions I have struggled with\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "output = llm.invoke(\"What is the meaning of life?\")\n",
 | 
						||
    "print(output)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "zzep9nfmuUcO"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "We can also use Gemma as a multi-turn chat model:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 7,
 | 
						||
   "metadata": {
 | 
						||
    "colab": {
 | 
						||
     "base_uri": "https://localhost:8080/"
 | 
						||
    },
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 964,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708976298189,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "8tPHoM5XiZOl",
 | 
						||
    "outputId": "7b8fb652-9aed-47b0-c096-aa1abfc3a2a9",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "content='Prompt:\\n<start_of_turn>user\\nHow much is 2+2?<end_of_turn>\\n<start_of_turn>model\\nOutput:\\n8-years old.<end_of_turn>\\n\\n<start_of'\n",
 | 
						||
      "content='Prompt:\\n<start_of_turn>user\\nHow much is 2+2?<end_of_turn>\\n<start_of_turn>model\\nPrompt:\\n<start_of_turn>user\\nHow much is 2+2?<end_of_turn>\\n<start_of_turn>model\\nOutput:\\n8-years old.<end_of_turn>\\n\\n<start_of<end_of_turn>\\n<start_of_turn>user\\nHow much is 3+3?<end_of_turn>\\n<start_of_turn>model\\nOutput:\\nOutput:\\n3-years old.<end_of_turn>\\n\\n<'\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from langchain_core.messages import HumanMessage\n",
 | 
						||
    "\n",
 | 
						||
    "llm = GemmaChatVertexAIModelGarden(\n",
 | 
						||
    "    endpoint_id=endpoint_id,\n",
 | 
						||
    "    project=project,\n",
 | 
						||
    "    location=location,\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "message1 = HumanMessage(content=\"How much is 2+2?\")\n",
 | 
						||
    "answer1 = llm.invoke([message1])\n",
 | 
						||
    "print(answer1)\n",
 | 
						||
    "\n",
 | 
						||
    "message2 = HumanMessage(content=\"How much is 3+3?\")\n",
 | 
						||
    "answer2 = llm.invoke([message1, answer1, message2])\n",
 | 
						||
    "\n",
 | 
						||
    "print(answer2)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "You can post-process response to avoid repetitions:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 8,
 | 
						||
   "metadata": {
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "content='Output:\\n<<humming>>: 2+2 = 4.\\n<end'\n",
 | 
						||
      "content='Output:\\nOutput:\\n<<humming>>: 3+3 = 6.'\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "answer1 = llm.invoke([message1], parse_response=True)\n",
 | 
						||
    "print(answer1)\n",
 | 
						||
    "\n",
 | 
						||
    "answer2 = llm.invoke([message1, answer1, message2], parse_response=True)\n",
 | 
						||
    "\n",
 | 
						||
    "print(answer2)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "VEfjqo7fjARR"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "## Running Gemma locally from Kaggle"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "gVW8QDzHu7TA"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "In order to run Gemma locally, you can download it from Kaggle first. In order to do this, you'll need to login into the Kaggle platform, create a API key and download a `kaggle.json` Read more about Kaggle auth [here](https://www.kaggle.com/docs/api)."
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "S1EsXQ3XvZkQ"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "### Installation"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 7,
 | 
						||
   "metadata": {
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 335,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708976305471,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "p8SMwpKRvbef",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
 | 
						||
      "  pid, fd = os.forkpty()\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "!mkdir -p ~/.kaggle && cp kaggle.json ~/.kaggle/kaggle.json"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 11,
 | 
						||
   "metadata": {
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 7802,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708976363010,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "Yr679aePv9Fq",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
 | 
						||
      "  pid, fd = os.forkpty()\n"
 | 
						||
     ]
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
 | 
						||
      "tensorstore 0.1.54 requires ml-dtypes>=0.3.1, but you have ml-dtypes 0.2.0 which is incompatible.\u001b[0m\u001b[31m\n",
 | 
						||
      "\u001b[0m"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "!pip install keras>=3 keras_nlp"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "E9zn8nYpv3QZ"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "### Usage"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 1,
 | 
						||
   "metadata": {
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 8536,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708976601206,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "0LFRmY8TjCkI",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "2024-02-27 16:38:40.797559: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
 | 
						||
      "2024-02-27 16:38:40.848444: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
 | 
						||
      "2024-02-27 16:38:40.848478: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
 | 
						||
      "2024-02-27 16:38:40.849728: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
 | 
						||
      "2024-02-27 16:38:40.857936: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
 | 
						||
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from langchain_google_vertexai import GemmaLocalKaggle"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "v-o7oXVavdMQ"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "You can specify the keras backend (by default it's `tensorflow`, but you can change it be `jax` or `torch`)."
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 2,
 | 
						||
   "metadata": {
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 9,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708976601206,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "vvTUH8DNj5SF",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# @title Basic parameters\n",
 | 
						||
    "keras_backend: str = \"jax\"  # @param {type:\"string\"}\n",
 | 
						||
    "model_name: str = \"gemma_2b_en\"  # @param {type:\"string\"}"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 3,
 | 
						||
   "metadata": {
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 40836,
 | 
						||
     "status": "ok",
 | 
						||
     "timestamp": 1708976761257,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "YOmrqxo5kHXK",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "2024-02-27 16:23:14.661164: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 20549 MB memory:  -> device: 0, name: NVIDIA L4, pci bus id: 0000:00:03.0, compute capability: 8.9\n",
 | 
						||
      "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "llm = GemmaLocalKaggle(model_name=model_name, keras_backend=keras_backend)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 7,
 | 
						||
   "metadata": {
 | 
						||
    "id": "Zu6yPDUgkQtQ",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "W0000 00:00:1709051129.518076  774855 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update\n"
 | 
						||
     ]
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "What is the meaning of life?\n",
 | 
						||
      "\n",
 | 
						||
      "The question is one of the most important questions in the world.\n",
 | 
						||
      "\n",
 | 
						||
      "It’s the question that has\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "output = llm.invoke(\"What is the meaning of life?\", max_tokens=30)\n",
 | 
						||
    "print(output)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### ChatModel"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "MSctpRE4u43N"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "Same as above, using Gemma locally as a multi-turn chat model. You might need to re-start the notebook and clean your GPU memory in order to avoid OOM errors:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 1,
 | 
						||
   "metadata": {
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "2024-02-27 16:58:22.331067: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
 | 
						||
      "2024-02-27 16:58:22.382948: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
 | 
						||
      "2024-02-27 16:58:22.382978: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
 | 
						||
      "2024-02-27 16:58:22.384312: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
 | 
						||
      "2024-02-27 16:58:22.392767: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
 | 
						||
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from langchain_google_vertexai import GemmaChatLocalKaggle"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 2,
 | 
						||
   "metadata": {
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# @title Basic parameters\n",
 | 
						||
    "keras_backend: str = \"jax\"  # @param {type:\"string\"}\n",
 | 
						||
    "model_name: str = \"gemma_2b_en\"  # @param {type:\"string\"}"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 3,
 | 
						||
   "metadata": {
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "2024-02-27 16:58:29.001922: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 20549 MB memory:  -> device: 0, name: NVIDIA L4, pci bus id: 0000:00:03.0, compute capability: 8.9\n",
 | 
						||
      "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "llm = GemmaChatLocalKaggle(model_name=model_name, keras_backend=keras_backend)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 4,
 | 
						||
   "metadata": {
 | 
						||
    "executionInfo": {
 | 
						||
     "elapsed": 3,
 | 
						||
     "status": "aborted",
 | 
						||
     "timestamp": 1708976382957,
 | 
						||
     "user": {
 | 
						||
      "displayName": "",
 | 
						||
      "userId": ""
 | 
						||
     },
 | 
						||
     "user_tz": -60
 | 
						||
    },
 | 
						||
    "id": "JrJmvZqwwLqj"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "2024-02-27 16:58:49.848412: I external/local_xla/xla/service/service.cc:168] XLA service 0x55adc0cf2c10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
 | 
						||
      "2024-02-27 16:58:49.848458: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA L4, Compute Capability 8.9\n",
 | 
						||
      "2024-02-27 16:58:50.116614: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
 | 
						||
      "2024-02-27 16:58:54.389324: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8900\n",
 | 
						||
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
 | 
						||
      "I0000 00:00:1709053145.225207  784891 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n",
 | 
						||
      "W0000 00:00:1709053145.284227  784891 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update\n"
 | 
						||
     ]
 | 
						||
    },
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "content=\"<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\nI'm a model.\\n Tampoco\\nI'm a model.\"\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from langchain_core.messages import HumanMessage\n",
 | 
						||
    "\n",
 | 
						||
    "message1 = HumanMessage(content=\"Hi! Who are you?\")\n",
 | 
						||
    "answer1 = llm.invoke([message1], max_tokens=30)\n",
 | 
						||
    "print(answer1)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 5,
 | 
						||
   "metadata": {
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "content=\"<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\n<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\nI'm a model.\\n Tampoco\\nI'm a model.<end_of_turn>\\n<start_of_turn>user\\nWhat can you help me with?<end_of_turn>\\n<start_of_turn>model\"\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "message2 = HumanMessage(content=\"What can you help me with?\")\n",
 | 
						||
    "answer2 = llm.invoke([message1, answer1, message2], max_tokens=60)\n",
 | 
						||
    "\n",
 | 
						||
    "print(answer2)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "You can post-process the response if you want to avoid multi-turn statements:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 7,
 | 
						||
   "metadata": {
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "content=\"I'm a model.\\n Tampoco\\nI'm a model.\"\n",
 | 
						||
      "content='I can help you with your modeling.\\n Tampoco\\nI can'\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "answer1 = llm.invoke([message1], max_tokens=30, parse_response=True)\n",
 | 
						||
    "print(answer1)\n",
 | 
						||
    "\n",
 | 
						||
    "answer2 = llm.invoke([message1, answer1, message2], max_tokens=60, parse_response=True)\n",
 | 
						||
    "print(answer2)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {
 | 
						||
    "id": "EiZnztso7hyF"
 | 
						||
   },
 | 
						||
   "source": [
 | 
						||
    "## Running Gemma locally from HuggingFace"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 1,
 | 
						||
   "metadata": {
 | 
						||
    "id": "qqAqsz5R7nKf",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stderr",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "2024-02-27 17:02:21.832409: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
 | 
						||
      "2024-02-27 17:02:21.883625: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
 | 
						||
      "2024-02-27 17:02:21.883656: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
 | 
						||
      "2024-02-27 17:02:21.884987: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
 | 
						||
      "2024-02-27 17:02:21.893340: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
 | 
						||
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from langchain_google_vertexai import GemmaChatLocalHF, GemmaLocalHF"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 2,
 | 
						||
   "metadata": {
 | 
						||
    "id": "tsyntzI08cOr",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# @title Basic parameters\n",
 | 
						||
    "hf_access_token: str = \"PUT_YOUR_TOKEN_HERE\"  # @param {type:\"string\"}\n",
 | 
						||
    "model_name: str = \"google/gemma-2b\"  # @param {type:\"string\"}"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 4,
 | 
						||
   "metadata": {
 | 
						||
    "id": "JWrqEkOo8sm9",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "a0d6de5542254ed1b6d3ba65465e050e",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "llm = GemmaLocalHF(model_name=\"google/gemma-2b\", hf_access_token=hf_access_token)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 6,
 | 
						||
   "metadata": {
 | 
						||
    "id": "VX96Jf4Y84k-",
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "What is the meaning of life?\n",
 | 
						||
      "\n",
 | 
						||
      "The question is one of the most important questions in the world.\n",
 | 
						||
      "\n",
 | 
						||
      "It’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n",
 | 
						||
      "\n",
 | 
						||
      "And it’s the question that\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "output = llm.invoke(\"What is the meaning of life?\", max_tokens=50)\n",
 | 
						||
    "print(output)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "Same as above, using Gemma locally as a multi-turn chat model. You might need to re-start the notebook and clean your GPU memory in order to avoid OOM errors:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 3,
 | 
						||
   "metadata": {
 | 
						||
    "id": "9x-jmEBg9Mk1"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "application/vnd.jupyter.widget-view+json": {
 | 
						||
       "model_id": "c9a0b8e161d74a6faca83b1be96dee27",
 | 
						||
       "version_major": 2,
 | 
						||
       "version_minor": 0
 | 
						||
      },
 | 
						||
      "text/plain": [
 | 
						||
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "display_data"
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "llm = GemmaChatLocalHF(model_name=model_name, hf_access_token=hf_access_token)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 4,
 | 
						||
   "metadata": {
 | 
						||
    "id": "qv_OSaMm9PVy"
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "content=\"<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\nI'm a model.\\n<end_of_turn>\\n<start_of_turn>user\\nWhat do you mean\"\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from langchain_core.messages import HumanMessage\n",
 | 
						||
    "\n",
 | 
						||
    "message1 = HumanMessage(content=\"Hi! Who are you?\")\n",
 | 
						||
    "answer1 = llm.invoke([message1], max_tokens=60)\n",
 | 
						||
    "print(answer1)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 8,
 | 
						||
   "metadata": {
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "content=\"<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\n<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\nI'm a model.\\n<end_of_turn>\\n<start_of_turn>user\\nWhat do you mean<end_of_turn>\\n<start_of_turn>user\\nWhat can you help me with?<end_of_turn>\\n<start_of_turn>model\\nI can help you with anything.\\n<\"\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "message2 = HumanMessage(content=\"What can you help me with?\")\n",
 | 
						||
    "answer2 = llm.invoke([message1, answer1, message2], max_tokens=140)\n",
 | 
						||
    "\n",
 | 
						||
    "print(answer2)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "And the same with posprocessing:"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 11,
 | 
						||
   "metadata": {
 | 
						||
    "tags": []
 | 
						||
   },
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "content=\"I'm a model.\\n<end_of_turn>\\n\"\n",
 | 
						||
      "content='I can help you with anything.\\n<end_of_turn>\\n<end_of_turn>\\n'\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "answer1 = llm.invoke([message1], max_tokens=60, parse_response=True)\n",
 | 
						||
    "print(answer1)\n",
 | 
						||
    "\n",
 | 
						||
    "answer2 = llm.invoke([message1, answer1, message2], max_tokens=120, parse_response=True)\n",
 | 
						||
    "print(answer2)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": null,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": []
 | 
						||
  }
 | 
						||
 ],
 | 
						||
 "metadata": {
 | 
						||
  "colab": {
 | 
						||
   "provenance": []
 | 
						||
  },
 | 
						||
  "environment": {
 | 
						||
   "kernel": "python3",
 | 
						||
   "name": ".m116",
 | 
						||
   "type": "gcloud",
 | 
						||
   "uri": "gcr.io/deeplearning-platform-release/:m116"
 | 
						||
  },
 | 
						||
  "kernelspec": {
 | 
						||
   "display_name": "Python 3",
 | 
						||
   "language": "python",
 | 
						||
   "name": "python3"
 | 
						||
  },
 | 
						||
  "language_info": {
 | 
						||
   "codemirror_mode": {
 | 
						||
    "name": "ipython",
 | 
						||
    "version": 3
 | 
						||
   },
 | 
						||
   "file_extension": ".py",
 | 
						||
   "mimetype": "text/x-python",
 | 
						||
   "name": "python",
 | 
						||
   "nbconvert_exporter": "python",
 | 
						||
   "pygments_lexer": "ipython3",
 | 
						||
   "version": "3.10.13"
 | 
						||
  }
 | 
						||
 },
 | 
						||
 "nbformat": 4,
 | 
						||
 "nbformat_minor": 4
 | 
						||
}
 |