rerank.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from typing import Optional
  2. import httpx
  3. from dify_plugin import RerankModel
  4. from dify_plugin.entities import I18nObject
  5. from dify_plugin.entities.model import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
  6. from dify_plugin.errors.model import (
  7. CredentialsValidateFailedError,
  8. InvokeAuthorizationError,
  9. InvokeBadRequestError,
  10. InvokeConnectionError,
  11. InvokeError,
  12. InvokeRateLimitError,
  13. InvokeServerUnavailableError,
  14. )
  15. from dify_plugin.entities.model.rerank import (
  16. RerankDocument,
  17. RerankResult,
  18. )
  19. class {{ .PluginName | SnakeToCamel }}RerankModel(RerankModel):
  20. """
  21. Model class for {{ .PluginName | SnakeToCamel }} rerank model.
  22. """
  23. def _invoke(
  24. self,
  25. model: str,
  26. credentials: dict,
  27. query: str,
  28. docs: list[str],
  29. score_threshold: Optional[float] = None,
  30. top_n: Optional[int] = None,
  31. user: Optional[str] = None,
  32. ) -> RerankResult:
  33. """
  34. Invoke rerank model
  35. :param model: model name
  36. :param credentials: model credentials
  37. :param query: search query
  38. :param docs: docs for reranking
  39. :param score_threshold: score threshold
  40. :param top_n: top n documents to return
  41. :param user: unique user id
  42. :return: rerank result
  43. """
  44. pass
  45. def validate_credentials(self, model: str, credentials: dict) -> None:
  46. """
  47. Validate model credentials
  48. :param model: model name
  49. :param credentials: model credentials
  50. :return:
  51. """
  52. try:
  53. self._invoke(
  54. model=model,
  55. credentials=credentials,
  56. query="What is the capital of the United States?",
  57. docs=[
  58. "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
  59. "Census, Carson City had a population of 55,274.",
  60. "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
  61. "are a political division controlled by the United States. Its capital is Saipan.",
  62. ],
  63. score_threshold=0.8,
  64. )
  65. except Exception as ex:
  66. raise CredentialsValidateFailedError(str(ex))
  67. @property
  68. def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
  69. """
  70. Map model invoke error to unified error
  71. """
  72. return {
  73. InvokeConnectionError: [httpx.ConnectError],
  74. InvokeServerUnavailableError: [httpx.RemoteProtocolError],
  75. InvokeRateLimitError: [],
  76. InvokeAuthorizationError: [httpx.HTTPStatusError],
  77. InvokeBadRequestError: [httpx.RequestError],
  78. }
  79. def get_customizable_model_schema(
  80. self, model: str, credentials: dict
  81. ) -> AIModelEntity:
  82. """
  83. generate custom model entities from credentials
  84. """
  85. entity = AIModelEntity(
  86. model=model,
  87. label=I18nObject(en_US=model),
  88. model_type=ModelType.RERANK,
  89. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  90. model_properties={
  91. ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size") or 0)
  92. },
  93. )
  94. return entity