소스 검색

fix: Incorrect order of embedded documents in CacheEmbedding (#1671)

Wen O.Y 1 년 전
부모
커밋
a1cd043fdc
1개의 변경된 파일11개의 추가작업 그리고 15개의 파일을 삭제
  1. 11 15
      api/core/embedding/cached_embedding.py

+ 11 - 15
api/core/embedding/cached_embedding.py

@@ -18,31 +18,30 @@ class CacheEmbedding(Embeddings):
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
     def embed_documents(self, texts: List[str]) -> List[List[float]]:
         """Embed search docs."""
         """Embed search docs."""
         # use doc embedding cache or store if not exists
         # use doc embedding cache or store if not exists
-        text_embeddings = []
-        embedding_queue_texts = []
-        for text in texts:
+        text_embeddings = [None for _ in range(len(texts))]
+        embedding_queue_indices = []
+        for i, text in enumerate(texts):
             hash = helper.generate_text_hash(text)
             hash = helper.generate_text_hash(text)
             embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
             embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
             if embedding:
             if embedding:
-                text_embeddings.append(embedding.get_embedding())
+                text_embeddings[i] = embedding.get_embedding()
             else:
             else:
-                embedding_queue_texts.append(text)
+                embedding_queue_indices.append(i)
 
 
-        if embedding_queue_texts:
+        if embedding_queue_indices:
             try:
             try:
-                embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
+                embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices])
             except Exception as ex:
             except Exception as ex:
                 raise self._embeddings.handle_exceptions(ex)
                 raise self._embeddings.handle_exceptions(ex)
-            i = 0
-            normalized_embedding_results = []
-            for text in embedding_queue_texts:
-                hash = helper.generate_text_hash(text)
+
+            for i, indice in enumerate(embedding_queue_indices):
+                hash = helper.generate_text_hash(texts[indice])
 
 
                 try:
                 try:
                     embedding = Embedding(model_name=self._embeddings.name, hash=hash)
                     embedding = Embedding(model_name=self._embeddings.name, hash=hash)
                     vector = embedding_results[i]
                     vector = embedding_results[i]
                     normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
                     normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
-                    normalized_embedding_results.append(normalized_embedding)
+                    text_embeddings[indice] = normalized_embedding
                     embedding.set_embedding(normalized_embedding)
                     embedding.set_embedding(normalized_embedding)
                     db.session.add(embedding)
                     db.session.add(embedding)
                     db.session.commit()
                     db.session.commit()
@@ -52,10 +51,7 @@ class CacheEmbedding(Embeddings):
                 except:
                 except:
                     logging.exception('Failed to add embedding to db')
                     logging.exception('Failed to add embedding to db')
                     continue
                     continue
-                finally:
-                    i += 1
 
 
-            text_embeddings.extend(normalized_embedding_results)
         return text_embeddings
         return text_embeddings
 
 
     def embed_query(self, text: str) -> List[float]:
     def embed_query(self, text: str) -> List[float]: