test_vector_store.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import uuid
  2. from unittest.mock import MagicMock
  3. import pytest
  4. from core.rag.models.document import Document
  5. from extensions import ext_redis
  6. from models.dataset import Dataset
  7. def get_sample_text() -> str:
  8. return 'test_text'
  9. def get_sample_embedding() -> list[float]:
  10. return [1.1, 2.2, 3.3]
  11. def get_sample_query_vector() -> list[float]:
  12. return get_sample_embedding()
  13. def get_sample_document(sample_dataset_id: str) -> Document:
  14. doc = Document(
  15. page_content=get_sample_text(),
  16. metadata={
  17. "doc_id": sample_dataset_id,
  18. "doc_hash": sample_dataset_id,
  19. "document_id": sample_dataset_id,
  20. "dataset_id": sample_dataset_id,
  21. }
  22. )
  23. return doc
  24. @pytest.fixture
  25. def setup_mock_redis() -> None:
  26. # get
  27. ext_redis.redis_client.get = MagicMock(return_value=None)
  28. # set
  29. ext_redis.redis_client.set = MagicMock(return_value=None)
  30. # lock
  31. mock_redis_lock = MagicMock()
  32. mock_redis_lock.__enter__ = MagicMock()
  33. mock_redis_lock.__exit__ = MagicMock()
  34. ext_redis.redis_client.lock = mock_redis_lock
  35. class AbstractTestVector:
  36. def __init__(self):
  37. self.vector = None
  38. self.dataset_id = str(uuid.uuid4())
  39. self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id)
  40. def create_vector(self) -> None:
  41. self.vector.create(
  42. texts=[get_sample_document(self.dataset_id)],
  43. embeddings=[get_sample_embedding()],
  44. )
  45. def search_by_vector(self):
  46. hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector())
  47. assert len(hits_by_vector) >= 1
  48. def search_by_full_text(self):
  49. hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
  50. assert len(hits_by_full_text) >= 1
  51. def delete_vector(self):
  52. self.vector.delete()
  53. def delete_by_ids(self):
  54. self.vector.delete_by_ids([self.dataset_id])
  55. def add_texts(self):
  56. self.vector.add_texts(
  57. documents=[
  58. get_sample_document(str(uuid.uuid4())),
  59. get_sample_document(str(uuid.uuid4())),
  60. ],
  61. embeddings=[
  62. get_sample_embedding(),
  63. get_sample_embedding(),
  64. ],
  65. )
  66. def text_exists(self):
  67. self.vector.text_exists(self.dataset_id)
  68. def delete_document_by_id(self):
  69. with pytest.raises(NotImplementedError):
  70. self.vector.delete_by_document_id(self.dataset_id)
  71. def get_ids_by_metadata_field(self):
  72. with pytest.raises(NotImplementedError):
  73. self.vector.get_ids_by_metadata_field('key', 'value')
  74. def run_all_tests(self):
  75. self.create_vector()
  76. self.search_by_vector()
  77. self.search_by_full_text()
  78. self.text_exists()
  79. self.get_ids_by_metadata_field()
  80. self.add_texts()
  81. self.delete_document_by_id()
  82. self.delete_by_ids()
  83. self.delete_vector()