123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- import base64
- from configs import dify_config
- from core.file import file_repository
- from core.helper import ssrf_proxy
- from core.model_runtime.entities import (
- AudioPromptMessageContent,
- DocumentPromptMessageContent,
- ImagePromptMessageContent,
- VideoPromptMessageContent,
- )
- from extensions.ext_database import db
- from extensions.ext_storage import storage
- from . import helpers
- from .enums import FileAttribute
- from .models import File, FileTransferMethod, FileType
- from .tool_file_parser import ToolFileParser
- def get_attr(*, file: File, attr: FileAttribute):
- match attr:
- case FileAttribute.TYPE:
- return file.type.value
- case FileAttribute.SIZE:
- return file.size
- case FileAttribute.NAME:
- return file.filename
- case FileAttribute.MIME_TYPE:
- return file.mime_type
- case FileAttribute.TRANSFER_METHOD:
- return file.transfer_method.value
- case FileAttribute.URL:
- return file.remote_url
- case FileAttribute.EXTENSION:
- return file.extension
- def to_prompt_message_content(
- f: File,
- /,
- *,
- image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
- ):
- match f.type:
- case FileType.IMAGE:
- image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
- if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
- data = _to_url(f)
- else:
- data = _to_base64_data_string(f)
- return ImagePromptMessageContent(data=data, detail=image_detail_config)
- case FileType.AUDIO:
- encoded_string = _get_encoded_string(f)
- if f.extension is None:
- raise ValueError("Missing file extension")
- return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
- case FileType.VIDEO:
- if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
- data = _to_url(f)
- else:
- data = _to_base64_data_string(f)
- if f.extension is None:
- raise ValueError("Missing file extension")
- return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
- case FileType.DOCUMENT:
- data = _get_encoded_string(f)
- if f.mime_type is None:
- raise ValueError("Missing file mime_type")
- return DocumentPromptMessageContent(
- encode_format="base64",
- mime_type=f.mime_type,
- data=data,
- )
- case _:
- raise ValueError(f"file type {f.type} is not supported")
- def download(f: File, /):
- if f.transfer_method == FileTransferMethod.TOOL_FILE:
- tool_file = file_repository.get_tool_file(session=db.session(), file=f)
- return _download_file_content(tool_file.file_key)
- elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
- upload_file = file_repository.get_upload_file(session=db.session(), file=f)
- return _download_file_content(upload_file.key)
- # remote file
- response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
- response.raise_for_status()
- return response.content
- def _download_file_content(path: str, /):
- """
- Download and return the contents of a file as bytes.
- This function loads the file from storage and ensures it's in bytes format.
- Args:
- path (str): The path to the file in storage.
- Returns:
- bytes: The contents of the file as a bytes object.
- Raises:
- ValueError: If the loaded file is not a bytes object.
- """
- data = storage.load(path, stream=False)
- if not isinstance(data, bytes):
- raise ValueError(f"file {path} is not a bytes object")
- return data
- def _get_encoded_string(f: File, /):
- match f.transfer_method:
- case FileTransferMethod.REMOTE_URL:
- response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
- response.raise_for_status()
- data = response.content
- case FileTransferMethod.LOCAL_FILE:
- upload_file = file_repository.get_upload_file(session=db.session(), file=f)
- data = _download_file_content(upload_file.key)
- case FileTransferMethod.TOOL_FILE:
- tool_file = file_repository.get_tool_file(session=db.session(), file=f)
- data = _download_file_content(tool_file.file_key)
- encoded_string = base64.b64encode(data).decode("utf-8")
- return encoded_string
- def _to_base64_data_string(f: File, /):
- encoded_string = _get_encoded_string(f)
- return f"data:{f.mime_type};base64,{encoded_string}"
- def _to_url(f: File, /):
- if f.transfer_method == FileTransferMethod.REMOTE_URL:
- if f.remote_url is None:
- raise ValueError("Missing file remote_url")
- return f.remote_url
- elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
- if f.related_id is None:
- raise ValueError("Missing file related_id")
- return helpers.get_signed_file_url(upload_file_id=f.related_id)
- elif f.transfer_method == FileTransferMethod.TOOL_FILE:
- # add sign url
- if f.related_id is None or f.extension is None:
- raise ValueError("Missing file related_id or extension")
- return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
- else:
- raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|