file_manager.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import base64
  2. from configs import dify_config
  3. from core.file import file_repository
  4. from core.helper import ssrf_proxy
  5. from core.model_runtime.entities import (
  6. AudioPromptMessageContent,
  7. DocumentPromptMessageContent,
  8. ImagePromptMessageContent,
  9. VideoPromptMessageContent,
  10. )
  11. from extensions.ext_database import db
  12. from extensions.ext_storage import storage
  13. from . import helpers
  14. from .enums import FileAttribute
  15. from .models import File, FileTransferMethod, FileType
  16. from .tool_file_parser import ToolFileParser
  17. def get_attr(*, file: File, attr: FileAttribute):
  18. match attr:
  19. case FileAttribute.TYPE:
  20. return file.type.value
  21. case FileAttribute.SIZE:
  22. return file.size
  23. case FileAttribute.NAME:
  24. return file.filename
  25. case FileAttribute.MIME_TYPE:
  26. return file.mime_type
  27. case FileAttribute.TRANSFER_METHOD:
  28. return file.transfer_method.value
  29. case FileAttribute.URL:
  30. return file.remote_url
  31. case FileAttribute.EXTENSION:
  32. return file.extension
  33. def to_prompt_message_content(
  34. f: File,
  35. /,
  36. *,
  37. image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
  38. ):
  39. match f.type:
  40. case FileType.IMAGE:
  41. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  42. if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
  43. data = _to_url(f)
  44. else:
  45. data = _to_base64_data_string(f)
  46. return ImagePromptMessageContent(data=data, detail=image_detail_config)
  47. case FileType.AUDIO:
  48. encoded_string = _get_encoded_string(f)
  49. if f.extension is None:
  50. raise ValueError("Missing file extension")
  51. return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
  52. case FileType.VIDEO:
  53. if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
  54. data = _to_url(f)
  55. else:
  56. data = _to_base64_data_string(f)
  57. if f.extension is None:
  58. raise ValueError("Missing file extension")
  59. return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
  60. case FileType.DOCUMENT:
  61. data = _get_encoded_string(f)
  62. if f.mime_type is None:
  63. raise ValueError("Missing file mime_type")
  64. return DocumentPromptMessageContent(
  65. encode_format="base64",
  66. mime_type=f.mime_type,
  67. data=data,
  68. )
  69. case _:
  70. raise ValueError(f"file type {f.type} is not supported")
  71. def download(f: File, /):
  72. if f.transfer_method == FileTransferMethod.TOOL_FILE:
  73. tool_file = file_repository.get_tool_file(session=db.session(), file=f)
  74. return _download_file_content(tool_file.file_key)
  75. elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
  76. upload_file = file_repository.get_upload_file(session=db.session(), file=f)
  77. return _download_file_content(upload_file.key)
  78. # remote file
  79. response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
  80. response.raise_for_status()
  81. return response.content
  82. def _download_file_content(path: str, /):
  83. """
  84. Download and return the contents of a file as bytes.
  85. This function loads the file from storage and ensures it's in bytes format.
  86. Args:
  87. path (str): The path to the file in storage.
  88. Returns:
  89. bytes: The contents of the file as a bytes object.
  90. Raises:
  91. ValueError: If the loaded file is not a bytes object.
  92. """
  93. data = storage.load(path, stream=False)
  94. if not isinstance(data, bytes):
  95. raise ValueError(f"file {path} is not a bytes object")
  96. return data
  97. def _get_encoded_string(f: File, /):
  98. match f.transfer_method:
  99. case FileTransferMethod.REMOTE_URL:
  100. response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
  101. response.raise_for_status()
  102. data = response.content
  103. case FileTransferMethod.LOCAL_FILE:
  104. upload_file = file_repository.get_upload_file(session=db.session(), file=f)
  105. data = _download_file_content(upload_file.key)
  106. case FileTransferMethod.TOOL_FILE:
  107. tool_file = file_repository.get_tool_file(session=db.session(), file=f)
  108. data = _download_file_content(tool_file.file_key)
  109. encoded_string = base64.b64encode(data).decode("utf-8")
  110. return encoded_string
  111. def _to_base64_data_string(f: File, /):
  112. encoded_string = _get_encoded_string(f)
  113. return f"data:{f.mime_type};base64,{encoded_string}"
  114. def _to_url(f: File, /):
  115. if f.transfer_method == FileTransferMethod.REMOTE_URL:
  116. if f.remote_url is None:
  117. raise ValueError("Missing file remote_url")
  118. return f.remote_url
  119. elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
  120. if f.related_id is None:
  121. raise ValueError("Missing file related_id")
  122. return helpers.get_signed_file_url(upload_file_id=f.related_id)
  123. elif f.transfer_method == FileTransferMethod.TOOL_FILE:
  124. # add sign url
  125. if f.related_id is None or f.extension is None:
  126. raise ValueError("Missing file related_id or extension")
  127. return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
  128. else:
  129. raise ValueError(f"Unsupported transfer method: {f.transfer_method}")