mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
Be able to use Codey models on Vertex AI (#6354)
Added the functionality to leverage 3 new Codey models from Vertex AI: - code-bison - Code generation using the existing LLM integration - code-gecko - Code completion using the existing LLM integration - codechat-bison - Code chat using the existing chat_model integration --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
0fce8ef178
commit
456ca3d587
@ -141,6 +141,73 @@
|
|||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2023-06-17T21:09:25.423568Z",
|
||||||
|
"iopub.status.busy": "2023-06-17T21:09:25.423213Z",
|
||||||
|
"iopub.status.idle": "2023-06-17T21:09:25.429641Z",
|
||||||
|
"shell.execute_reply": "2023-06-17T21:09:25.429060Z",
|
||||||
|
"shell.execute_reply.started": "2023-06-17T21:09:25.423546Z"
|
||||||
|
},
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n",
|
||||||
|
"- codechat-bison: for code assistance"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2023-06-17T21:30:43.974841Z",
|
||||||
|
"iopub.status.busy": "2023-06-17T21:30:43.974431Z",
|
||||||
|
"iopub.status.idle": "2023-06-17T21:30:44.248119Z",
|
||||||
|
"shell.execute_reply": "2023-06-17T21:30:44.247362Z",
|
||||||
|
"shell.execute_reply.started": "2023-06-17T21:30:43.974820Z"
|
||||||
|
},
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatVertexAI(model_name=\"codechat-bison\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2023-06-17T21:30:45.146093Z",
|
||||||
|
"iopub.status.busy": "2023-06-17T21:30:45.145752Z",
|
||||||
|
"iopub.status.idle": "2023-06-17T21:30:47.449126Z",
|
||||||
|
"shell.execute_reply": "2023-06-17T21:30:47.448609Z",
|
||||||
|
"shell.execute_reply.started": "2023-06-17T21:30:45.146069Z"
|
||||||
|
},
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content='The following Python function can be used to identify all prime numbers up to a given integer:\\n\\n```\\ndef is_prime(n):\\n \"\"\"\\n Determines whether the given integer is prime.\\n\\n Args:\\n n: The integer to be tested for primality.\\n\\n Returns:\\n True if n is prime, False otherwise.\\n \"\"\"\\n\\n # Check if n is divisible by 2.\\n if n % 2 == 0:\\n return False\\n\\n # Check if n is divisible by any integer from 3 to the square root', additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" HumanMessage(content=\"How do I create a python function to identify all prime numbers?\")\n",
|
||||||
|
"]\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
@ -101,11 +101,80 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "markdown",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can now leverage the Codey API for code generation within Vertex AI. The model names are:\n",
|
||||||
|
"- code-bison: for code suggestion\n",
|
||||||
|
"- code-gecko: for code completion"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2023-06-17T21:16:53.149438Z",
|
||||||
|
"iopub.status.busy": "2023-06-17T21:16:53.149065Z",
|
||||||
|
"iopub.status.idle": "2023-06-17T21:16:53.421824Z",
|
||||||
|
"shell.execute_reply": "2023-06-17T21:16:53.421136Z",
|
||||||
|
"shell.execute_reply.started": "2023-06-17T21:16:53.149415Z"
|
||||||
|
},
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": [
|
||||||
|
"llm = VertexAI(model_name=\"code-bison\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2023-06-17T21:17:11.179077Z",
|
||||||
|
"iopub.status.busy": "2023-06-17T21:17:11.178686Z",
|
||||||
|
"iopub.status.idle": "2023-06-17T21:17:11.182499Z",
|
||||||
|
"shell.execute_reply": "2023-06-17T21:17:11.181895Z",
|
||||||
|
"shell.execute_reply.started": "2023-06-17T21:17:11.179052Z"
|
||||||
|
},
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2023-06-17T21:18:47.024785Z",
|
||||||
|
"iopub.status.busy": "2023-06-17T21:18:47.024230Z",
|
||||||
|
"iopub.status.idle": "2023-06-17T21:18:49.352249Z",
|
||||||
|
"shell.execute_reply": "2023-06-17T21:18:49.351695Z",
|
||||||
|
"shell.execute_reply.started": "2023-06-17T21:18:47.024762Z"
|
||||||
|
},
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'```python\\ndef is_prime(n):\\n \"\"\"\\n Determines if a number is prime.\\n\\n Args:\\n n: The number to be tested.\\n\\n Returns:\\n True if the number is prime, False otherwise.\\n \"\"\"\\n\\n # Check if the number is 1.\\n if n == 1:\\n return False\\n\\n # Check if the number is 2.\\n if n == 2:\\n return True\\n\\n'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"question = \"Write a python function that identifies if the number is a prime number?\"\n",
|
||||||
|
"\n",
|
||||||
|
"llm_chain.run(question)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.llms.vertexai import _VertexAICommon
|
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -42,7 +42,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
|
|||||||
|
|
||||||
A sequence should be either (SystemMessage, HumanMessage, AIMessage,
|
A sequence should be either (SystemMessage, HumanMessage, AIMessage,
|
||||||
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
|
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
|
||||||
AIMessage, ...).
|
AIMessage, ...). CodeChat does not support SystemMessage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
history: The list of messages to re-create the history of the chat.
|
history: The list of messages to re-create the history of the chat.
|
||||||
@ -82,10 +82,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
"""Validate that the python package exists in environment."""
|
"""Validate that the python package exists in environment."""
|
||||||
cls._try_init_vertexai(values)
|
cls._try_init_vertexai(values)
|
||||||
try:
|
try:
|
||||||
from vertexai.preview.language_models import ChatModel
|
if is_codey_model(values["model_name"]):
|
||||||
|
from vertexai.preview.language_models import CodeChatModel
|
||||||
|
|
||||||
|
values["client"] = CodeChatModel.from_pretrained(values["model_name"])
|
||||||
|
else:
|
||||||
|
from vertexai.preview.language_models import ChatModel
|
||||||
|
|
||||||
|
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -98,9 +104,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
"""Generate next turn in the conversation.
|
"""Generate next turn in the conversation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: The history of the conversation as a list of messages.
|
messages: The history of the conversation as a list of messages. Code chat
|
||||||
|
does not support context.
|
||||||
stop: The list of stop words (optional).
|
stop: The list of stop words (optional).
|
||||||
run_manager: The Callbackmanager for LLM run, it's not used at the moment.
|
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The ChatResult that contains outputs generated by the model.
|
The ChatResult that contains outputs generated by the model.
|
||||||
@ -121,10 +128,12 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
history = _parse_chat_history(messages[:-1])
|
history = _parse_chat_history(messages[:-1])
|
||||||
context = history.system_message.content if history.system_message else None
|
context = history.system_message.content if history.system_message else None
|
||||||
params = {**self._default_params, **kwargs}
|
params = {**self._default_params, **kwargs}
|
||||||
chat = self.client.start_chat(context=context, **params)
|
if not self.is_codey_model:
|
||||||
|
params["context"] = context
|
||||||
|
chat = self.client.start_chat(**params)
|
||||||
for pair in history.history:
|
for pair in history.history:
|
||||||
chat._history.append((pair.question.content, pair.answer.content))
|
chat._history.append((pair.question.content, pair.answer.content))
|
||||||
response = chat.send_message(question.content, **self._default_params)
|
response = chat.send_message(question.content, **params)
|
||||||
text = self._enforce_stop_words(response.text, stop)
|
text = self._enforce_stop_words(response.text, stop)
|
||||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||||
|
|
||||||
|
@ -15,6 +15,10 @@ if TYPE_CHECKING:
|
|||||||
from vertexai.language_models._language_models import _LanguageModel
|
from vertexai.language_models._language_models import _LanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
def is_codey_model(model_name: str) -> bool:
|
||||||
|
return "code" in model_name
|
||||||
|
|
||||||
|
|
||||||
class _VertexAICommon(BaseModel):
|
class _VertexAICommon(BaseModel):
|
||||||
client: "_LanguageModel" = None #: :meta private:
|
client: "_LanguageModel" = None #: :meta private:
|
||||||
model_name: str
|
model_name: str
|
||||||
@ -25,10 +29,10 @@ class _VertexAICommon(BaseModel):
|
|||||||
"Token limit determines the maximum amount of text output from one prompt."
|
"Token limit determines the maximum amount of text output from one prompt."
|
||||||
top_p: float = 0.95
|
top_p: float = 0.95
|
||||||
"Tokens are selected from most probable to least until the sum of their "
|
"Tokens are selected from most probable to least until the sum of their "
|
||||||
"probabilities equals the top-p value."
|
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
||||||
top_k: int = 40
|
top_k: int = 40
|
||||||
"How the model selects tokens for output, the next token is selected from "
|
"How the model selects tokens for output, the next token is selected from "
|
||||||
"among the top-k most probable tokens."
|
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
||||||
stop: Optional[List[str]] = None
|
stop: Optional[List[str]] = None
|
||||||
"Optional list of stop words to use when generating."
|
"Optional list of stop words to use when generating."
|
||||||
project: Optional[str] = None
|
project: Optional[str] = None
|
||||||
@ -40,15 +44,24 @@ class _VertexAICommon(BaseModel):
|
|||||||
"when making API calls. If not provided, credentials will be ascertained from "
|
"when making API calls. If not provided, credentials will be ascertained from "
|
||||||
"the environment."
|
"the environment."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_codey_model(self) -> bool:
|
||||||
|
return is_codey_model(self.model_name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
base_params = {
|
if self.is_codey_model:
|
||||||
"temperature": self.temperature,
|
return {
|
||||||
"max_output_tokens": self.max_output_tokens,
|
"temperature": self.temperature,
|
||||||
"top_k": self.top_k,
|
"max_output_tokens": self.max_output_tokens,
|
||||||
"top_p": self.top_p,
|
}
|
||||||
}
|
else:
|
||||||
return {**base_params}
|
return {
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_output_tokens": self.max_output_tokens,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
}
|
||||||
|
|
||||||
def _predict(
|
def _predict(
|
||||||
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
@ -80,22 +93,32 @@ class VertexAI(_VertexAICommon, LLM):
|
|||||||
"""Wrapper around Google Vertex AI large language models."""
|
"""Wrapper around Google Vertex AI large language models."""
|
||||||
|
|
||||||
model_name: str = "text-bison"
|
model_name: str = "text-bison"
|
||||||
|
"The name of the Vertex AI large language model."
|
||||||
tuned_model_name: Optional[str] = None
|
tuned_model_name: Optional[str] = None
|
||||||
"The name of a tuned model, if it's provided, model_name is ignored."
|
"The name of a tuned model. If provided, model_name is ignored."
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that the python package exists in environment."""
|
"""Validate that the python package exists in environment."""
|
||||||
cls._try_init_vertexai(values)
|
cls._try_init_vertexai(values)
|
||||||
|
tuned_model_name = values.get("tuned_model_name")
|
||||||
|
model_name = values["model_name"]
|
||||||
try:
|
try:
|
||||||
from vertexai.preview.language_models import TextGenerationModel
|
if tuned_model_name or not is_codey_model(model_name):
|
||||||
|
from vertexai.preview.language_models import TextGenerationModel
|
||||||
|
|
||||||
|
if tuned_model_name:
|
||||||
|
values["client"] = TextGenerationModel.get_tuned_model(
|
||||||
|
tuned_model_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
values["client"] = TextGenerationModel.from_pretrained(model_name)
|
||||||
|
else:
|
||||||
|
from vertexai.preview.language_models import CodeGenerationModel
|
||||||
|
|
||||||
|
values["client"] = CodeGenerationModel.from_pretrained(model_name)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
tuned_model_name = values.get("tuned_model_name")
|
|
||||||
if tuned_model_name:
|
|
||||||
values["client"] = TextGenerationModel.get_tuned_model(tuned_model_name)
|
|
||||||
else:
|
|
||||||
values["client"] = TextGenerationModel.from_pretrained(values["model_name"])
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
|
@ -11,7 +11,7 @@ def raise_vertex_import_error() -> None:
|
|||||||
Raises:
|
Raises:
|
||||||
ImportError: an ImportError that mentions a required version of the SDK.
|
ImportError: an ImportError that mentions a required version of the SDK.
|
||||||
"""
|
"""
|
||||||
sdk = "'google-cloud-aiplatform>=1.25.0'"
|
sdk = "'google-cloud-aiplatform>=1.26.0'"
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import VertexAI. Please, install it with " f"pip install {sdk}"
|
"Could not import VertexAI. Please, install it with " f"pip install {sdk}"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user