|
@@ -49,7 +49,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
retrieval_thread = threading.Thread(
|
|
|
target=self._retriever,
|
|
|
kwargs={
|
|
|
- "flask_app": current_app._get_current_object(),
|
|
|
+ "flask_app": current_app._get_current_object(), # type: ignore
|
|
|
"dataset_id": dataset_id,
|
|
|
"query": query,
|
|
|
"all_documents": all_documents,
|
|
@@ -77,11 +77,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
|
document_score_list = {}
|
|
|
for item in all_documents:
|
|
|
+ assert item.metadata
|
|
|
if item.metadata.get("score"):
|
|
|
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
|
|
|
|
|
document_context_list = []
|
|
|
- index_node_ids = [document.metadata["doc_id"] for document in all_documents]
|
|
|
+ index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
|
|
segments = DocumentSegment.query.filter(
|
|
|
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
|
|
DocumentSegment.completed_at.isnot(None),
|
|
@@ -140,6 +141,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
|
return str("\n".join(document_context_list))
|
|
|
|
|
|
+ raise RuntimeError("not segments found")
|
|
|
+
|
|
|
def _retriever(
|
|
|
self,
|
|
|
flask_app: Flask,
|