| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 | 
							- from typing import Any, Dict, Optional, Sequence
 
- from langchain.schema import Document
 
- from sqlalchemy import func
 
- from core.model_providers.model_factory import ModelFactory
 
- from extensions.ext_database import db
 
- from models.dataset import Dataset, DocumentSegment
 
- class DatesetDocumentStore:
 
-     def __init__(
 
-         self,
 
-         dataset: Dataset,
 
-         user_id: str,
 
-         document_id: Optional[str] = None,
 
-     ):
 
-         self._dataset = dataset
 
-         self._user_id = user_id
 
-         self._document_id = document_id
 
-     @classmethod
 
-     def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore":
 
-         return cls(**config_dict)
 
-     def to_dict(self) -> Dict[str, Any]:
 
-         """Serialize to dict."""
 
-         return {
 
-             "dataset_id": self._dataset.id,
 
-         }
 
-     @property
 
-     def dateset_id(self) -> Any:
 
-         return self._dataset.id
 
-     @property
 
-     def user_id(self) -> Any:
 
-         return self._user_id
 
-     @property
 
-     def docs(self) -> Dict[str, Document]:
 
-         document_segments = db.session.query(DocumentSegment).filter(
 
-             DocumentSegment.dataset_id == self._dataset.id
 
-         ).all()
 
-         output = {}
 
-         for document_segment in document_segments:
 
-             doc_id = document_segment.index_node_id
 
-             output[doc_id] = Document(
 
-                 page_content=document_segment.content,
 
-                 metadata={
 
-                     "doc_id": document_segment.index_node_id,
 
-                     "doc_hash": document_segment.index_node_hash,
 
-                     "document_id": document_segment.document_id,
 
-                     "dataset_id": document_segment.dataset_id,
 
-                 }
 
-             )
 
-         return output
 
-     def add_documents(
 
-         self, docs: Sequence[Document], allow_update: bool = True
 
-     ) -> None:
 
-         max_position = db.session.query(func.max(DocumentSegment.position)).filter(
 
-             DocumentSegment.document_id == self._document_id
 
-         ).scalar()
 
-         if max_position is None:
 
-             max_position = 0
 
-         embedding_model = ModelFactory.get_embedding_model(
 
-             tenant_id=self._dataset.tenant_id
 
-         )
 
-         for doc in docs:
 
-             if not isinstance(doc, Document):
 
-                 raise ValueError("doc must be a Document")
 
-             segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
 
-             # NOTE: doc could already exist in the store, but we overwrite it
 
-             if not allow_update and segment_document:
 
-                 raise ValueError(
 
-                     f"doc_id {doc.metadata['doc_id']} already exists. "
 
-                     "Set allow_update to True to overwrite."
 
-                 )
 
-             # calc embedding use tokens
 
-             tokens = embedding_model.get_num_tokens(doc.page_content)
 
-             if not segment_document:
 
-                 max_position += 1
 
-                 segment_document = DocumentSegment(
 
-                     tenant_id=self._dataset.tenant_id,
 
-                     dataset_id=self._dataset.id,
 
-                     document_id=self._document_id,
 
-                     index_node_id=doc.metadata['doc_id'],
 
-                     index_node_hash=doc.metadata['doc_hash'],
 
-                     position=max_position,
 
-                     content=doc.page_content,
 
-                     word_count=len(doc.page_content),
 
-                     tokens=tokens,
 
-                     created_by=self._user_id,
 
-                 )
 
-                 if 'answer' in doc.metadata and doc.metadata['answer']:
 
-                     segment_document.answer = doc.metadata.pop('answer', '')
 
-                 db.session.add(segment_document)
 
-             else:
 
-                 segment_document.content = doc.page_content
 
-                 if 'answer' in doc.metadata and doc.metadata['answer']:
 
-                     segment_document.answer = doc.metadata.pop('answer', '')
 
-                 segment_document.index_node_hash = doc.metadata['doc_hash']
 
-                 segment_document.word_count = len(doc.page_content)
 
-                 segment_document.tokens = tokens
 
-             db.session.commit()
 
-     def document_exists(self, doc_id: str) -> bool:
 
-         """Check if document exists."""
 
-         result = self.get_document_segment(doc_id)
 
-         return result is not None
 
-     def get_document(
 
-         self, doc_id: str, raise_error: bool = True
 
-     ) -> Optional[Document]:
 
-         document_segment = self.get_document_segment(doc_id)
 
-         if document_segment is None:
 
-             if raise_error:
 
-                 raise ValueError(f"doc_id {doc_id} not found.")
 
-             else:
 
-                 return None
 
-         return Document(
 
-             page_content=document_segment.content,
 
-             metadata={
 
-                 "doc_id": document_segment.index_node_id,
 
-                 "doc_hash": document_segment.index_node_hash,
 
-                 "document_id": document_segment.document_id,
 
-                 "dataset_id": document_segment.dataset_id,
 
-             }
 
-         )
 
-     def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
 
-         document_segment = self.get_document_segment(doc_id)
 
-         if document_segment is None:
 
-             if raise_error:
 
-                 raise ValueError(f"doc_id {doc_id} not found.")
 
-             else:
 
-                 return None
 
-         db.session.delete(document_segment)
 
-         db.session.commit()
 
-     def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
 
-         """Set the hash for a given doc_id."""
 
-         document_segment = self.get_document_segment(doc_id)
 
-         if document_segment is None:
 
-             return None
 
-         document_segment.index_node_hash = doc_hash
 
-         db.session.commit()
 
-     def get_document_hash(self, doc_id: str) -> Optional[str]:
 
-         """Get the stored hash for a document, if it exists."""
 
-         document_segment = self.get_document_segment(doc_id)
 
-         if document_segment is None:
 
-             return None
 
-         return document_segment.index_node_hash
 
-     def get_document_segment(self, doc_id: str) -> DocumentSegment:
 
-         document_segment = db.session.query(DocumentSegment).filter(
 
-             DocumentSegment.dataset_id == self._dataset.id,
 
-             DocumentSegment.index_node_id == doc_id
 
-         ).first()
 
-         return document_segment
 
 
  |