| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 | import uuidfrom unittest.mock import MagicMockimport pytestfrom core.rag.models.document import Documentfrom extensions import ext_redisfrom models.dataset import Datasetdef get_example_text() -> str:    return "test_text"def get_example_document(doc_id: str) -> Document:    doc = Document(        page_content=get_example_text(),        metadata={            "doc_id": doc_id,            "doc_hash": doc_id,            "document_id": doc_id,            "dataset_id": doc_id,        },    )    return doc@pytest.fixturedef setup_mock_redis() -> None:    # get    ext_redis.redis_client.get = MagicMock(return_value=None)    # set    ext_redis.redis_client.set = MagicMock(return_value=None)    # lock    mock_redis_lock = MagicMock()    mock_redis_lock.__enter__ = MagicMock()    mock_redis_lock.__exit__ = MagicMock()    ext_redis.redis_client.lock = mock_redis_lockclass AbstractVectorTest:    def __init__(self):        self.vector = None        self.dataset_id = str(uuid.uuid4())        self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"        self.example_doc_id = str(uuid.uuid4())        self.example_embedding = [1.001 * i for i in range(128)]    def create_vector(self) -> None:        self.vector.create(            texts=[get_example_document(doc_id=self.example_doc_id)],            embeddings=[self.example_embedding],        )    def search_by_vector(self):        hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)        assert len(hits_by_vector) == 1        assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id    def search_by_full_text(self):        hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())        assert len(hits_by_full_text) == 1        assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id    def delete_vector(self):        self.vector.delete()    def delete_by_ids(self, ids: list[str]):        self.vector.delete_by_ids(ids=ids)    def add_texts(self) -> list[str]:        batch_size = 100        documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]        embeddings = [self.example_embedding] * batch_size        self.vector.add_texts(documents=documents, embeddings=embeddings)        return [doc.metadata["doc_id"] for doc in documents]    def text_exists(self):        assert self.vector.text_exists(self.example_doc_id)    def get_ids_by_metadata_field(self):        with pytest.raises(NotImplementedError):            self.vector.get_ids_by_metadata_field(key="key", value="value")    def run_all_tests(self):        self.create_vector()        self.search_by_vector()        self.search_by_full_text()        self.text_exists()        self.get_ids_by_metadata_field()        added_doc_ids = self.add_texts()        self.delete_by_ids(added_doc_ids)        self.delete_vector()
 |