diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index 1713af650c6..8bc14726333 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -65,6 +65,7 @@ class LLMInputOutputAdapter: provider_to_output_key_map = { "anthropic": "completion", "amazon": "outputText", + "cohere": "text", } @classmethod @@ -74,7 +75,7 @@ class LLMInputOutputAdapter: input_body = {**model_kwargs} if provider == "anthropic": input_body["prompt"] = _human_assistant_format(prompt) - elif provider == "ai21": + elif provider == "ai21" or provider == "cohere": input_body["prompt"] = prompt elif provider == "amazon": input_body = dict() @@ -98,6 +99,8 @@ class LLMInputOutputAdapter: if provider == "ai21": return response_body.get("completions")[0].get("data").get("text") + elif provider == "cohere": + return response_body.get("generations")[0].get("text") else: return response_body.get("results")[0].get("outputText") @@ -119,6 +122,12 @@ class LLMInputOutputAdapter: chunk = event.get("chunk") if chunk: chunk_obj = json.loads(chunk.get("bytes").decode()) + if provider == "cohere" and ( + chunk_obj["is_finished"] + or chunk_obj[cls.provider_to_output_key_map[provider]] + == "" + ): + return # chunk obj format varies with provider yield GenerationChunk( @@ -159,6 +168,7 @@ class BedrockBase(BaseModel, ABC): "anthropic": "stop_sequences", "amazon": "stopSequences", "ai21": "stop_sequences", + "cohere": "stop_sequences", } @root_validator() @@ -259,9 +269,10 @@ class BedrockBase(BaseModel, ABC): # stop sequence from _generate() overrides # stop sequences in the class attribute - _model_kwargs[ - self.provider_stop_sequence_key_name_map.get(provider), - ] = stop + _model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop + + if provider == "cohere": + _model_kwargs["stream"] = True params = {**_model_kwargs, **kwargs} input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)