diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index e8a998cde63..70a6ab663fc 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -299,7 +299,12 @@ class VertexAI(_VertexAICommon, BaseLLM): res = completion_with_retry( self, prompt, run_manager=run_manager, **params ) - generations.append([_response_to_generation(r) for r in res.candidates]) + if self.is_codey_model: + generations.append([_response_to_generation(res)]) + else: + generations.append( + [_response_to_generation(r) for r in res.candidates] + ) return LLMResult(generations=generations) async def _agenerate( diff --git a/libs/langchain/tests/integration_tests/llms/test_vertexai.py b/libs/langchain/tests/integration_tests/llms/test_vertexai.py index 2bee6af181c..20cfb6d85ba 100644 --- a/libs/langchain/tests/integration_tests/llms/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/llms/test_vertexai.py @@ -36,6 +36,15 @@ def test_vertex_generate() -> None: assert len(output.generations[0]) == 2 +@pytest.mark.scheduled +def test_vertex_generate_code() -> None: + llm = VertexAI(temperature=0.3, n=2, model_name="code-bison@001") + output = llm.generate(["generate a python method that says foo:"]) + assert isinstance(output, LLMResult) + assert len(output.generations) == 1 + assert len(output.generations[0]) == 2 + + @pytest.mark.scheduled @pytest.mark.asyncio async def test_vertex_agenerate() -> None: