| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532 | import binasciifrom collections.abc import Generator, Sequencefrom typing import IO, Optionalfrom core.model_runtime.entities.llm_entities import LLMResultChunkfrom core.model_runtime.entities.message_entities import PromptMessage, PromptMessageToolfrom core.model_runtime.entities.model_entities import AIModelEntityfrom core.model_runtime.entities.rerank_entities import RerankResultfrom core.model_runtime.entities.text_embedding_entities import TextEmbeddingResultfrom core.model_runtime.utils.encoders import jsonable_encoderfrom core.plugin.entities.plugin_daemon import (    PluginBasicBooleanResponse,    PluginDaemonInnerError,    PluginLLMNumTokensResponse,    PluginModelProviderEntity,    PluginModelSchemaEntity,    PluginStringResultResponse,    PluginTextEmbeddingNumTokensResponse,    PluginVoicesResponse,)from core.plugin.manager.base import BasePluginManagerclass PluginModelManager(BasePluginManager):    def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:        """        Fetch model providers for the given tenant.        """        response = self._request_with_plugin_daemon_response(            "GET",            f"plugin/{tenant_id}/management/models",            list[PluginModelProviderEntity],            params={"page": 1, "page_size": 256},        )        return response    def get_model_schema(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model_type: str,        model: str,        credentials: dict,    ) -> AIModelEntity | None:        """        Get model schema        """        response = self._request_with_plugin_daemon_response_stream(            "POST",            f"plugin/{tenant_id}/dispatch/model/schema",            PluginModelSchemaEntity,            data={                "user_id": user_id,                "data": {                    "provider": provider,                    "model_type": model_type,                    "model": model,                    "credentials": credentials,                },            },            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            return resp.model_schema        return None    def validate_provider_credentials(        self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict    ) -> bool:        """        validate the credentials of the provider        """        response = self._request_with_plugin_daemon_response_stream(            "POST",            f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",            PluginBasicBooleanResponse,            data={                "user_id": user_id,                "data": {                    "provider": provider,                    "credentials": credentials,                },            },            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            if resp.credentials and isinstance(resp.credentials, dict):                credentials.update(resp.credentials)            return resp.result        return False    def validate_model_credentials(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model_type: str,        model: str,        credentials: dict,    ) -> bool:        """        validate the credentials of the provider        """        response = self._request_with_plugin_daemon_response_stream(            "POST",            f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",            PluginBasicBooleanResponse,            data={                "user_id": user_id,                "data": {                    "provider": provider,                    "model_type": model_type,                    "model": model,                    "credentials": credentials,                },            },            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            if resp.credentials and isinstance(resp.credentials, dict):                credentials.update(resp.credentials)            return resp.result        return False    def invoke_llm(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model: str,        credentials: dict,        prompt_messages: list[PromptMessage],        model_parameters: Optional[dict] = None,        tools: Optional[list[PromptMessageTool]] = None,        stop: Optional[list[str]] = None,        stream: bool = True,    ) -> Generator[LLMResultChunk, None, None]:        """        Invoke llm        """        response = self._request_with_plugin_daemon_response_stream(            method="POST",            path=f"plugin/{tenant_id}/dispatch/llm/invoke",            type=LLMResultChunk,            data=jsonable_encoder(                {                    "user_id": user_id,                    "data": {                        "provider": provider,                        "model_type": "llm",                        "model": model,                        "credentials": credentials,                        "prompt_messages": prompt_messages,                        "model_parameters": model_parameters,                        "tools": tools,                        "stop": stop,                        "stream": stream,                    },                }            ),            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        try:            yield from response        except PluginDaemonInnerError as e:            raise ValueError(e.message + str(e.code))    def get_llm_num_tokens(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model_type: str,        model: str,        credentials: dict,        prompt_messages: list[PromptMessage],        tools: Optional[list[PromptMessageTool]] = None,    ) -> int:        """        Get number of tokens for llm        """        response = self._request_with_plugin_daemon_response_stream(            method="POST",            path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",            type=PluginLLMNumTokensResponse,            data=jsonable_encoder(                {                    "user_id": user_id,                    "data": {                        "provider": provider,                        "model_type": model_type,                        "model": model,                        "credentials": credentials,                        "prompt_messages": prompt_messages,                        "tools": tools,                    },                }            ),            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            return resp.num_tokens        return 0    def invoke_text_embedding(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model: str,        credentials: dict,        texts: list[str],        input_type: str,    ) -> TextEmbeddingResult:        """        Invoke text embedding        """        response = self._request_with_plugin_daemon_response_stream(            method="POST",            path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",            type=TextEmbeddingResult,            data=jsonable_encoder(                {                    "user_id": user_id,                    "data": {                        "provider": provider,                        "model_type": "text-embedding",                        "model": model,                        "credentials": credentials,                        "texts": texts,                        "input_type": input_type,                    },                }            ),            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            return resp        raise ValueError("Failed to invoke text embedding")    def get_text_embedding_num_tokens(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model: str,        credentials: dict,        texts: list[str],    ) -> list[int]:        """        Get number of tokens for text embedding        """        response = self._request_with_plugin_daemon_response_stream(            method="POST",            path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",            type=PluginTextEmbeddingNumTokensResponse,            data=jsonable_encoder(                {                    "user_id": user_id,                    "data": {                        "provider": provider,                        "model_type": "text-embedding",                        "model": model,                        "credentials": credentials,                        "texts": texts,                    },                }            ),            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            return resp.num_tokens        return []    def invoke_rerank(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model: str,        credentials: dict,        query: str,        docs: list[str],        score_threshold: Optional[float] = None,        top_n: Optional[int] = None,    ) -> RerankResult:        """        Invoke rerank        """        response = self._request_with_plugin_daemon_response_stream(            method="POST",            path=f"plugin/{tenant_id}/dispatch/rerank/invoke",            type=RerankResult,            data=jsonable_encoder(                {                    "user_id": user_id,                    "data": {                        "provider": provider,                        "model_type": "rerank",                        "model": model,                        "credentials": credentials,                        "query": query,                        "docs": docs,                        "score_threshold": score_threshold,                        "top_n": top_n,                    },                }            ),            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            return resp        raise ValueError("Failed to invoke rerank")    def invoke_tts(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model: str,        credentials: dict,        content_text: str,        voice: str,    ) -> Generator[bytes, None, None]:        """        Invoke tts        """        response = self._request_with_plugin_daemon_response_stream(            method="POST",            path=f"plugin/{tenant_id}/dispatch/tts/invoke",            type=PluginStringResultResponse,            data=jsonable_encoder(                {                    "user_id": user_id,                    "data": {                        "provider": provider,                        "model_type": "tts",                        "model": model,                        "credentials": credentials,                        "tenant_id": tenant_id,                        "content_text": content_text,                        "voice": voice,                    },                }            ),            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        try:            for result in response:                hex_str = result.result                yield binascii.unhexlify(hex_str)        except PluginDaemonInnerError as e:            raise ValueError(e.message + str(e.code))    def get_tts_model_voices(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model: str,        credentials: dict,        language: Optional[str] = None,    ) -> list[dict]:        """        Get tts model voices        """        response = self._request_with_plugin_daemon_response_stream(            method="POST",            path=f"plugin/{tenant_id}/dispatch/tts/model/voices",            type=PluginVoicesResponse,            data=jsonable_encoder(                {                    "user_id": user_id,                    "data": {                        "provider": provider,                        "model_type": "tts",                        "model": model,                        "credentials": credentials,                        "language": language,                    },                }            ),            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            voices = []            for voice in resp.voices:                voices.append({"name": voice.name, "value": voice.value})            return voices        return []    def invoke_speech_to_text(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model: str,        credentials: dict,        file: IO[bytes],    ) -> str:        """        Invoke speech to text        """        response = self._request_with_plugin_daemon_response_stream(            method="POST",            path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",            type=PluginStringResultResponse,            data=jsonable_encoder(                {                    "user_id": user_id,                    "data": {                        "provider": provider,                        "model_type": "speech2text",                        "model": model,                        "credentials": credentials,                        "file": binascii.hexlify(file.read()).decode(),                    },                }            ),            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            return resp.result        raise ValueError("Failed to invoke speech to text")    def invoke_moderation(        self,        tenant_id: str,        user_id: str,        plugin_id: str,        provider: str,        model: str,        credentials: dict,        text: str,    ) -> bool:        """        Invoke moderation        """        response = self._request_with_plugin_daemon_response_stream(            method="POST",            path=f"plugin/{tenant_id}/dispatch/moderation/invoke",            type=PluginBasicBooleanResponse,            data=jsonable_encoder(                {                    "user_id": user_id,                    "data": {                        "provider": provider,                        "model_type": "moderation",                        "model": model,                        "credentials": credentials,                        "text": text,                    },                }            ),            headers={                "X-Plugin-ID": plugin_id,                "Content-Type": "application/json",            },        )        for resp in response:            return resp.result        raise ValueError("Failed to invoke moderation")
 |