mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-19 05:51:16 +00:00
add vertex prod features (#10910)
- chat vertex async - vertex stream - vertex full generation info - vertex use server-side stopping - model garden async - update docs for all the above in follow up will add [] chat vertex full generation info [] chat vertex retries [] scheduled tests
This commit is contained in:
parent
dccc20b402
commit
cab55e9bc1
@ -5,7 +5,7 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Google Cloud Platform Vertex AI PaLM \n",
|
"# GCP Vertex AI \n",
|
||||||
"\n",
|
"\n",
|
||||||
"Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n",
|
"Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -31,7 +31,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"#!pip install google-cloud-aiplatform"
|
"#!pip install langchain google-cloud-aiplatform"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -41,12 +41,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.chat_models import ChatVertexAI\n",
|
"from langchain.chat_models import ChatVertexAI\n",
|
||||||
"from langchain.prompts.chat import (\n",
|
"from langchain.prompts import ChatPromptTemplate"
|
||||||
" ChatPromptTemplate,\n",
|
|
||||||
" SystemMessagePromptTemplate,\n",
|
|
||||||
" HumanMessagePromptTemplate,\n",
|
|
||||||
")\n",
|
|
||||||
"from langchain.schema import HumanMessage, SystemMessage"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -60,82 +55,78 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 34,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"system = \"You are a helpful assistant who translate English to French\"\n",
|
||||||
|
"human = \"Translate this sentence from English to French. I love programming.\"\n",
|
||||||
|
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||||
|
" [(\"system\", system), (\"human\", human)]\n",
|
||||||
|
")\n",
|
||||||
|
"messages = prompt.format_messages()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessage(content='Sure, here is the translation of the sentence \"I love programming\" from English to French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)"
|
"AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"messages = [\n",
|
|
||||||
" SystemMessage(\n",
|
|
||||||
" content=\"You are a helpful assistant that translates English to French.\"\n",
|
|
||||||
" ),\n",
|
|
||||||
" HumanMessage(\n",
|
|
||||||
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
|
||||||
" ),\n",
|
|
||||||
"]\n",
|
|
||||||
"chat(messages)"
|
"chat(messages)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use `ChatPromptTemplate`'s `format_prompt` -- this returns a `PromptValue`, which you can convert to a string or Message object, depending on whether you want to use the formatted value as input to an llm or chat model.\n",
|
"If we want to construct a simple chain that takes user specified parameters:"
|
||||||
"\n",
|
|
||||||
"For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"template = (\n",
|
"system = \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
|
||||||
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
|
"human = \"{text}\"\n",
|
||||||
")\n",
|
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||||
"system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
|
" [(\"system\", system), (\"human\", human)]\n",
|
||||||
"human_template = \"{text}\"\n",
|
")"
|
||||||
"human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessage(content='Sure, here is the translation of \"I love programming\" in French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)"
|
"AIMessage(content=' 私はプログラミングが大好きです。', additional_kwargs={}, example=False)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"chat_prompt = ChatPromptTemplate.from_messages(\n",
|
"chain = prompt | chat\n",
|
||||||
" [system_message_prompt, human_message_prompt]\n",
|
"chain.invoke(\n",
|
||||||
")\n",
|
" {\"input_language\": \"English\", \"output_language\": \"Japanese\", \"text\": \"I love programming\"}\n",
|
||||||
"\n",
|
|
||||||
"# get a chat completion from the formatted messages\n",
|
|
||||||
"chat(\n",
|
|
||||||
" chat_prompt.format_prompt(\n",
|
|
||||||
" input_language=\"English\", output_language=\"French\", text=\"I love programming.\"\n",
|
|
||||||
" ).to_messages()\n",
|
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -153,60 +144,129 @@
|
|||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
|
"## Code generation chat models\n",
|
||||||
"You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n",
|
"You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n",
|
||||||
"- codechat-bison: for code assistance"
|
"- codechat-bison: for code assistance"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 18,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
|
||||||
"iopub.execute_input": "2023-06-17T21:30:43.974841Z",
|
|
||||||
"iopub.status.busy": "2023-06-17T21:30:43.974431Z",
|
|
||||||
"iopub.status.idle": "2023-06-17T21:30:44.248119Z",
|
|
||||||
"shell.execute_reply": "2023-06-17T21:30:44.247362Z",
|
|
||||||
"shell.execute_reply.started": "2023-06-17T21:30:43.974820Z"
|
|
||||||
},
|
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"chat = ChatVertexAI(model_name=\"codechat-bison\")"
|
"chat = ChatVertexAI(\n",
|
||||||
|
" model_name=\"codechat-bison\",\n",
|
||||||
|
" max_output_tokens=1000,\n",
|
||||||
|
" temperature=0.5\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 20,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
|
||||||
"iopub.execute_input": "2023-06-17T21:30:45.146093Z",
|
|
||||||
"iopub.status.busy": "2023-06-17T21:30:45.145752Z",
|
|
||||||
"iopub.status.idle": "2023-06-17T21:30:47.449126Z",
|
|
||||||
"shell.execute_reply": "2023-06-17T21:30:47.448609Z",
|
|
||||||
"shell.execute_reply.started": "2023-06-17T21:30:45.146069Z"
|
|
||||||
},
|
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" ```python\n",
|
||||||
|
"def is_prime(x): \n",
|
||||||
|
" if (x <= 1): \n",
|
||||||
|
" return False\n",
|
||||||
|
" for i in range(2, x): \n",
|
||||||
|
" if (x % i == 0): \n",
|
||||||
|
" return False\n",
|
||||||
|
" return True\n",
|
||||||
|
"```\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# For simple string in string out usage, we can use the `predict` method:\n",
|
||||||
|
"print(chat.predict(\"Write a Python function to identify all prime numbers\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Asynchronous calls\n",
|
||||||
|
"\n",
|
||||||
|
"We can make asynchronous calls via the `agenerate` and `ainvoke` methods."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import asyncio\n",
|
||||||
|
"# import nest_asyncio\n",
|
||||||
|
"# nest_asyncio.apply()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 35,
|
||||||
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"AIMessage(content='The following Python function can be used to identify all prime numbers up to a given integer:\\n\\n```\\ndef is_prime(n):\\n \"\"\"\\n Determines whether the given integer is prime.\\n\\n Args:\\n n: The integer to be tested for primality.\\n\\n Returns:\\n True if n is prime, False otherwise.\\n \"\"\"\\n\\n # Check if n is divisible by 2.\\n if n % 2 == 0:\\n return False\\n\\n # Check if n is divisible by any integer from 3 to the square root', additional_kwargs={}, example=False)"
|
"LLMResult(generations=[[ChatGeneration(text=\" J'aime la programmation.\", generation_info=None, message=AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('223599ef-38f8-4c79-ac6d-a5013060eb9d'))])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 35,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"messages = [\n",
|
"chat = ChatVertexAI(\n",
|
||||||
" HumanMessage(\n",
|
" model_name=\"chat-bison\",\n",
|
||||||
" content=\"How do I create a python function to identify all prime numbers?\"\n",
|
" max_output_tokens=1000,\n",
|
||||||
" )\n",
|
" temperature=0.7,\n",
|
||||||
"]\n",
|
" top_p=0.95,\n",
|
||||||
"chat(messages)"
|
" top_k=40,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"asyncio.run(chat.agenerate([messages]))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 36,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=' अहं प्रोग्रामिंग प्रेमामि', additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 36,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"asyncio.run(chain.ainvoke({\"input_language\": \"English\", \"output_language\": \"Sanskrit\", \"text\": \"I love programming\"}))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Streaming calls\n",
|
||||||
|
"\n",
|
||||||
|
"We can also stream outputs via the `stream` method:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -214,14 +274,51 @@
|
|||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": [
|
||||||
|
"import sys"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 32,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 1. China (1,444,216,107)\n",
|
||||||
|
"2. India (1,393,409,038)\n",
|
||||||
|
"3. United States (332,403,650)\n",
|
||||||
|
"4. Indonesia (273,523,615)\n",
|
||||||
|
"5. Pakistan (220,892,340)\n",
|
||||||
|
"6. Brazil (212,559,409)\n",
|
||||||
|
"7. Nigeria (206,139,589)\n",
|
||||||
|
"8. Bangladesh (164,689,383)\n",
|
||||||
|
"9. Russia (145,934,462)\n",
|
||||||
|
"10. Mexico (128,932,488)\n",
|
||||||
|
"11. Japan (126,476,461)\n",
|
||||||
|
"12. Ethiopia (115,063,982)\n",
|
||||||
|
"13. Philippines (109,581,078)\n",
|
||||||
|
"14. Egypt (102,334,404)\n",
|
||||||
|
"15. Vietnam (97,338,589)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"prompt = ChatPromptTemplate.from_messages([(\"human\", \"List out the 15 most populous countries in the world\")])\n",
|
||||||
|
"messages = prompt.format_messages()\n",
|
||||||
|
"for chunk in chat.stream(messages):\n",
|
||||||
|
" sys.stdout.write(chunk.content)\n",
|
||||||
|
" sys.stdout.flush()"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "poetry-venv",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "poetry-venv"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
@ -26,7 +26,7 @@ ChatLiteLLM|✅|✅|✅|✅
|
|||||||
ChatMLflowAIGateway|✅|❌|❌|❌
|
ChatMLflowAIGateway|✅|❌|❌|❌
|
||||||
ChatOllama|✅|❌|✅|❌
|
ChatOllama|✅|❌|✅|❌
|
||||||
ChatOpenAI|✅|✅|✅|✅
|
ChatOpenAI|✅|✅|✅|✅
|
||||||
ChatVertexAI|✅|❌|✅|❌
|
ChatVertexAI|✅|✅|✅|❌
|
||||||
ErnieBotChat|✅|❌|❌|❌
|
ErnieBotChat|✅|❌|❌|❌
|
||||||
JinaChat|✅|✅|✅|✅
|
JinaChat|✅|✅|✅|✅
|
||||||
MiniMaxChat|✅|✅|❌|❌
|
MiniMaxChat|✅|✅|❌|❌
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Google Vertex AI PaLM \n",
|
"# GCP Vertex AI\n",
|
||||||
"\n",
|
"\n",
|
||||||
"**Note:** This is separate from the `Google PaLM` integration, it exposes [Vertex AI PaLM API](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/overview) on `Google Cloud`. \n"
|
"**Note:** This is separate from the `Google PaLM` integration, it exposes [Vertex AI PaLM API](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/overview) on `Google Cloud`. \n"
|
||||||
]
|
]
|
||||||
@ -41,32 +41,56 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"#!pip install google-cloud-aiplatform"
|
"#!pip install langchain google-cloud-aiplatform"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.llms import VertexAI"
|
"from langchain.llms import VertexAI"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" Python is a widely used, interpreted, object-oriented, and high-level programming language with dynamic semantics, used for general-purpose programming. It is known for its readability, simplicity, and versatility. Here are some of the pros and cons of Python:\n",
|
||||||
|
"\n",
|
||||||
|
"**Pros:**\n",
|
||||||
|
"\n",
|
||||||
|
"- **Easy to learn:** Python is known for its simple and intuitive syntax, making it easy for beginners to learn. It has a relatively shallow learning curve compared to other programming languages.\n",
|
||||||
|
"\n",
|
||||||
|
"- **Versatile:** Python is a general-purpose programming language, meaning it can be used for a wide variety of tasks, including web development, data science, machine\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm = VertexAI()\n",
|
||||||
|
"print(llm(\"What are some of the pros and cons of Python as a programming language?\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Question-answering example"
|
"## Using in a chain"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain"
|
"from langchain.prompts import PromptTemplate"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -78,17 +102,7 @@
|
|||||||
"template = \"\"\"Question: {question}\n",
|
"template = \"\"\"Question: {question}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Answer: Let's think step by step.\"\"\"\n",
|
"Answer: Let's think step by step.\"\"\"\n",
|
||||||
"\n",
|
"prompt = PromptTemplate.from_template(template)"
|
||||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"llm = VertexAI()"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -97,29 +111,26 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
"chain = prompt | llm"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"name": "stdout",
|
||||||
"text/plain": [
|
"output_type": "stream",
|
||||||
"'Justin Bieber was born on March 1, 1994. The Super Bowl in 1994 was won by the San Francisco 49ers.\\nThe final answer: San Francisco 49ers.'"
|
"text": [
|
||||||
]
|
" Justin Bieber was born on March 1, 1994. Bill Clinton was the president of the United States from January 20, 1993, to January 20, 2001.\n",
|
||||||
},
|
"The final answer is Bill Clinton\n"
|
||||||
"execution_count": 8,
|
]
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
|
"question = \"Who was the president in the year Justin Beiber was born?\"\n",
|
||||||
"\n",
|
"print(chain.invoke({\"question\": question}))"
|
||||||
"llm_chain.run(question)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -140,78 +151,200 @@
|
|||||||
"- `code-gecko`: for code completion"
|
"- `code-gecko`: for code completion"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"metadata": {
|
|
||||||
"execution": {
|
|
||||||
"iopub.execute_input": "2023-06-17T21:16:53.149438Z",
|
|
||||||
"iopub.status.busy": "2023-06-17T21:16:53.149065Z",
|
|
||||||
"iopub.status.idle": "2023-06-17T21:16:53.421824Z",
|
|
||||||
"shell.execute_reply": "2023-06-17T21:16:53.421136Z",
|
|
||||||
"shell.execute_reply.started": "2023-06-17T21:16:53.149415Z"
|
|
||||||
},
|
|
||||||
"tags": []
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"llm = VertexAI(model_name=\"code-bison\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {
|
|
||||||
"execution": {
|
|
||||||
"iopub.execute_input": "2023-06-17T21:17:11.179077Z",
|
|
||||||
"iopub.status.busy": "2023-06-17T21:17:11.178686Z",
|
|
||||||
"iopub.status.idle": "2023-06-17T21:17:11.182499Z",
|
|
||||||
"shell.execute_reply": "2023-06-17T21:17:11.181895Z",
|
|
||||||
"shell.execute_reply.started": "2023-06-17T21:17:11.179052Z"
|
|
||||||
},
|
|
||||||
"tags": []
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 15,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"execution": {
|
"tags": []
|
||||||
"iopub.execute_input": "2023-06-17T21:18:47.024785Z",
|
},
|
||||||
"iopub.status.busy": "2023-06-17T21:18:47.024230Z",
|
"outputs": [],
|
||||||
"iopub.status.idle": "2023-06-17T21:18:49.352249Z",
|
"source": [
|
||||||
"shell.execute_reply": "2023-06-17T21:18:49.351695Z",
|
"llm = VertexAI(model_name=\"code-bison\", max_output_tokens=1000, temperature=0.3)"
|
||||||
"shell.execute_reply.started": "2023-06-17T21:18:47.024762Z"
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 21,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"question = \"Write a python function that checks if a string is a valid email address\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"name": "stdout",
|
||||||
"text/plain": [
|
"output_type": "stream",
|
||||||
"'```python\\ndef is_prime(n):\\n \"\"\"\\n Determines if a number is prime.\\n\\n Args:\\n n: The number to be tested.\\n\\n Returns:\\n True if the number is prime, False otherwise.\\n \"\"\"\\n\\n # Check if the number is 1.\\n if n == 1:\\n return False\\n\\n # Check if the number is 2.\\n if n == 2:\\n return True\\n\\n'"
|
"text": [
|
||||||
]
|
"```python\n",
|
||||||
},
|
"import re\n",
|
||||||
"execution_count": 15,
|
"\n",
|
||||||
"metadata": {},
|
"def is_valid_email(email):\n",
|
||||||
"output_type": "execute_result"
|
" pattern = re.compile(r\"[^@]+@[^@]+\\.[^@]+\")\n",
|
||||||
|
" return pattern.match(email)\n",
|
||||||
|
"```\n"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"question = \"Write a python function that identifies if the number is a prime number?\"\n",
|
"print(llm(question))"
|
||||||
"\n",
|
|
||||||
"llm_chain.run(question)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Using models deployed on Vertex Model Garden"
|
"## Full generation info\n",
|
||||||
|
"\n",
|
||||||
|
"We can use the `generate` method to get back extra metadata like [safety attributes](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_confidence_scoring) and not just text completions"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[[GenerationChunk(text='```python\\nimport re\\n\\ndef is_valid_email(email):\\n pattern = re.compile(r\"[^@]+@[^@]+\\\\.[^@]+\")\\n return pattern.match(email)\\n```', generation_info={'is_blocked': False, 'safety_attributes': {'Health': 0.1}})]]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"result = llm.generate([question])\n",
|
||||||
|
"result.generations"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Asynchronous calls\n",
|
||||||
|
"\n",
|
||||||
|
"With `agenerate` we can make asynchronous calls"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# If running in a Jupyter notebook you'll need to install nest_asyncio\n",
|
||||||
|
"\n",
|
||||||
|
"# !pip install nest_asyncio"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import asyncio\n",
|
||||||
|
"# import nest_asyncio\n",
|
||||||
|
"# nest_asyncio.apply()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 25,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"LLMResult(generations=[[GenerationChunk(text='```python\\nimport re\\n\\ndef is_valid_email(email):\\n pattern = re.compile(r\"[^@]+@[^@]+\\\\.[^@]+\")\\n return pattern.match(email)\\n```', generation_info={'is_blocked': False, 'safety_attributes': {'Health': 0.1}})]], llm_output=None, run=[RunInfo(run_id=UUID('caf74e91-aefb-48ac-8031-0c505fcbbcc6'))])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 25,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"asyncio.run(llm.agenerate([question]))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Streaming calls\n",
|
||||||
|
"\n",
|
||||||
|
"With `stream` we can stream results from the model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sys"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 28,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"```python\n",
|
||||||
|
"import re\n",
|
||||||
|
"\n",
|
||||||
|
"def is_valid_email(email):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Checks if a string is a valid email address.\n",
|
||||||
|
"\n",
|
||||||
|
" Args:\n",
|
||||||
|
" email: The string to check.\n",
|
||||||
|
"\n",
|
||||||
|
" Returns:\n",
|
||||||
|
" True if the string is a valid email address, False otherwise.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
" # Check for a valid email address format.\n",
|
||||||
|
" if not re.match(r\"^[A-Za-z0-9\\.\\+_-]+@[A-Za-z0-9\\._-]+\\.[a-zA-Z]*$\", email):\n",
|
||||||
|
" return False\n",
|
||||||
|
"\n",
|
||||||
|
" # Check if the domain name exists.\n",
|
||||||
|
" try:\n",
|
||||||
|
" domain = email.split(\"@\")[1]\n",
|
||||||
|
" socket.gethostbyname(domain)\n",
|
||||||
|
" except socket.gaierror:\n",
|
||||||
|
" return False\n",
|
||||||
|
"\n",
|
||||||
|
" return True\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for chunk in llm.stream(question):\n",
|
||||||
|
" sys.stdout.write(chunk)\n",
|
||||||
|
" sys.stdout.flush()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Vertex Model Garden"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -248,7 +381,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"llm(\"What is the meaning of life?\")"
|
"print(llm(\"What is the meaning of life?\"))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -264,8 +397,6 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.prompts import PromptTemplate\n",
|
|
||||||
"\n",
|
|
||||||
"prompt = PromptTemplate.from_template(\"What is the meaning of {thing}?\")"
|
"prompt = PromptTemplate.from_template(\"What is the meaning of {thing}?\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -275,9 +406,8 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"llm_oss_chain = prompt | llm\n",
|
"chian = prompt | llm\n",
|
||||||
"\n",
|
"print(chain.invoke({\"thing\": \"life\"}))"
|
||||||
"llm_oss_chain.invoke({\"thing\": \"life\"})"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -83,8 +83,8 @@ TitanTakeoff|✅|❌|✅|❌|❌|❌
|
|||||||
Tongyi|✅|❌|❌|❌|❌|❌
|
Tongyi|✅|❌|❌|❌|❌|❌
|
||||||
VLLM|✅|❌|❌|❌|✅|❌
|
VLLM|✅|❌|❌|❌|✅|❌
|
||||||
VLLMOpenAI|✅|✅|✅|✅|✅|✅
|
VLLMOpenAI|✅|✅|✅|✅|✅|✅
|
||||||
VertexAI|✅|✅|❌|❌|❌|❌
|
VertexAI|✅|✅|✅|❌|✅|✅
|
||||||
VertexAIModelGarden|✅|✅|❌|❌|❌|❌
|
VertexAIModelGarden|✅|✅|❌|❌|✅|✅
|
||||||
Writer|✅|❌|❌|❌|❌|❌
|
Writer|✅|❌|❌|❌|❌|❌
|
||||||
Xinference|✅|❌|❌|❌|❌|❌
|
Xinference|✅|❌|❌|❌|❌|❌
|
||||||
|
|
||||||
|
@ -2,6 +2,35 @@
|
|||||||
|
|
||||||
All functionality related to Google Platform
|
All functionality related to Google Platform
|
||||||
|
|
||||||
|
## LLMs
|
||||||
|
|
||||||
|
### Vertex AI
|
||||||
|
|
||||||
|
Access PaLM LLMs like `text-bison` and `code-bison` via Google Cloud.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.llms import VertexAI
|
||||||
|
```
|
||||||
|
|
||||||
|
### Model Garden
|
||||||
|
|
||||||
|
Access PaLM and hundreds of OSS models via Vertex AI Model Garden.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.llms import VertexAIModelGarden
|
||||||
|
```
|
||||||
|
|
||||||
|
## Chat models
|
||||||
|
|
||||||
|
### Vertex AI
|
||||||
|
|
||||||
|
Access PaLM chat models like `chat-bison` and `codechat-bison` via Google Cloud.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.chat_models import ChatVertexAI
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Document Loader
|
## Document Loader
|
||||||
### Google BigQuery
|
### Google BigQuery
|
||||||
|
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
"""Wrapper around Google VertexAI chat-based models."""
|
"""Wrapper around Google VertexAI chat-based models."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
|
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
|
||||||
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
||||||
from langchain.pydantic_v1 import root_validator
|
from langchain.pydantic_v1 import root_validator
|
||||||
@ -30,6 +34,8 @@ if TYPE_CHECKING:
|
|||||||
InputOutputTextPair,
|
InputOutputTextPair,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _ChatHistory:
|
class _ChatHistory:
|
||||||
@ -116,7 +122,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
"""`Vertex AI` Chat large language models API."""
|
"""`Vertex AI` Chat large language models API."""
|
||||||
|
|
||||||
model_name: str = "chat-bison"
|
model_name: str = "chat-bison"
|
||||||
streaming: bool = False
|
"Underlying model name."
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -177,6 +183,42 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
text = self._enforce_stop_words(response.text, stop)
|
text = self._enforce_stop_words(response.text, stop)
|
||||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Asynchronously generate next turn in the conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The history of the conversation as a list of messages. Code chat
|
||||||
|
does not support context.
|
||||||
|
stop: The list of stop words (optional).
|
||||||
|
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ChatResult that contains outputs generated by the model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the last message in the list is not from human.
|
||||||
|
"""
|
||||||
|
if "stream" in kwargs:
|
||||||
|
kwargs.pop("stream")
|
||||||
|
logger.warning("ChatVertexAI does not currently support async streaming.")
|
||||||
|
question = _get_question(messages)
|
||||||
|
history = _parse_chat_history(messages[:-1])
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
|
examples = kwargs.get("examples", None)
|
||||||
|
if examples:
|
||||||
|
params["examples"] = _parse_examples(examples)
|
||||||
|
|
||||||
|
chat = self._start_chat(history, params)
|
||||||
|
response = await chat.send_message_async(question.content)
|
||||||
|
text = self._enforce_stop_words(response.text, stop)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
|
@ -1,28 +1,58 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
ClassVar,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM, create_base_retry_decorator
|
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
Generation,
|
Generation,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
from langchain.utilities.vertexai import (
|
from langchain.utilities.vertexai import (
|
||||||
init_vertexai,
|
init_vertexai,
|
||||||
raise_vertex_import_error,
|
raise_vertex_import_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from google.cloud.aiplatform.gapic import PredictionServiceClient
|
from google.cloud.aiplatform.gapic import (
|
||||||
from vertexai.language_models._language_models import _LanguageModel
|
PredictionServiceAsyncClient,
|
||||||
|
PredictionServiceClient,
|
||||||
|
)
|
||||||
|
from vertexai.language_models._language_models import (
|
||||||
|
TextGenerationResponse,
|
||||||
|
_LanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _response_to_generation(
|
||||||
|
response: TextGenerationResponse,
|
||||||
|
) -> GenerationChunk:
|
||||||
|
"""Convert a stream response to a generation chunk."""
|
||||||
|
try:
|
||||||
|
generation_info = {
|
||||||
|
"is_blocked": response.is_blocked,
|
||||||
|
"safety_attributes": response.safety_attributes,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
generation_info = None
|
||||||
|
return GenerationChunk(text=response.text, generation_info=generation_info)
|
||||||
|
|
||||||
|
|
||||||
def is_codey_model(model_name: str) -> bool:
|
def is_codey_model(model_name: str) -> bool:
|
||||||
@ -36,7 +66,13 @@ def is_codey_model(model_name: str) -> bool:
|
|||||||
return "code" in model_name
|
return "code" in model_name
|
||||||
|
|
||||||
|
|
||||||
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
|
def _create_retry_decorator(
|
||||||
|
llm: VertexAI,
|
||||||
|
*,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
import google.api_core
|
import google.api_core
|
||||||
|
|
||||||
errors = [
|
errors = [
|
||||||
@ -46,14 +82,19 @@ def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
|
|||||||
google.api_core.exceptions.DeadlineExceeded,
|
google.api_core.exceptions.DeadlineExceeded,
|
||||||
]
|
]
|
||||||
decorator = create_base_retry_decorator(
|
decorator = create_base_retry_decorator(
|
||||||
error_types=errors, max_retries=llm.max_retries # type: ignore
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
)
|
)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
|
def completion_with_retry(
|
||||||
|
llm: VertexAI,
|
||||||
|
*args: Any,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm)
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||||
@ -62,6 +103,38 @@ def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
|
|||||||
return _completion_with_retry(*args, **kwargs)
|
return _completion_with_retry(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_completion_with_retry(
|
||||||
|
llm: VertexAI,
|
||||||
|
*args: Any,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return llm.client.predict_streaming(*args, **kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry(
|
||||||
|
llm: VertexAI,
|
||||||
|
*args: Any,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return await llm.client.predict_async(*args, **kwargs)
|
||||||
|
|
||||||
|
return await _acompletion_with_retry(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class _VertexAIBase(BaseModel):
|
class _VertexAIBase(BaseModel):
|
||||||
project: Optional[str] = None
|
project: Optional[str] = None
|
||||||
"The default GCP project to use when making Vertex API calls."
|
"The default GCP project to use when making Vertex API calls."
|
||||||
@ -110,6 +183,11 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
||||||
"when making API calls. If not provided, credentials will be ascertained from "
|
"when making API calls. If not provided, credentials will be ascertained from "
|
||||||
"the environment."
|
"the environment."
|
||||||
|
streaming: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "vertexai"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_codey_model(self) -> bool:
|
def is_codey_model(self) -> bool:
|
||||||
@ -135,17 +213,6 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _predict(
|
|
||||||
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
|
||||||
) -> str:
|
|
||||||
params = {**self._default_params, **kwargs}
|
|
||||||
res = completion_with_retry(self, prompt, **params) # type: ignore
|
|
||||||
return self._enforce_stop_words(res.text, stop)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
return "vertexai"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _try_init_vertexai(cls, values: Dict) -> None:
|
def _try_init_vertexai(cls, values: Dict) -> None:
|
||||||
allowed_params = ["project", "location", "credentials"]
|
allowed_params = ["project", "location", "credentials"]
|
||||||
@ -154,13 +221,14 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class VertexAI(_VertexAICommon, LLM):
|
class VertexAI(_VertexAICommon, BaseLLM):
|
||||||
"""Google Vertex AI large language models."""
|
"""Google Vertex AI large language models."""
|
||||||
|
|
||||||
model_name: str = "text-bison"
|
model_name: str = "text-bison"
|
||||||
"The name of the Vertex AI large language model."
|
"The name of the Vertex AI large language model."
|
||||||
tuned_model_name: Optional[str] = None
|
tuned_model_name: Optional[str] = None
|
||||||
"The name of a tuned model. If provided, model_name is ignored."
|
"The name of a tuned model. If provided, model_name is ignored."
|
||||||
|
streaming: bool = False
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -191,51 +259,78 @@ class VertexAI(_VertexAICommon, LLM):
|
|||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _call(
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
stop_sequences = stop or self.stop
|
||||||
|
should_stream = stream if stream is not None else self.streaming
|
||||||
|
|
||||||
|
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||||
|
generations = []
|
||||||
|
for prompt in prompts:
|
||||||
|
if should_stream:
|
||||||
|
generation = GenerationChunk(text="")
|
||||||
|
for chunk in self._stream(
|
||||||
|
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
generation += chunk
|
||||||
|
generations.append([generation])
|
||||||
|
else:
|
||||||
|
res = completion_with_retry(
|
||||||
|
self, prompt, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
|
generations.append([_response_to_generation(res)])
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
stop_sequences = stop or self.stop
|
||||||
|
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||||
|
generations = []
|
||||||
|
for prompt in prompts:
|
||||||
|
res = await acompletion_with_retry(
|
||||||
|
self, prompt, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
|
generations.append([_response_to_generation(res)])
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
|
def _stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> Iterator[GenerationChunk]:
|
||||||
"""Call Vertex model to get predictions based on the prompt.
|
stop_sequences = stop or self.stop
|
||||||
|
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||||
Args:
|
for stream_resp in stream_completion_with_retry(
|
||||||
prompt: The prompt to pass into the model.
|
self, prompt, run_manager=run_manager, **params
|
||||||
stop: A list of stop words (optional).
|
):
|
||||||
run_manager: A Callbackmanager for LLM run, optional.
|
chunk = _response_to_generation(stream_resp)
|
||||||
|
yield chunk
|
||||||
Returns:
|
if run_manager:
|
||||||
The string generated by the model.
|
run_manager.on_llm_new_token(
|
||||||
"""
|
chunk.text,
|
||||||
return self._predict(prompt, stop, **kwargs)
|
chunk=chunk,
|
||||||
|
verbose=self.verbose,
|
||||||
async def _acall(
|
)
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
"""Call Vertex model to get predictions based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The prompt to pass into the model.
|
|
||||||
stop: A list of stop words (optional).
|
|
||||||
run_manager: A callback manager for async interaction with LLMs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The string generated by the model.
|
|
||||||
"""
|
|
||||||
return await asyncio.wrap_future(
|
|
||||||
self._get_task_executor().submit(self._call, prompt, stop)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIModelGarden(_VertexAIBase, LLM):
|
class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
||||||
"""Large language models served from Vertex AI Model Garden."""
|
"""Large language models served from Vertex AI Model Garden."""
|
||||||
|
|
||||||
client: "PredictionServiceClient" = None #: :meta private:
|
client: "PredictionServiceClient" = None #: :meta private:
|
||||||
|
async_client: "PredictionServiceAsyncClient" = None #: :meta private:
|
||||||
endpoint_id: str
|
endpoint_id: str
|
||||||
"A name of an endpoint where the model has been deployed."
|
"A name of an endpoint where the model has been deployed."
|
||||||
allowed_model_args: Optional[List[str]] = None
|
allowed_model_args: Optional[List[str]] = None
|
||||||
@ -247,7 +342,11 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
|
|||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that the python package exists in environment."""
|
"""Validate that the python package exists in environment."""
|
||||||
try:
|
try:
|
||||||
from google.cloud.aiplatform.gapic import PredictionServiceClient
|
from google.api_core.client_options import ClientOptions
|
||||||
|
from google.cloud.aiplatform.gapic import (
|
||||||
|
PredictionServiceAsyncClient,
|
||||||
|
PredictionServiceClient,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
|
|
||||||
@ -256,38 +355,19 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
|
|||||||
"A GCP project should be provided to run inference on Model Garden!"
|
"A GCP project should be provided to run inference on Model Garden!"
|
||||||
)
|
)
|
||||||
|
|
||||||
client_options = {
|
client_options = ClientOptions(
|
||||||
"api_endpoint": f"{values['location']}-aiplatform.googleapis.com"
|
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
|
||||||
}
|
)
|
||||||
values["client"] = PredictionServiceClient(client_options=client_options)
|
values["client"] = PredictionServiceClient(client_options=client_options)
|
||||||
|
values["async_client"] = PredictionServiceAsyncClient(
|
||||||
|
client_options=client_options
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "vertexai_model_garden"
|
return "vertexai_model_garden"
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
"""Call Vertex model to get predictions based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The prompt to pass into the model.
|
|
||||||
stop: A list of stop words (optional).
|
|
||||||
run_manager: A Callbackmanager for LLM run, optional.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The string generated by the model.
|
|
||||||
"""
|
|
||||||
result = self._generate(
|
|
||||||
prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return result.generations[0][0].text
|
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -331,23 +411,47 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
|
|||||||
)
|
)
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
async def _acall(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> LLMResult:
|
||||||
"""Call Vertex model to get predictions based on the prompt.
|
"""Run the LLM on the given prompt and input."""
|
||||||
|
try:
|
||||||
|
from google.protobuf import json_format
|
||||||
|
from google.protobuf.struct_pb2 import Value
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"protobuf package not found, please install it with"
|
||||||
|
" `pip install protobuf`"
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
instances = []
|
||||||
prompt: The prompt to pass into the model.
|
for prompt in prompts:
|
||||||
stop: A list of stop words (optional).
|
if self.allowed_model_args:
|
||||||
run_manager: A callback manager for async interaction with LLMs.
|
instance = {
|
||||||
|
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
instance = {}
|
||||||
|
instance[self.prompt_arg] = prompt
|
||||||
|
instances.append(instance)
|
||||||
|
|
||||||
Returns:
|
predict_instances = [
|
||||||
The string generated by the model.
|
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
||||||
"""
|
]
|
||||||
return await asyncio.wrap_future(
|
|
||||||
self._get_task_executor().submit(self._call, prompt, stop)
|
endpoint = self.async_client.endpoint_path(
|
||||||
|
project=self.project, location=self.location, endpoint=self.endpoint_id
|
||||||
)
|
)
|
||||||
|
response = await self.async_client.predict(
|
||||||
|
endpoint=endpoint, instances=predict_instances
|
||||||
|
)
|
||||||
|
generations: List[List[Generation]] = []
|
||||||
|
for result in response.predictions:
|
||||||
|
generations.append(
|
||||||
|
[Generation(text=prediction[self.result_arg]) for prediction in result]
|
||||||
|
)
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
@ -13,6 +13,7 @@ import pytest
|
|||||||
|
|
||||||
from langchain.chat_models import ChatVertexAI
|
from langchain.chat_models import ChatVertexAI
|
||||||
from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
|
from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
|
||||||
|
from langchain.schema import LLMResult
|
||||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
@ -26,10 +27,22 @@ def test_vertexai_single_call(model_name: str) -> None:
|
|||||||
response = model([message])
|
response = model([message])
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
assert model._llm_type == "vertexai"
|
assert model._llm_type == "chat-vertexai"
|
||||||
assert model.model_name == model.client._model_id
|
assert model.model_name == model.client._model_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertexai_agenerate() -> None:
|
||||||
|
model = ChatVertexAI(temperature=0)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = await model.agenerate([[message]])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
|
||||||
|
|
||||||
|
sync_response = model.generate([[message]])
|
||||||
|
assert response.generations[0][0] == sync_response.generations[0][0]
|
||||||
|
|
||||||
|
|
||||||
def test_vertexai_single_call_with_context() -> None:
|
def test_vertexai_single_call_with_context() -> None:
|
||||||
model = ChatVertexAI()
|
model = ChatVertexAI()
|
||||||
raw_context = (
|
raw_context = (
|
||||||
|
@ -14,7 +14,6 @@ def test_embedding_documents() -> None:
|
|||||||
output = model.embed_documents(documents)
|
output = model.embed_documents(documents)
|
||||||
assert len(output) == 1
|
assert len(output) == 1
|
||||||
assert len(output[0]) == 768
|
assert len(output[0]) == 768
|
||||||
assert model._llm_type == "vertexai"
|
|
||||||
assert model.model_name == model.client._model_id
|
assert model.model_name == model.client._model_id
|
||||||
|
|
||||||
|
|
||||||
@ -40,5 +39,4 @@ def test_paginated_texts() -> None:
|
|||||||
output = model.embed_documents(documents)
|
output = model.embed_documents(documents)
|
||||||
assert len(output) == 8
|
assert len(output) == 8
|
||||||
assert len(output[0]) == 768
|
assert len(output[0]) == 768
|
||||||
assert model._llm_type == "vertexai"
|
|
||||||
assert model.model_name == model.client._model_id
|
assert model.model_name == model.client._model_id
|
||||||
|
@ -9,18 +9,49 @@ Your end-user credentials would be used to make the calls (make sure you've run
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.llms import VertexAI, VertexAIModelGarden
|
from langchain.llms import VertexAI, VertexAIModelGarden
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
def test_vertex_call() -> None:
|
def test_vertex_call() -> None:
|
||||||
llm = VertexAI()
|
llm = VertexAI(temperature=0)
|
||||||
output = llm("Say foo:")
|
output = llm("Say foo:")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert llm._llm_type == "vertexai"
|
assert llm._llm_type == "vertexai"
|
||||||
assert llm.model_name == llm.client._model_id
|
assert llm.model_name == llm.client._model_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertex_generate() -> None:
|
||||||
|
llm = VertexAI(temperate=0)
|
||||||
|
output = llm.generate(["Please say foo:"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_agenerate() -> None:
|
||||||
|
llm = VertexAI(temperate=0)
|
||||||
|
output = await llm.agenerate(["Please say foo:"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertext_stream() -> None:
|
||||||
|
llm = VertexAI(temperate=0)
|
||||||
|
outputs = list(llm.stream("Please say foo:"))
|
||||||
|
assert isinstance(outputs[0], str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vertex_consistency() -> None:
|
||||||
|
llm = VertexAI(temperate=0)
|
||||||
|
output = llm.generate(["Please say foo:"])
|
||||||
|
streaming_output = llm.generate(["Please say foo:"], stream=True)
|
||||||
|
async_output = await llm.agenerate(["Please say foo:"])
|
||||||
|
assert output.generations[0][0].text == streaming_output.generations[0][0].text
|
||||||
|
assert output.generations[0][0].text == async_output.generations[0][0].text
|
||||||
|
|
||||||
|
|
||||||
def test_model_garden() -> None:
|
def test_model_garden() -> None:
|
||||||
"""In order to run this test, you should provide an endpoint name.
|
"""In order to run this test, you should provide an endpoint name.
|
||||||
|
|
||||||
@ -37,7 +68,7 @@ def test_model_garden() -> None:
|
|||||||
assert llm._llm_type == "vertexai_model_garden"
|
assert llm._llm_type == "vertexai_model_garden"
|
||||||
|
|
||||||
|
|
||||||
def test_model_garden_batch() -> None:
|
def test_model_garden_generate() -> None:
|
||||||
"""In order to run this test, you should provide an endpoint name.
|
"""In order to run this test, you should provide an endpoint name.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -47,6 +78,16 @@ def test_model_garden_batch() -> None:
|
|||||||
endpoint_id = os.environ["ENDPOINT_ID"]
|
endpoint_id = os.environ["ENDPOINT_ID"]
|
||||||
project = os.environ["PROJECT"]
|
project = os.environ["PROJECT"]
|
||||||
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
||||||
output = llm._generate(["What is the meaning of life?", "How much is 2+2"])
|
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_garden_agenerate() -> None:
|
||||||
|
endpoint_id = os.environ["ENDPOINT_ID"]
|
||||||
|
project = os.environ["PROJECT"]
|
||||||
|
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
||||||
|
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
|
||||||
assert isinstance(output, LLMResult)
|
assert isinstance(output, LLMResult)
|
||||||
assert len(output.generations) == 2
|
assert len(output.generations) == 2
|
||||||
|
Loading…
Reference in New Issue
Block a user