mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-18 13:31:36 +00:00
add vertex prod features (#10910)
- chat vertex async - vertex stream - vertex full generation info - vertex use server-side stopping - model garden async - update docs for all the above in follow up will add [] chat vertex full generation info [] chat vertex retries [] scheduled tests
This commit is contained in:
parent
dccc20b402
commit
cab55e9bc1
@ -5,7 +5,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Google Cloud Platform Vertex AI PaLM \n",
|
||||
"# GCP Vertex AI \n",
|
||||
"\n",
|
||||
"Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n",
|
||||
"\n",
|
||||
@ -31,7 +31,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install google-cloud-aiplatform"
|
||||
"#!pip install langchain google-cloud-aiplatform"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -41,12 +41,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatVertexAI\n",
|
||||
"from langchain.prompts.chat import (\n",
|
||||
" ChatPromptTemplate,\n",
|
||||
" SystemMessagePromptTemplate,\n",
|
||||
" HumanMessagePromptTemplate,\n",
|
||||
")\n",
|
||||
"from langchain.schema import HumanMessage, SystemMessage"
|
||||
"from langchain.prompts import ChatPromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -60,82 +55,78 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"system = \"You are a helpful assistant who translate English to French\"\n",
|
||||
"human = \"Translate this sentence from English to French. I love programming.\"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [(\"system\", system), (\"human\", human)]\n",
|
||||
")\n",
|
||||
"messages = prompt.format_messages()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Sure, here is the translation of the sentence \"I love programming\" from English to French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)"
|
||||
"AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" SystemMessage(\n",
|
||||
" content=\"You are a helpful assistant that translates English to French.\"\n",
|
||||
" ),\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"chat(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use `ChatPromptTemplate`'s `format_prompt` -- this returns a `PromptValue`, which you can convert to a string or Message object, depending on whether you want to use the formatted value as input to an llm or chat model.\n",
|
||||
"\n",
|
||||
"For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:"
|
||||
"If we want to construct a simple chain that takes user specified parameters:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = (\n",
|
||||
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
|
||||
")\n",
|
||||
"system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
|
||||
"human_template = \"{text}\"\n",
|
||||
"human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)"
|
||||
"system = \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
|
||||
"human = \"{text}\"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [(\"system\", system), (\"human\", human)]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Sure, here is the translation of \"I love programming\" in French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)"
|
||||
"AIMessage(content=' 私はプログラミングが大好きです。', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat_prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [system_message_prompt, human_message_prompt]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# get a chat completion from the formatted messages\n",
|
||||
"chat(\n",
|
||||
" chat_prompt.format_prompt(\n",
|
||||
" input_language=\"English\", output_language=\"French\", text=\"I love programming.\"\n",
|
||||
" ).to_messages()\n",
|
||||
"chain = prompt | chat\n",
|
||||
"chain.invoke(\n",
|
||||
" {\"input_language\": \"English\", \"output_language\": \"Japanese\", \"text\": \"I love programming\"}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -153,60 +144,129 @@
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"## Code generation chat models\n",
|
||||
"You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n",
|
||||
"- codechat-bison: for code assistance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:30:43.974841Z",
|
||||
"iopub.status.busy": "2023-06-17T21:30:43.974431Z",
|
||||
"iopub.status.idle": "2023-06-17T21:30:44.248119Z",
|
||||
"shell.execute_reply": "2023-06-17T21:30:44.247362Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:30:43.974820Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = ChatVertexAI(model_name=\"codechat-bison\")"
|
||||
"chat = ChatVertexAI(\n",
|
||||
" model_name=\"codechat-bison\",\n",
|
||||
" max_output_tokens=1000,\n",
|
||||
" temperature=0.5\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:30:45.146093Z",
|
||||
"iopub.status.busy": "2023-06-17T21:30:45.145752Z",
|
||||
"iopub.status.idle": "2023-06-17T21:30:47.449126Z",
|
||||
"shell.execute_reply": "2023-06-17T21:30:47.448609Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:30:45.146069Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" ```python\n",
|
||||
"def is_prime(x): \n",
|
||||
" if (x <= 1): \n",
|
||||
" return False\n",
|
||||
" for i in range(2, x): \n",
|
||||
" if (x % i == 0): \n",
|
||||
" return False\n",
|
||||
" return True\n",
|
||||
"```\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# For simple string in string out usage, we can use the `predict` method:\n",
|
||||
"print(chat.predict(\"Write a Python function to identify all prime numbers\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Asynchronous calls\n",
|
||||
"\n",
|
||||
"We can make asynchronous calls via the `agenerate` and `ainvoke` methods."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"# import nest_asyncio\n",
|
||||
"# nest_asyncio.apply()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='The following Python function can be used to identify all prime numbers up to a given integer:\\n\\n```\\ndef is_prime(n):\\n \"\"\"\\n Determines whether the given integer is prime.\\n\\n Args:\\n n: The integer to be tested for primality.\\n\\n Returns:\\n True if n is prime, False otherwise.\\n \"\"\"\\n\\n # Check if n is divisible by 2.\\n if n % 2 == 0:\\n return False\\n\\n # Check if n is divisible by any integer from 3 to the square root', additional_kwargs={}, example=False)"
|
||||
"LLMResult(generations=[[ChatGeneration(text=\" J'aime la programmation.\", generation_info=None, message=AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('223599ef-38f8-4c79-ac6d-a5013060eb9d'))])"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"How do I create a python function to identify all prime numbers?\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"chat(messages)"
|
||||
"chat = ChatVertexAI(\n",
|
||||
" model_name=\"chat-bison\",\n",
|
||||
" max_output_tokens=1000,\n",
|
||||
" temperature=0.7,\n",
|
||||
" top_p=0.95,\n",
|
||||
" top_k=40,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"asyncio.run(chat.agenerate([messages]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=' अहं प्रोग्रामिंग प्रेमामि', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"asyncio.run(chain.ainvoke({\"input_language\": \"English\", \"output_language\": \"Sanskrit\", \"text\": \"I love programming\"}))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Streaming calls\n",
|
||||
"\n",
|
||||
"We can also stream outputs via the `stream` method:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -214,14 +274,51 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"import sys"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 1. China (1,444,216,107)\n",
|
||||
"2. India (1,393,409,038)\n",
|
||||
"3. United States (332,403,650)\n",
|
||||
"4. Indonesia (273,523,615)\n",
|
||||
"5. Pakistan (220,892,340)\n",
|
||||
"6. Brazil (212,559,409)\n",
|
||||
"7. Nigeria (206,139,589)\n",
|
||||
"8. Bangladesh (164,689,383)\n",
|
||||
"9. Russia (145,934,462)\n",
|
||||
"10. Mexico (128,932,488)\n",
|
||||
"11. Japan (126,476,461)\n",
|
||||
"12. Ethiopia (115,063,982)\n",
|
||||
"13. Philippines (109,581,078)\n",
|
||||
"14. Egypt (102,334,404)\n",
|
||||
"15. Vietnam (97,338,589)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"prompt = ChatPromptTemplate.from_messages([(\"human\", \"List out the 15 most populous countries in the world\")])\n",
|
||||
"messages = prompt.format_messages()\n",
|
||||
"for chunk in chat.stream(messages):\n",
|
||||
" sys.stdout.write(chunk.content)\n",
|
||||
" sys.stdout.flush()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "poetry-venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "poetry-venv"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
@ -26,7 +26,7 @@ ChatLiteLLM|✅|✅|✅|✅
|
||||
ChatMLflowAIGateway|✅|❌|❌|❌
|
||||
ChatOllama|✅|❌|✅|❌
|
||||
ChatOpenAI|✅|✅|✅|✅
|
||||
ChatVertexAI|✅|❌|✅|❌
|
||||
ChatVertexAI|✅|✅|✅|❌
|
||||
ErnieBotChat|✅|❌|❌|❌
|
||||
JinaChat|✅|✅|✅|✅
|
||||
MiniMaxChat|✅|✅|❌|❌
|
||||
|
@ -4,7 +4,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Google Vertex AI PaLM \n",
|
||||
"# GCP Vertex AI\n",
|
||||
"\n",
|
||||
"**Note:** This is separate from the `Google PaLM` integration, it exposes [Vertex AI PaLM API](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/overview) on `Google Cloud`. \n"
|
||||
]
|
||||
@ -41,32 +41,56 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install google-cloud-aiplatform"
|
||||
"#!pip install langchain google-cloud-aiplatform"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import VertexAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Python is a widely used, interpreted, object-oriented, and high-level programming language with dynamic semantics, used for general-purpose programming. It is known for its readability, simplicity, and versatility. Here are some of the pros and cons of Python:\n",
|
||||
"\n",
|
||||
"**Pros:**\n",
|
||||
"\n",
|
||||
"- **Easy to learn:** Python is known for its simple and intuitive syntax, making it easy for beginners to learn. It has a relatively shallow learning curve compared to other programming languages.\n",
|
||||
"\n",
|
||||
"- **Versatile:** Python is a general-purpose programming language, meaning it can be used for a wide variety of tasks, including web development, data science, machine\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = VertexAI()\n",
|
||||
"print(llm(\"What are some of the pros and cons of Python as a programming language?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Question-answering example"
|
||||
"## Using in a chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain"
|
||||
"from langchain.prompts import PromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -78,17 +102,7 @@
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = VertexAI()"
|
||||
"prompt = PromptTemplate.from_template(template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -97,29 +111,26 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||
"chain = prompt | llm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Justin Bieber was born on March 1, 1994. The Super Bowl in 1994 was won by the San Francisco 49ers.\\nThe final answer: San Francisco 49ers.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Justin Bieber was born on March 1, 1994. Bill Clinton was the president of the United States from January 20, 1993, to January 20, 2001.\n",
|
||||
"The final answer is Bill Clinton\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
|
||||
"\n",
|
||||
"llm_chain.run(question)"
|
||||
"question = \"Who was the president in the year Justin Beiber was born?\"\n",
|
||||
"print(chain.invoke({\"question\": question}))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -140,78 +151,200 @@
|
||||
"- `code-gecko`: for code completion"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:16:53.149438Z",
|
||||
"iopub.status.busy": "2023-06-17T21:16:53.149065Z",
|
||||
"iopub.status.idle": "2023-06-17T21:16:53.421824Z",
|
||||
"shell.execute_reply": "2023-06-17T21:16:53.421136Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:16:53.149415Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = VertexAI(model_name=\"code-bison\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:17:11.179077Z",
|
||||
"iopub.status.busy": "2023-06-17T21:17:11.178686Z",
|
||||
"iopub.status.idle": "2023-06-17T21:17:11.182499Z",
|
||||
"shell.execute_reply": "2023-06-17T21:17:11.181895Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:17:11.179052Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:18:47.024785Z",
|
||||
"iopub.status.busy": "2023-06-17T21:18:47.024230Z",
|
||||
"iopub.status.idle": "2023-06-17T21:18:49.352249Z",
|
||||
"shell.execute_reply": "2023-06-17T21:18:49.351695Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:18:47.024762Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = VertexAI(model_name=\"code-bison\", max_output_tokens=1000, temperature=0.3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"Write a python function that checks if a string is a valid email address\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'```python\\ndef is_prime(n):\\n \"\"\"\\n Determines if a number is prime.\\n\\n Args:\\n n: The number to be tested.\\n\\n Returns:\\n True if the number is prime, False otherwise.\\n \"\"\"\\n\\n # Check if the number is 1.\\n if n == 1:\\n return False\\n\\n # Check if the number is 2.\\n if n == 2:\\n return True\\n\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"```python\n",
|
||||
"import re\n",
|
||||
"\n",
|
||||
"def is_valid_email(email):\n",
|
||||
" pattern = re.compile(r\"[^@]+@[^@]+\\.[^@]+\")\n",
|
||||
" return pattern.match(email)\n",
|
||||
"```\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = \"Write a python function that identifies if the number is a prime number?\"\n",
|
||||
"\n",
|
||||
"llm_chain.run(question)"
|
||||
"print(llm(question))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using models deployed on Vertex Model Garden"
|
||||
"## Full generation info\n",
|
||||
"\n",
|
||||
"We can use the `generate` method to get back extra metadata like [safety attributes](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_confidence_scoring) and not just text completions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[[GenerationChunk(text='```python\\nimport re\\n\\ndef is_valid_email(email):\\n pattern = re.compile(r\"[^@]+@[^@]+\\\\.[^@]+\")\\n return pattern.match(email)\\n```', generation_info={'is_blocked': False, 'safety_attributes': {'Health': 0.1}})]]"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = llm.generate([question])\n",
|
||||
"result.generations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Asynchronous calls\n",
|
||||
"\n",
|
||||
"With `agenerate` we can make asynchronous calls"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If running in a Jupyter notebook you'll need to install nest_asyncio\n",
|
||||
"\n",
|
||||
"# !pip install nest_asyncio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import asyncio\n",
|
||||
"# import nest_asyncio\n",
|
||||
"# nest_asyncio.apply()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LLMResult(generations=[[GenerationChunk(text='```python\\nimport re\\n\\ndef is_valid_email(email):\\n pattern = re.compile(r\"[^@]+@[^@]+\\\\.[^@]+\")\\n return pattern.match(email)\\n```', generation_info={'is_blocked': False, 'safety_attributes': {'Health': 0.1}})]], llm_output=None, run=[RunInfo(run_id=UUID('caf74e91-aefb-48ac-8031-0c505fcbbcc6'))])"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"asyncio.run(llm.agenerate([question]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Streaming calls\n",
|
||||
"\n",
|
||||
"With `stream` we can stream results from the model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"```python\n",
|
||||
"import re\n",
|
||||
"\n",
|
||||
"def is_valid_email(email):\n",
|
||||
" \"\"\"\n",
|
||||
" Checks if a string is a valid email address.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" email: The string to check.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" True if the string is a valid email address, False otherwise.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Check for a valid email address format.\n",
|
||||
" if not re.match(r\"^[A-Za-z0-9\\.\\+_-]+@[A-Za-z0-9\\._-]+\\.[a-zA-Z]*$\", email):\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" # Check if the domain name exists.\n",
|
||||
" try:\n",
|
||||
" domain = email.split(\"@\")[1]\n",
|
||||
" socket.gethostbyname(domain)\n",
|
||||
" except socket.gaierror:\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" return True\n",
|
||||
"```"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for chunk in llm.stream(question):\n",
|
||||
" sys.stdout.write(chunk)\n",
|
||||
" sys.stdout.flush()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Vertex Model Garden"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -248,7 +381,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm(\"What is the meaning of life?\")"
|
||||
"print(llm(\"What is the meaning of life?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -264,8 +397,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate.from_template(\"What is the meaning of {thing}?\")"
|
||||
]
|
||||
},
|
||||
@ -275,9 +406,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_oss_chain = prompt | llm\n",
|
||||
"\n",
|
||||
"llm_oss_chain.invoke({\"thing\": \"life\"})"
|
||||
"chian = prompt | llm\n",
|
||||
"print(chain.invoke({\"thing\": \"life\"}))"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -83,8 +83,8 @@ TitanTakeoff|✅|❌|✅|❌|❌|❌
|
||||
Tongyi|✅|❌|❌|❌|❌|❌
|
||||
VLLM|✅|❌|❌|❌|✅|❌
|
||||
VLLMOpenAI|✅|✅|✅|✅|✅|✅
|
||||
VertexAI|✅|✅|❌|❌|❌|❌
|
||||
VertexAIModelGarden|✅|✅|❌|❌|❌|❌
|
||||
VertexAI|✅|✅|✅|❌|✅|✅
|
||||
VertexAIModelGarden|✅|✅|❌|❌|✅|✅
|
||||
Writer|✅|❌|❌|❌|❌|❌
|
||||
Xinference|✅|❌|❌|❌|❌|❌
|
||||
|
||||
|
@ -2,6 +2,35 @@
|
||||
|
||||
All functionality related to Google Platform
|
||||
|
||||
## LLMs
|
||||
|
||||
### Vertex AI
|
||||
|
||||
Access PaLM LLMs like `text-bison` and `code-bison` via Google Cloud.
|
||||
|
||||
```python
|
||||
from langchain.llms import VertexAI
|
||||
```
|
||||
|
||||
### Model Garden
|
||||
|
||||
Access PaLM and hundreds of OSS models via Vertex AI Model Garden.
|
||||
|
||||
```python
|
||||
from langchain.llms import VertexAIModelGarden
|
||||
```
|
||||
|
||||
## Chat models
|
||||
|
||||
### Vertex AI
|
||||
|
||||
Access PaLM chat models like `chat-bison` and `codechat-bison` via Google Cloud.
|
||||
|
||||
```python
|
||||
from langchain.chat_models import ChatVertexAI
|
||||
```
|
||||
|
||||
|
||||
## Document Loader
|
||||
### Google BigQuery
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
"""Wrapper around Google VertexAI chat-based models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
|
||||
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
@ -30,6 +34,8 @@ if TYPE_CHECKING:
|
||||
InputOutputTextPair,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ChatHistory:
|
||||
@ -116,7 +122,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
"""`Vertex AI` Chat large language models API."""
|
||||
|
||||
model_name: str = "chat-bison"
|
||||
streaming: bool = False
|
||||
"Underlying model name."
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -177,6 +183,42 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Asynchronously generate next turn in the conversation.
|
||||
|
||||
Args:
|
||||
messages: The history of the conversation as a list of messages. Code chat
|
||||
does not support context.
|
||||
stop: The list of stop words (optional).
|
||||
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||
|
||||
Returns:
|
||||
The ChatResult that contains outputs generated by the model.
|
||||
|
||||
Raises:
|
||||
ValueError: if the last message in the list is not from human.
|
||||
"""
|
||||
if "stream" in kwargs:
|
||||
kwargs.pop("stream")
|
||||
logger.warning("ChatVertexAI does not currently support async streaming.")
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = {**self._default_params, **kwargs}
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
chat = self._start_chat(history, params)
|
||||
response = await chat.send_message_async(question.content)
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
@ -1,28 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM, create_base_retry_decorator
|
||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema import (
|
||||
Generation,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utilities.vertexai import (
|
||||
init_vertexai,
|
||||
raise_vertex_import_error,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.cloud.aiplatform.gapic import PredictionServiceClient
|
||||
from vertexai.language_models._language_models import _LanguageModel
|
||||
from google.cloud.aiplatform.gapic import (
|
||||
PredictionServiceAsyncClient,
|
||||
PredictionServiceClient,
|
||||
)
|
||||
from vertexai.language_models._language_models import (
|
||||
TextGenerationResponse,
|
||||
_LanguageModel,
|
||||
)
|
||||
|
||||
|
||||
def _response_to_generation(
|
||||
response: TextGenerationResponse,
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
try:
|
||||
generation_info = {
|
||||
"is_blocked": response.is_blocked,
|
||||
"safety_attributes": response.safety_attributes,
|
||||
}
|
||||
except Exception:
|
||||
generation_info = None
|
||||
return GenerationChunk(text=response.text, generation_info=generation_info)
|
||||
|
||||
|
||||
def is_codey_model(model_name: str) -> bool:
|
||||
@ -36,7 +66,13 @@ def is_codey_model(model_name: str) -> bool:
|
||||
return "code" in model_name
|
||||
|
||||
|
||||
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
|
||||
def _create_retry_decorator(
|
||||
llm: VertexAI,
|
||||
*,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
import google.api_core
|
||||
|
||||
errors = [
|
||||
@ -46,14 +82,19 @@ def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
|
||||
google.api_core.exceptions.DeadlineExceeded,
|
||||
]
|
||||
decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries # type: ignore
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
|
||||
def completion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
@ -62,6 +103,38 @@ def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
def stream_completion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.predict_streaming(*args, **kwargs)
|
||||
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return await llm.client.predict_async(*args, **kwargs)
|
||||
|
||||
return await _acompletion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
class _VertexAIBase(BaseModel):
|
||||
project: Optional[str] = None
|
||||
"The default GCP project to use when making Vertex API calls."
|
||||
@ -110,6 +183,11 @@ class _VertexAICommon(_VertexAIBase):
|
||||
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
||||
"when making API calls. If not provided, credentials will be ascertained from "
|
||||
"the environment."
|
||||
streaming: bool = False
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vertexai"
|
||||
|
||||
@property
|
||||
def is_codey_model(self) -> bool:
|
||||
@ -135,17 +213,6 @@ class _VertexAICommon(_VertexAIBase):
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
|
||||
def _predict(
|
||||
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> str:
|
||||
params = {**self._default_params, **kwargs}
|
||||
res = completion_with_retry(self, prompt, **params) # type: ignore
|
||||
return self._enforce_stop_words(res.text, stop)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vertexai"
|
||||
|
||||
@classmethod
|
||||
def _try_init_vertexai(cls, values: Dict) -> None:
|
||||
allowed_params = ["project", "location", "credentials"]
|
||||
@ -154,13 +221,14 @@ class _VertexAICommon(_VertexAIBase):
|
||||
return None
|
||||
|
||||
|
||||
class VertexAI(_VertexAICommon, LLM):
|
||||
class VertexAI(_VertexAICommon, BaseLLM):
|
||||
"""Google Vertex AI large language models."""
|
||||
|
||||
model_name: str = "text-bison"
|
||||
"The name of the Vertex AI large language model."
|
||||
tuned_model_name: Optional[str] = None
|
||||
"The name of a tuned model. If provided, model_name is ignored."
|
||||
streaming: bool = False
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -191,51 +259,78 @@ class VertexAI(_VertexAICommon, LLM):
|
||||
raise_vertex_import_error()
|
||||
return values
|
||||
|
||||
def _call(
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
stop_sequences = stop or self.stop
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
|
||||
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
if should_stream:
|
||||
generation = GenerationChunk(text="")
|
||||
for chunk in self._stream(
|
||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
generation += chunk
|
||||
generations.append([generation])
|
||||
else:
|
||||
res = completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
)
|
||||
generations.append([_response_to_generation(res)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
stop_sequences = stop or self.stop
|
||||
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
res = await acompletion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
)
|
||||
generations.append([_response_to_generation(res)])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call Vertex model to get predictions based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of stop words (optional).
|
||||
run_manager: A Callbackmanager for LLM run, optional.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
return self._predict(prompt, stop, **kwargs)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call Vertex model to get predictions based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of stop words (optional).
|
||||
run_manager: A callback manager for async interaction with LLMs.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
return await asyncio.wrap_future(
|
||||
self._get_task_executor().submit(self._call, prompt, stop)
|
||||
)
|
||||
) -> Iterator[GenerationChunk]:
|
||||
stop_sequences = stop or self.stop
|
||||
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||
for stream_resp in stream_completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
):
|
||||
chunk = _response_to_generation(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIModelGarden(_VertexAIBase, LLM):
|
||||
class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
||||
"""Large language models served from Vertex AI Model Garden."""
|
||||
|
||||
client: "PredictionServiceClient" = None #: :meta private:
|
||||
async_client: "PredictionServiceAsyncClient" = None #: :meta private:
|
||||
endpoint_id: str
|
||||
"A name of an endpoint where the model has been deployed."
|
||||
allowed_model_args: Optional[List[str]] = None
|
||||
@ -247,7 +342,11 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
try:
|
||||
from google.cloud.aiplatform.gapic import PredictionServiceClient
|
||||
from google.api_core.client_options import ClientOptions
|
||||
from google.cloud.aiplatform.gapic import (
|
||||
PredictionServiceAsyncClient,
|
||||
PredictionServiceClient,
|
||||
)
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
|
||||
@ -256,38 +355,19 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
|
||||
"A GCP project should be provided to run inference on Model Garden!"
|
||||
)
|
||||
|
||||
client_options = {
|
||||
"api_endpoint": f"{values['location']}-aiplatform.googleapis.com"
|
||||
}
|
||||
client_options = ClientOptions(
|
||||
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
|
||||
)
|
||||
values["client"] = PredictionServiceClient(client_options=client_options)
|
||||
values["async_client"] = PredictionServiceAsyncClient(
|
||||
client_options=client_options
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vertexai_model_garden"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call Vertex model to get predictions based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of stop words (optional).
|
||||
run_manager: A Callbackmanager for LLM run, optional.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
result = self._generate(
|
||||
prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return result.generations[0][0].text
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -331,23 +411,47 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _acall(
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompt: str,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call Vertex model to get predictions based on the prompt.
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
try:
|
||||
from google.protobuf import json_format
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"protobuf package not found, please install it with"
|
||||
" `pip install protobuf`"
|
||||
)
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of stop words (optional).
|
||||
run_manager: A callback manager for async interaction with LLMs.
|
||||
instances = []
|
||||
for prompt in prompts:
|
||||
if self.allowed_model_args:
|
||||
instance = {
|
||||
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
||||
}
|
||||
else:
|
||||
instance = {}
|
||||
instance[self.prompt_arg] = prompt
|
||||
instances.append(instance)
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
return await asyncio.wrap_future(
|
||||
self._get_task_executor().submit(self._call, prompt, stop)
|
||||
predict_instances = [
|
||||
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
||||
]
|
||||
|
||||
endpoint = self.async_client.endpoint_path(
|
||||
project=self.project, location=self.location, endpoint=self.endpoint_id
|
||||
)
|
||||
response = await self.async_client.predict(
|
||||
endpoint=endpoint, instances=predict_instances
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
for result in response.predictions:
|
||||
generations.append(
|
||||
[Generation(text=prediction[self.result_arg]) for prediction in result]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
|
@ -13,6 +13,7 @@ import pytest
|
||||
|
||||
from langchain.chat_models import ChatVertexAI
|
||||
from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
@ -26,10 +27,22 @@ def test_vertexai_single_call(model_name: str) -> None:
|
||||
response = model([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
assert model._llm_type == "vertexai"
|
||||
assert model._llm_type == "chat-vertexai"
|
||||
assert model.model_name == model.client._model_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertexai_agenerate() -> None:
|
||||
model = ChatVertexAI(temperature=0)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await model.agenerate([[message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
|
||||
|
||||
sync_response = model.generate([[message]])
|
||||
assert response.generations[0][0] == sync_response.generations[0][0]
|
||||
|
||||
|
||||
def test_vertexai_single_call_with_context() -> None:
|
||||
model = ChatVertexAI()
|
||||
raw_context = (
|
||||
|
@ -14,7 +14,6 @@ def test_embedding_documents() -> None:
|
||||
output = model.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
assert model._llm_type == "vertexai"
|
||||
assert model.model_name == model.client._model_id
|
||||
|
||||
|
||||
@ -40,5 +39,4 @@ def test_paginated_texts() -> None:
|
||||
output = model.embed_documents(documents)
|
||||
assert len(output) == 8
|
||||
assert len(output[0]) == 768
|
||||
assert model._llm_type == "vertexai"
|
||||
assert model.model_name == model.client._model_id
|
||||
|
@ -9,18 +9,49 @@ Your end-user credentials would be used to make the calls (make sure you've run
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.llms import VertexAI, VertexAIModelGarden
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
def test_vertex_call() -> None:
|
||||
llm = VertexAI()
|
||||
llm = VertexAI(temperature=0)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "vertexai"
|
||||
assert llm.model_name == llm.client._model_id
|
||||
|
||||
|
||||
def test_vertex_generate() -> None:
|
||||
llm = VertexAI(temperate=0)
|
||||
output = llm.generate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_agenerate() -> None:
|
||||
llm = VertexAI(temperate=0)
|
||||
output = await llm.agenerate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
def test_vertext_stream() -> None:
|
||||
llm = VertexAI(temperate=0)
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_consistency() -> None:
|
||||
llm = VertexAI(temperate=0)
|
||||
output = llm.generate(["Please say foo:"])
|
||||
streaming_output = llm.generate(["Please say foo:"], stream=True)
|
||||
async_output = await llm.agenerate(["Please say foo:"])
|
||||
assert output.generations[0][0].text == streaming_output.generations[0][0].text
|
||||
assert output.generations[0][0].text == async_output.generations[0][0].text
|
||||
|
||||
|
||||
def test_model_garden() -> None:
|
||||
"""In order to run this test, you should provide an endpoint name.
|
||||
|
||||
@ -37,7 +68,7 @@ def test_model_garden() -> None:
|
||||
assert llm._llm_type == "vertexai_model_garden"
|
||||
|
||||
|
||||
def test_model_garden_batch() -> None:
|
||||
def test_model_garden_generate() -> None:
|
||||
"""In order to run this test, you should provide an endpoint name.
|
||||
|
||||
Example:
|
||||
@ -47,6 +78,16 @@ def test_model_garden_batch() -> None:
|
||||
endpoint_id = os.environ["ENDPOINT_ID"]
|
||||
project = os.environ["PROJECT"]
|
||||
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
||||
output = llm._generate(["What is the meaning of life?", "How much is 2+2"])
|
||||
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_garden_agenerate() -> None:
|
||||
endpoint_id = os.environ["ENDPOINT_ID"]
|
||||
project = os.environ["PROJECT"]
|
||||
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
||||
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 2
|
||||
|
Loading…
Reference in New Issue
Block a user