Selaa lähdekoodia

fix: AnalyticdbVector retrieval scores (#8803)

8bitpd 9 kuukautta sitten
vanhempi
commit
4c1063e1c5
1 muutettua tiedostoa jossa 6 lisäystä ja 13 poistoa
  1. 6 13
      api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py

+ 6 - 13
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py

@@ -40,19 +40,8 @@ class AnalyticdbConfig(BaseModel):
 
 
 class AnalyticdbVector(BaseVector):
-    _instance = None
-    _init = False
-
-    def __new__(cls, *args, **kwargs):
-        if cls._instance is None:
-            cls._instance = super().__new__(cls)
-        return cls._instance
-
     def __init__(self, collection_name: str, config: AnalyticdbConfig):
-        # collection_name must be updated every time
         self._collection_name = collection_name.lower()
-        if AnalyticdbVector._init:
-            return
         try:
             from alibabacloud_gpdb20160503.client import Client
             from alibabacloud_tea_openapi import models as open_api_models
@@ -62,7 +51,6 @@ class AnalyticdbVector(BaseVector):
         self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
         self._client = Client(self._client_config)
         self._initialize()
-        AnalyticdbVector._init = True
 
     def _initialize(self) -> None:
         cache_key = f"vector_indexing_{self.config.instance_id}"
@@ -257,11 +245,14 @@ class AnalyticdbVector(BaseVector):
         documents = []
         for match in response.body.matches.match:
             if match.score > score_threshold:
+                metadata = json.loads(match.metadata.get("metadata_"))
+                metadata["score"] = match.score
                 doc = Document(
                     page_content=match.metadata.get("page_content"),
-                    metadata=json.loads(match.metadata.get("metadata_")),
+                    metadata=metadata,
                 )
                 documents.append(doc)
+        documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
         return documents
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -286,12 +277,14 @@ class AnalyticdbVector(BaseVector):
         for match in response.body.matches.match:
             if match.score > score_threshold:
                 metadata = json.loads(match.metadata.get("metadata_"))
+                metadata["score"] = match.score
                 doc = Document(
                     page_content=match.metadata.get("page_content"),
                     vector=match.metadata.get("vector"),
                     metadata=metadata,
                 )
                 documents.append(doc)
+        documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
         return documents
 
     def delete(self) -> None: