import json
import threading
from typing import Type, Optional, List

from flask import current_app
from langchain.tools import BaseTool
from pydantic import Field, BaseModel

from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.conversation_message_task import ConversationMessageTask
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, Document
from services.retrieval_service import RetrievalService

default_retrieval_model = {
    'search_method': 'semantic_search',
    'reranking_enable': False,
    'reranking_model': {
        'reranking_provider_name': '',
        'reranking_model_name': ''
    },
    'top_k': 2,
    'score_threshold_enable': False
}


class DatasetRetrieverToolInput(BaseModel):
    query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")


class DatasetRetrieverTool(BaseTool):
    """Tool for querying a Dataset."""
    name: str = "dataset"
    args_schema: Type[BaseModel] = DatasetRetrieverToolInput
    description: str = "use this to retrieve a dataset. "

    tenant_id: str
    dataset_id: str
    top_k: int = 2
    score_threshold: Optional[float] = None
    conversation_message_task: ConversationMessageTask
    return_resource: bool
    retriever_from: str

    @classmethod
    def from_dataset(cls, dataset: Dataset, **kwargs):
        description = dataset.description
        if not description:
            description = 'useful for when you want to answer queries about the ' + dataset.name

        description = description.replace('\n', '').replace('\r', '')
        return cls(
            name=f'dataset-{dataset.id}',
            tenant_id=dataset.tenant_id,
            dataset_id=dataset.id,
            description=description,
            **kwargs
        )

    def _run(self, query: str) -> str:
        dataset = db.session.query(Dataset).filter(
            Dataset.tenant_id == self.tenant_id,
            Dataset.id == self.dataset_id
        ).first()

        if not dataset:
            return ''
        # get retrieval model , if the model is not setting , using default
        retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model

        if dataset.indexing_technique == "economy":
            # use keyword table query
            kw_table_index = KeywordTableIndex(
                dataset=dataset,
                config=KeywordTableConfig(
                    max_keywords_per_chunk=5
                )
            )

            documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
            return str("\n".join([document.page_content for document in documents]))
        else:

            try:
                embedding_model = ModelFactory.get_embedding_model(
                    tenant_id=dataset.tenant_id,
                    model_provider_name=dataset.embedding_model_provider,
                    model_name=dataset.embedding_model
                )
            except LLMBadRequestError:
                return ''
            except ProviderTokenNotInitError:
                return ''

            embeddings = CacheEmbedding(embedding_model)

            documents = []
            threads = []
            if self.top_k > 0:
                # retrieval source with semantic
                if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
                    embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
                        'flask_app': current_app._get_current_object(),
                        'dataset_id': str(dataset.id),
                        'query': query,
                        'top_k': self.top_k,
                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
                            'score_threshold_enable'] else None,
                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
                            'reranking_enable'] else None,
                        'all_documents': documents,
                        'search_method': retrieval_model['search_method'],
                        'embeddings': embeddings
                    })
                    threads.append(embedding_thread)
                    embedding_thread.start()

                # retrieval_model source with full text
                if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
                    full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
                        'flask_app': current_app._get_current_object(),
                        'dataset_id': str(dataset.id),
                        'query': query,
                        'search_method': retrieval_model['search_method'],
                        'embeddings': embeddings,
                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
                            'score_threshold_enable'] else None,
                        'top_k': self.top_k,
                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
                            'reranking_enable'] else None,
                        'all_documents': documents
                    })
                    threads.append(full_text_index_thread)
                    full_text_index_thread.start()

                for thread in threads:
                    thread.join()
                # hybrid search: rerank after all documents have been searched
                if retrieval_model['search_method'] == 'hybrid_search':
                    hybrid_rerank = ModelFactory.get_reranking_model(
                        tenant_id=dataset.tenant_id,
                        model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
                        model_name=retrieval_model['reranking_model']['reranking_model_name']
                    )
                    documents = hybrid_rerank.rerank(query, documents,
                                                     retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
                                                     self.top_k)
            else:
                documents = []

            hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
            hit_callback.on_tool_end(documents)
            document_score_list = {}
            if dataset.indexing_technique != "economy":
                for item in documents:
                    document_score_list[item.metadata['doc_id']] = item.metadata['score']
            document_context_list = []
            index_node_ids = [document.metadata['doc_id'] for document in documents]
            segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
                                                    DocumentSegment.completed_at.isnot(None),
                                                    DocumentSegment.status == 'completed',
                                                    DocumentSegment.enabled == True,
                                                    DocumentSegment.index_node_id.in_(index_node_ids)
                                                    ).all()

            if segments:
                index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
                sorted_segments = sorted(segments,
                                         key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
                                                                                           float('inf')))
                for segment in sorted_segments:
                    if segment.answer:
                        document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
                    else:
                        document_context_list.append(segment.content)
                if self.return_resource:
                    context_list = []
                    resource_number = 1
                    for segment in sorted_segments:
                        context = {}
                        document = Document.query.filter(Document.id == segment.document_id,
                                                         Document.enabled == True,
                                                         Document.archived == False,
                                                         ).first()
                        if dataset and document:
                            source = {
                                'position': resource_number,
                                'dataset_id': dataset.id,
                                'dataset_name': dataset.name,
                                'document_id': document.id,
                                'document_name': document.name,
                                'data_source_type': document.data_source_type,
                                'segment_id': segment.id,
                                'retriever_from': self.retriever_from,
                                'score': document_score_list.get(segment.index_node_id, None)

                            }
                            if self.retriever_from == 'dev':
                                source['hit_count'] = segment.hit_count
                                source['word_count'] = segment.word_count
                                source['segment_position'] = segment.position
                                source['index_node_hash'] = segment.index_node_hash
                            if segment.answer:
                                source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
                            else:
                                source['content'] = segment.content
                            context_list.append(source)
                        resource_number += 1
                    hit_callback.return_retriever_resource_info(context_list)

            return str("\n".join(document_context_list))

    async def _arun(self, tool_input: str) -> str:
        raise NotImplementedError()