| 
					
				 | 
			
			
				@@ -1,4 +1,5 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import time 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from typing import Optional 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -24,6 +25,7 @@ from core.model_runtime.errors.invoke import ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+logger = logging.getLogger(__name__) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class BedrockTextEmbeddingModel(TextEmbeddingModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -53,17 +55,19 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         embeddings = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         token_usage = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model_prefix = model.split('.')[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if model_prefix == "amazon": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for text in texts: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                body = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "inputText": text, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                embeddings.extend([response_body.get('embedding')]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                token_usage += response_body.get('inputTextTokenCount') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            result = TextEmbeddingResult( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if model_prefix == "amazon" : 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           for text in texts: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              body = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 "inputText": text, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              embeddings.extend([response_body.get('embedding')]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              token_usage += response_body.get('inputTextTokenCount') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           logger.warning(f'Total Tokens: {token_usage}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           result = TextEmbeddingResult( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 embeddings=embeddings, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 usage=self._calc_response_usage( 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -71,11 +75,32 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     credentials=credentials, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     tokens=token_usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           return result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if model_prefix == "cohere" : 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           input_type = 'search_document' if len(texts) > 1 else 'search_query' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           for text in texts: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              body = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 "texts": [text], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 "input_type": input_type, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              embeddings.extend(response_body.get('embeddings')) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              token_usage += len(text) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           result = TextEmbeddingResult( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                embeddings=embeddings, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                usage=self._calc_response_usage( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    credentials=credentials, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    tokens=token_usage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+           return result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #others 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: 
			 |