opengauss.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import json
  2. import uuid
  3. from contextlib import contextmanager
  4. from typing import Any
  5. import psycopg2.extras # type: ignore
  6. import psycopg2.pool # type: ignore
  7. from pydantic import BaseModel, model_validator
  8. from configs import dify_config
  9. from core.rag.datasource.vdb.vector_base import BaseVector
  10. from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
  11. from core.rag.datasource.vdb.vector_type import VectorType
  12. from core.rag.embedding.embedding_base import Embeddings
  13. from core.rag.models.document import Document
  14. from extensions.ext_redis import redis_client
  15. from models.dataset import Dataset
  16. class OpenGaussConfig(BaseModel):
  17. host: str
  18. port: int
  19. user: str
  20. password: str
  21. database: str
  22. min_connection: int
  23. max_connection: int
  24. @model_validator(mode="before")
  25. @classmethod
  26. def validate_config(cls, values: dict) -> dict:
  27. if not values["host"]:
  28. raise ValueError("config OPENGAUSS_HOST is required")
  29. if not values["port"]:
  30. raise ValueError("config OPENGAUSS_PORT is required")
  31. if not values["user"]:
  32. raise ValueError("config OPENGAUSS_USER is required")
  33. if not values["password"]:
  34. raise ValueError("config OPENGAUSS_PASSWORD is required")
  35. if not values["database"]:
  36. raise ValueError("config OPENGAUSS_DATABASE is required")
  37. if not values["min_connection"]:
  38. raise ValueError("config OPENGAUSS_MIN_CONNECTION is required")
  39. if not values["max_connection"]:
  40. raise ValueError("config OPENGAUSS_MAX_CONNECTION is required")
  41. if values["min_connection"] > values["max_connection"]:
  42. raise ValueError("config OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION")
  43. return values
  44. SQL_CREATE_TABLE = """
  45. CREATE TABLE IF NOT EXISTS {table_name} (
  46. id UUID PRIMARY KEY,
  47. text TEXT NOT NULL,
  48. meta JSONB NOT NULL,
  49. embedding vector({dimension}) NOT NULL
  50. );
  51. """
  52. SQL_CREATE_INDEX = """
  53. CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
  54. USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
  55. """
  56. class OpenGauss(BaseVector):
  57. def __init__(self, collection_name: str, config: OpenGaussConfig):
  58. super().__init__(collection_name)
  59. self.pool = self._create_connection_pool(config)
  60. self.table_name = f"embedding_{collection_name}"
  61. def get_type(self) -> str:
  62. return VectorType.OPENGAUSS
  63. def _create_connection_pool(self, config: OpenGaussConfig):
  64. return psycopg2.pool.SimpleConnectionPool(
  65. config.min_connection,
  66. config.max_connection,
  67. host=config.host,
  68. port=config.port,
  69. user=config.user,
  70. password=config.password,
  71. database=config.database,
  72. )
  73. @contextmanager
  74. def _get_cursor(self):
  75. conn = self.pool.getconn()
  76. cur = conn.cursor()
  77. try:
  78. yield cur
  79. finally:
  80. cur.close()
  81. conn.commit()
  82. self.pool.putconn(conn)
  83. def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
  84. dimension = len(embeddings[0])
  85. self._create_collection(dimension)
  86. return self.add_texts(texts, embeddings)
  87. def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
  88. values = []
  89. pks = []
  90. for i, doc in enumerate(documents):
  91. if doc.metadata is not None:
  92. doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
  93. pks.append(doc_id)
  94. values.append(
  95. (
  96. doc_id,
  97. doc.page_content,
  98. json.dumps(doc.metadata),
  99. embeddings[i],
  100. )
  101. )
  102. with self._get_cursor() as cur:
  103. psycopg2.extras.execute_values(
  104. cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
  105. )
  106. return pks
  107. def text_exists(self, id: str) -> bool:
  108. with self._get_cursor() as cur:
  109. cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
  110. return cur.fetchone() is not None
  111. def get_by_ids(self, ids: list[str]) -> list[Document]:
  112. with self._get_cursor() as cur:
  113. cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
  114. docs = []
  115. for record in cur:
  116. docs.append(Document(page_content=record[1], metadata=record[0]))
  117. return docs
  118. def delete_by_ids(self, ids: list[str]) -> None:
  119. # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
  120. # Scenario 1: extract a document fails, resulting in a table not being created.
  121. # Then clicking the retry button triggers a delete operation on an empty list.
  122. if not ids:
  123. return
  124. with self._get_cursor() as cur:
  125. cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
  126. def delete_by_metadata_field(self, key: str, value: str) -> None:
  127. with self._get_cursor() as cur:
  128. cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
  129. def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
  130. """
  131. Search the nearest neighbors to a vector.
  132. :param query_vector: The input vector to search for similar items.
  133. :param top_k: The number of nearest neighbors to return, default is 5.
  134. :return: List of Documents that are nearest to the query vector.
  135. """
  136. top_k = kwargs.get("top_k", 4)
  137. if not isinstance(top_k, int) or top_k <= 0:
  138. raise ValueError("top_k must be a positive integer")
  139. with self._get_cursor() as cur:
  140. cur.execute(
  141. f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
  142. f" ORDER BY distance LIMIT {top_k}",
  143. (json.dumps(query_vector),),
  144. )
  145. docs = []
  146. score_threshold = float(kwargs.get("score_threshold") or 0.0)
  147. for record in cur:
  148. metadata, text, distance = record
  149. score = 1 - distance
  150. metadata["score"] = score
  151. if score > score_threshold:
  152. docs.append(Document(page_content=text, metadata=metadata))
  153. return docs
  154. def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
  155. top_k = kwargs.get("top_k", 5)
  156. if not isinstance(top_k, int) or top_k <= 0:
  157. raise ValueError("top_k must be a positive integer")
  158. with self._get_cursor() as cur:
  159. cur.execute(
  160. f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
  161. FROM {self.table_name}
  162. WHERE to_tsvector(text) @@ plainto_tsquery(%s)
  163. ORDER BY score DESC
  164. LIMIT {top_k}""",
  165. # f"'{query}'" is required in order to account for whitespace in query
  166. (f"'{query}'", f"'{query}'"),
  167. )
  168. docs = []
  169. for record in cur:
  170. metadata, text, score = record
  171. metadata["score"] = score
  172. docs.append(Document(page_content=text, metadata=metadata))
  173. return docs
  174. def delete(self) -> None:
  175. with self._get_cursor() as cur:
  176. cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
  177. def _create_collection(self, dimension: int):
  178. cache_key = f"vector_indexing_{self._collection_name}"
  179. lock_name = f"{cache_key}_lock"
  180. with redis_client.lock(lock_name, timeout=20):
  181. collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
  182. if redis_client.get(collection_exist_cache_key):
  183. return
  184. with self._get_cursor() as cur:
  185. cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
  186. if dimension <= 2000:
  187. cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
  188. redis_client.set(collection_exist_cache_key, 1, ex=3600)
  189. class OpenGaussFactory(AbstractVectorFactory):
  190. def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenGauss:
  191. if dataset.index_struct_dict:
  192. class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
  193. collection_name = class_prefix
  194. else:
  195. dataset_id = dataset.id
  196. collection_name = Dataset.gen_collection_name_by_id(dataset_id)
  197. dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENGAUSS, collection_name))
  198. return OpenGauss(
  199. collection_name=collection_name,
  200. config=OpenGaussConfig(
  201. host=dify_config.OPENGAUSS_HOST or "localhost",
  202. port=dify_config.OPENGAUSS_PORT,
  203. user=dify_config.OPENGAUSS_USER or "postgres",
  204. password=dify_config.OPENGAUSS_PASSWORD or "",
  205. database=dify_config.OPENGAUSS_DATABASE or "dify",
  206. min_connection=dify_config.OPENGAUSS_MIN_CONNECTION,
  207. max_connection=dify_config.OPENGAUSS_MAX_CONNECTION,
  208. ),
  209. )