diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index 2426587f58d..aaa5efbecb7 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -169,7 +169,7 @@ class VertexAI(_VertexAICommon, LLM): tuned_model_name = values.get("tuned_model_name") model_name = values["model_name"] try: - if tuned_model_name or not is_codey_model(model_name): + if not is_codey_model(model_name): from vertexai.preview.language_models import TextGenerationModel if tuned_model_name: @@ -181,7 +181,12 @@ class VertexAI(_VertexAICommon, LLM): else: from vertexai.preview.language_models import CodeGenerationModel - values["client"] = CodeGenerationModel.from_pretrained(model_name) + if tuned_model_name: + values["client"] = CodeGenerationModel.get_tuned_model( + tuned_model_name + ) + else: + values["client"] = CodeGenerationModel.from_pretrained(model_name) except ImportError: raise_vertex_import_error() return values