file_factory.py 7.4 KB


  1. import mimetypes
  2. from collections.abc import Mapping, Sequence
  3. from typing import Any
  4. from sqlalchemy import select
  5. from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
  6. from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType
  7. from core.helper import ssrf_proxy
  8. from extensions.ext_database import db
  9. from models import MessageFile, ToolFile, UploadFile
  10. from models.enums import CreatedByRole
  11. def build_from_message_files(
  12. *,
  13. message_files: Sequence["MessageFile"],
  14. tenant_id: str,
  15. config: FileExtraConfig,
  16. ) -> Sequence[File]:
  17. results = [
  18. build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
  19. for file in message_files
  20. if file.belongs_to != FileBelongsTo.ASSISTANT
  21. ]
  22. return results
  23. def build_from_message_file(
  24. *,
  25. message_file: "MessageFile",
  26. tenant_id: str,
  27. config: FileExtraConfig,
  28. ):
  29. mapping = {
  30. "transfer_method": message_file.transfer_method,
  31. "url": message_file.url,
  32. "id": message_file.id,
  33. "type": message_file.type,
  34. "upload_file_id": message_file.upload_file_id,
  35. }
  36. return build_from_mapping(
  37. mapping=mapping,
  38. tenant_id=tenant_id,
  39. user_id=message_file.created_by,
  40. role=CreatedByRole(message_file.created_by_role),
  41. config=config,
  42. )
  43. def build_from_mapping(
  44. *,
  45. mapping: Mapping[str, Any],
  46. tenant_id: str,
  47. user_id: str,
  48. role: "CreatedByRole",
  49. config: FileExtraConfig,
  50. ):
  51. transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
  52. match transfer_method:
  53. case FileTransferMethod.REMOTE_URL:
  54. file = _build_from_remote_url(
  55. mapping=mapping,
  56. tenant_id=tenant_id,
  57. config=config,
  58. transfer_method=transfer_method,
  59. )
  60. case FileTransferMethod.LOCAL_FILE:
  61. file = _build_from_local_file(
  62. mapping=mapping,
  63. tenant_id=tenant_id,
  64. user_id=user_id,
  65. role=role,
  66. config=config,
  67. transfer_method=transfer_method,
  68. )
  69. case FileTransferMethod.TOOL_FILE:
  70. file = _build_from_tool_file(
  71. mapping=mapping,
  72. tenant_id=tenant_id,
  73. user_id=user_id,
  74. config=config,
  75. transfer_method=transfer_method,
  76. )
  77. case _:
  78. raise ValueError(f"Invalid file transfer method: {transfer_method}")
  79. return file
  80. def build_from_mappings(
  81. *,
  82. mappings: Sequence[Mapping[str, Any]],
  83. config: FileExtraConfig | None,
  84. tenant_id: str,
  85. user_id: str,
  86. role: "CreatedByRole",
  87. ) -> Sequence[File]:
  88. if not config:
  89. return []
  90. files = [
  91. build_from_mapping(
  92. mapping=mapping,
  93. tenant_id=tenant_id,
  94. user_id=user_id,
  95. role=role,
  96. config=config,
  97. )
  98. for mapping in mappings
  99. ]
  100. if (
  101. # If image config is set.
  102. config.image_config
  103. # And the number of image files exceeds the maximum limit
  104. and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
  105. ):
  106. raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
  107. if config.number_limits and len(files) > config.number_limits:
  108. raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
  109. return files
  110. def _build_from_local_file(
  111. *,
  112. mapping: Mapping[str, Any],
  113. tenant_id: str,
  114. user_id: str,
  115. role: "CreatedByRole",
  116. config: FileExtraConfig,
  117. transfer_method: FileTransferMethod,
  118. ):
  119. # check if the upload file exists.
  120. file_type = FileType.value_of(mapping.get("type"))
  121. stmt = select(UploadFile).where(
  122. UploadFile.id == mapping.get("upload_file_id"),
  123. UploadFile.tenant_id == tenant_id,
  124. UploadFile.created_by == user_id,
  125. UploadFile.created_by_role == role,
  126. )
  127. if file_type == FileType.IMAGE:
  128. stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS))
  129. elif file_type == FileType.VIDEO:
  130. stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS))
  131. elif file_type == FileType.AUDIO:
  132. stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS))
  133. elif file_type == FileType.DOCUMENT:
  134. stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS))
  135. row = db.session.scalar(stmt)
  136. if row is None:
  137. raise ValueError("Invalid upload file")
  138. file = File(
  139. id=mapping.get("id"),
  140. filename=row.name,
  141. extension=row.extension,
  142. mime_type=row.mime_type,
  143. tenant_id=tenant_id,
  144. type=file_type,
  145. transfer_method=transfer_method,
  146. remote_url=None,
  147. related_id=mapping.get("upload_file_id"),
  148. _extra_config=config,
  149. size=row.size,
  150. )
  151. return file
  152. def _build_from_remote_url(
  153. *,
  154. mapping: Mapping[str, Any],
  155. tenant_id: str,
  156. config: FileExtraConfig,
  157. transfer_method: FileTransferMethod,
  158. ):
  159. url = mapping.get("url")
  160. if not url:
  161. raise ValueError("Invalid file url")
  162. resp = ssrf_proxy.head(url)
  163. resp.raise_for_status()
  164. # Try to extract filename from response headers or URL
  165. content_disposition = resp.headers.get("Content-Disposition")
  166. if content_disposition:
  167. filename = content_disposition.split("filename=")[-1].strip('"')
  168. else:
  169. filename = url.split("/")[-1].split("?")[0]
  170. # If filename is empty, set a default one
  171. if not filename:
  172. filename = "unknown_file"
  173. # Determine file extension
  174. extension = "." + filename.split(".")[-1] if "." in filename else ".bin"
  175. # Create the File object
  176. file_size = int(resp.headers.get("Content-Length", -1))
  177. mime_type = str(resp.headers.get("Content-Type", ""))
  178. if not mime_type:
  179. mime_type, _ = mimetypes.guess_type(url)
  180. file = File(
  181. id=mapping.get("id"),
  182. filename=filename,
  183. tenant_id=tenant_id,
  184. type=FileType.value_of(mapping.get("type")),
  185. transfer_method=transfer_method,
  186. remote_url=url,
  187. _extra_config=config,
  188. mime_type=mime_type,
  189. extension=extension,
  190. size=file_size,
  191. )
  192. return file
  193. def _build_from_tool_file(
  194. *,
  195. mapping: Mapping[str, Any],
  196. tenant_id: str,
  197. user_id: str,
  198. config: FileExtraConfig,
  199. transfer_method: FileTransferMethod,
  200. ):
  201. tool_file = (
  202. db.session.query(ToolFile)
  203. .filter(
  204. ToolFile.id == mapping.get("tool_file_id"),
  205. ToolFile.tenant_id == tenant_id,
  206. ToolFile.user_id == user_id,
  207. )
  208. .first()
  209. )
  210. if tool_file is None:
  211. raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
  212. path = tool_file.file_key
  213. if "." in path:
  214. extension = "." + path.split("/")[-1].split(".")[-1]
  215. else:
  216. extension = ".bin"
  217. file = File(
  218. id=mapping.get("id"),
  219. tenant_id=tenant_id,
  220. filename=tool_file.name,
  221. type=FileType.value_of(mapping.get("type")),
  222. transfer_method=transfer_method,
  223. remote_url=tool_file.original_url,
  224. related_id=tool_file.id,
  225. extension=extension,
  226. mime_type=tool_file.mimetype,
  227. size=tool_file.size,
  228. _extra_config=config,
  229. )
  230. return file