mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
Be able to use Codey models on Vertex AI (#6354)
Added the functionality to leverage 3 new Codey models from Vertex AI: - code-bison - Code generation using the existing LLM integration - code-gecko - Code completion using the existing LLM integration - codechat-bison - Code chat using the existing chat_model integration --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
0fce8ef178
commit
456ca3d587
@ -141,6 +141,73 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:09:25.423568Z",
|
||||
"iopub.status.busy": "2023-06-17T21:09:25.423213Z",
|
||||
"iopub.status.idle": "2023-06-17T21:09:25.429641Z",
|
||||
"shell.execute_reply": "2023-06-17T21:09:25.429060Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:09:25.423546Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n",
|
||||
"- codechat-bison: for code assistance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:30:43.974841Z",
|
||||
"iopub.status.busy": "2023-06-17T21:30:43.974431Z",
|
||||
"iopub.status.idle": "2023-06-17T21:30:44.248119Z",
|
||||
"shell.execute_reply": "2023-06-17T21:30:44.247362Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:30:43.974820Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = ChatVertexAI(model_name=\"codechat-bison\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:30:45.146093Z",
|
||||
"iopub.status.busy": "2023-06-17T21:30:45.145752Z",
|
||||
"iopub.status.idle": "2023-06-17T21:30:47.449126Z",
|
||||
"shell.execute_reply": "2023-06-17T21:30:47.448609Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:30:45.146069Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='The following Python function can be used to identify all prime numbers up to a given integer:\\n\\n```\\ndef is_prime(n):\\n \"\"\"\\n Determines whether the given integer is prime.\\n\\n Args:\\n n: The integer to be tested for primality.\\n\\n Returns:\\n True if n is prime, False otherwise.\\n \"\"\"\\n\\n # Check if n is divisible by 2.\\n if n % 2 == 0:\\n return False\\n\\n # Check if n is divisible by any integer from 3 to the square root', additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" HumanMessage(content=\"How do I create a python function to identify all prime numbers?\")\n",
|
||||
"]\n",
|
||||
"chat(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
@ -101,11 +101,80 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can now leverage the Codey API for code generation within Vertex AI. The model names are:\n",
|
||||
"- code-bison: for code suggestion\n",
|
||||
"- code-gecko: for code completion"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:16:53.149438Z",
|
||||
"iopub.status.busy": "2023-06-17T21:16:53.149065Z",
|
||||
"iopub.status.idle": "2023-06-17T21:16:53.421824Z",
|
||||
"shell.execute_reply": "2023-06-17T21:16:53.421136Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:16:53.149415Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"llm = VertexAI(model_name=\"code-bison\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:17:11.179077Z",
|
||||
"iopub.status.busy": "2023-06-17T21:17:11.178686Z",
|
||||
"iopub.status.idle": "2023-06-17T21:17:11.182499Z",
|
||||
"shell.execute_reply": "2023-06-17T21:17:11.181895Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:17:11.179052Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-06-17T21:18:47.024785Z",
|
||||
"iopub.status.busy": "2023-06-17T21:18:47.024230Z",
|
||||
"iopub.status.idle": "2023-06-17T21:18:49.352249Z",
|
||||
"shell.execute_reply": "2023-06-17T21:18:49.351695Z",
|
||||
"shell.execute_reply.started": "2023-06-17T21:18:47.024762Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'```python\\ndef is_prime(n):\\n \"\"\"\\n Determines if a number is prime.\\n\\n Args:\\n n: The number to be tested.\\n\\n Returns:\\n True if the number is prime, False otherwise.\\n \"\"\"\\n\\n # Check if the number is 1.\\n if n == 1:\\n return False\\n\\n # Check if the number is 2.\\n if n == 2:\\n return True\\n\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = \"Write a python function that identifies if the number is a prime number?\"\n",
|
||||
"\n",
|
||||
"llm_chain.run(question)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.vertexai import _VertexAICommon
|
||||
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@ -42,7 +42,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
|
||||
|
||||
A sequence should be either (SystemMessage, HumanMessage, AIMessage,
|
||||
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
|
||||
AIMessage, ...).
|
||||
AIMessage, ...). CodeChat does not support SystemMessage.
|
||||
|
||||
Args:
|
||||
history: The list of messages to re-create the history of the chat.
|
||||
@ -82,10 +82,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
"""Validate that the python package exists in environment."""
|
||||
cls._try_init_vertexai(values)
|
||||
try:
|
||||
from vertexai.preview.language_models import ChatModel
|
||||
if is_codey_model(values["model_name"]):
|
||||
from vertexai.preview.language_models import CodeChatModel
|
||||
|
||||
values["client"] = CodeChatModel.from_pretrained(values["model_name"])
|
||||
else:
|
||||
from vertexai.preview.language_models import ChatModel
|
||||
|
||||
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
@ -98,9 +104,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
"""Generate next turn in the conversation.
|
||||
|
||||
Args:
|
||||
messages: The history of the conversation as a list of messages.
|
||||
messages: The history of the conversation as a list of messages. Code chat
|
||||
does not support context.
|
||||
stop: The list of stop words (optional).
|
||||
run_manager: The Callbackmanager for LLM run, it's not used at the moment.
|
||||
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||
|
||||
Returns:
|
||||
The ChatResult that contains outputs generated by the model.
|
||||
@ -121,10 +128,12 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
context = history.system_message.content if history.system_message else None
|
||||
params = {**self._default_params, **kwargs}
|
||||
chat = self.client.start_chat(context=context, **params)
|
||||
if not self.is_codey_model:
|
||||
params["context"] = context
|
||||
chat = self.client.start_chat(**params)
|
||||
for pair in history.history:
|
||||
chat._history.append((pair.question.content, pair.answer.content))
|
||||
response = chat.send_message(question.content, **self._default_params)
|
||||
response = chat.send_message(question.content, **params)
|
||||
text = self._enforce_stop_words(response.text, stop)
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||
|
||||
|
@ -15,6 +15,10 @@ if TYPE_CHECKING:
|
||||
from vertexai.language_models._language_models import _LanguageModel
|
||||
|
||||
|
||||
def is_codey_model(model_name: str) -> bool:
|
||||
return "code" in model_name
|
||||
|
||||
|
||||
class _VertexAICommon(BaseModel):
|
||||
client: "_LanguageModel" = None #: :meta private:
|
||||
model_name: str
|
||||
@ -25,10 +29,10 @@ class _VertexAICommon(BaseModel):
|
||||
"Token limit determines the maximum amount of text output from one prompt."
|
||||
top_p: float = 0.95
|
||||
"Tokens are selected from most probable to least until the sum of their "
|
||||
"probabilities equals the top-p value."
|
||||
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
||||
top_k: int = 40
|
||||
"How the model selects tokens for output, the next token is selected from "
|
||||
"among the top-k most probable tokens."
|
||||
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
||||
stop: Optional[List[str]] = None
|
||||
"Optional list of stop words to use when generating."
|
||||
project: Optional[str] = None
|
||||
@ -40,15 +44,24 @@ class _VertexAICommon(BaseModel):
|
||||
"when making API calls. If not provided, credentials will be ascertained from "
|
||||
"the environment."
|
||||
|
||||
@property
|
||||
def is_codey_model(self) -> bool:
|
||||
return is_codey_model(self.model_name)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
base_params = {
|
||||
"temperature": self.temperature,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
return {**base_params}
|
||||
if self.is_codey_model:
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
|
||||
def _predict(
|
||||
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
@ -80,22 +93,32 @@ class VertexAI(_VertexAICommon, LLM):
|
||||
"""Wrapper around Google Vertex AI large language models."""
|
||||
|
||||
model_name: str = "text-bison"
|
||||
"The name of the Vertex AI large language model."
|
||||
tuned_model_name: Optional[str] = None
|
||||
"The name of a tuned model, if it's provided, model_name is ignored."
|
||||
"The name of a tuned model. If provided, model_name is ignored."
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
cls._try_init_vertexai(values)
|
||||
tuned_model_name = values.get("tuned_model_name")
|
||||
model_name = values["model_name"]
|
||||
try:
|
||||
from vertexai.preview.language_models import TextGenerationModel
|
||||
if tuned_model_name or not is_codey_model(model_name):
|
||||
from vertexai.preview.language_models import TextGenerationModel
|
||||
|
||||
if tuned_model_name:
|
||||
values["client"] = TextGenerationModel.get_tuned_model(
|
||||
tuned_model_name
|
||||
)
|
||||
else:
|
||||
values["client"] = TextGenerationModel.from_pretrained(model_name)
|
||||
else:
|
||||
from vertexai.preview.language_models import CodeGenerationModel
|
||||
|
||||
values["client"] = CodeGenerationModel.from_pretrained(model_name)
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
tuned_model_name = values.get("tuned_model_name")
|
||||
if tuned_model_name:
|
||||
values["client"] = TextGenerationModel.get_tuned_model(tuned_model_name)
|
||||
else:
|
||||
values["client"] = TextGenerationModel.from_pretrained(values["model_name"])
|
||||
return values
|
||||
|
||||
def _call(
|
||||
|
@ -11,7 +11,7 @@ def raise_vertex_import_error() -> None:
|
||||
Raises:
|
||||
ImportError: an ImportError that mentions a required version of the SDK.
|
||||
"""
|
||||
sdk = "'google-cloud-aiplatform>=1.25.0'"
|
||||
sdk = "'google-cloud-aiplatform>=1.26.0'"
|
||||
raise ImportError(
|
||||
"Could not import VertexAI. Please, install it with " f"pip install {sdk}"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user