瀏覽代碼

fix: azure embedding not support batch (#188)

John Wang 2 年之前
父節點
當前提交
d93365d429
共有 1 個文件被更改,包括 14 次插入0 次删除
  1. 14 0
      api/core/embedding/openai_embedding.py

+ 14 - 0
api/core/embedding/openai_embedding.py

@@ -173,6 +173,13 @@ class OpenAIEmbedding(BaseEmbedding):
         Can be overriden for batch queries.
 
         """
+        if self.openai_api_type and self.openai_api_type == 'azure':
+            embeddings = []
+            for text in texts:
+                embeddings.append(self._get_text_embedding(text))
+
+            return embeddings
+
         if self.deployment_name is not None:
             engine = self.deployment_name
         else:
@@ -187,6 +194,13 @@ class OpenAIEmbedding(BaseEmbedding):
 
     async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
         """Asynchronously get text embeddings."""
+        if self.openai_api_type and self.openai_api_type == 'azure':
+            embeddings = []
+            for text in texts:
+                embeddings.append(await self._aget_text_embedding(text))
+
+            return embeddings
+
         if self.deployment_name is not None:
             engine = self.deployment_name
         else: