ソースを参照

pref: change ollama embedded api request (#6876)

灰灰 8 ヶ月 前
コミット
56af1a0adf
共有1 個のファイルを変更した21 個の追加25 個の削除を含む
  1. 21 25
      api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py

+ 21 - 25
api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py

@@ -59,7 +59,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
         if not endpoint_url.endswith('/'):
             endpoint_url += '/'
 
-        endpoint_url = urljoin(endpoint_url, 'api/embeddings')
+        endpoint_url = urljoin(endpoint_url, 'api/embed')
 
         # get model properties
         context_size = self._get_context_size(model, credentials)
@@ -78,32 +78,28 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
             else:
                 inputs.append(text)
 
-        batched_embeddings = []
-
-        for text in inputs:
-            # Prepare the payload for the request
-            payload = {
-                'prompt': text,
-                'model': model,
-            }
-
-            # Make the request to the OpenAI API
-            response = requests.post(
-                endpoint_url,
-                headers=headers,
-                data=json.dumps(payload),
-                timeout=(10, 300)
-            )
+        # Prepare the payload for the request
+        payload = {
+            'input': inputs,
+            'model': model,
+        }
+
+        # Make the request to the OpenAI API
+        response = requests.post(
+            endpoint_url,
+            headers=headers,
+            data=json.dumps(payload),
+            timeout=(10, 300)
+        )
 
-            response.raise_for_status()  # Raise an exception for HTTP errors
-            response_data = response.json()
+        response.raise_for_status()  # Raise an exception for HTTP errors
+        response_data = response.json()
 
-            # Extract embeddings and used tokens from the response
-            embeddings = response_data['embedding']
-            embedding_used_tokens = self.get_num_tokens(model, credentials, [text])
+        # Extract embeddings and used tokens from the response
+        embeddings = response_data['embeddings']
+        embedding_used_tokens = self.get_num_tokens(model, credentials, inputs)
 
-            used_tokens += embedding_used_tokens
-            batched_embeddings.append(embeddings)
+        used_tokens += embedding_used_tokens
 
         # calc usage
         usage = self._calc_response_usage(
@@ -113,7 +109,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
         )
 
         return TextEmbeddingResult(
-            embeddings=batched_embeddings,
+            embeddings=embeddings,
             usage=usage,
             model=model
         )