Pārlūkot izejas kodu

feat: move audio and webscraper back to dify

Yeuoly 5 mēneši atpakaļ
vecāks
revīzija
dcf19549cb

+ 3 - 3
api/core/agent/cot_agent_runner.py

@@ -309,13 +309,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         )
 
         # publish files
-        for message_file_id, save_as in message_files:
+        for message_file_id in message_files:
             # publish message file
             self.queue_manager.publish(
-                QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
+                QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
             )
             # add message file ids
-            message_file_ids.append(message_file_id.id)
+            message_file_ids.append(message_file_id)
 
         return tool_invoke_response, tool_invoke_meta
 

+ 3 - 3
api/core/agent/fc_agent_runner.py

@@ -246,13 +246,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
                         conversation_id=self.conversation.id,
                     )
                     # publish files
-                    for message_file_id, save_as in message_files:
+                    for message_file_id in message_files:
                         # publish message file
                         self.queue_manager.publish(
-                            QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
+                            QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
                         )
                         # add message file ids
-                        message_file_ids.append(message_file_id.id)
+                        message_file_ids.append(message_file_id)
 
                     tool_response = {
                         "tool_call_id": tool_call_id,

+ 1 - 1
api/core/app/apps/agent_chat/app_generator.py

@@ -172,7 +172,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
             target=self._generate_worker,
             kwargs={
                 "flask_app": current_app._get_current_object(),  # type: ignore
-                "contexts": contextvars.copy_context(),
+                "context": contextvars.copy_context(),
                 "application_generate_entity": application_generate_entity,
                 "queue_manager": queue_manager,
                 "conversation_id": conversation.id,

+ 11 - 9
api/core/tools/__base/tool.py

@@ -157,7 +157,10 @@ class Tool(ABC):
 
         return parameters
 
-    def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage:
+    def create_image_message(
+        self,
+        image: str,
+    ) -> ToolInvokeMessage:
         """
         create an image message
 
@@ -165,7 +168,7 @@ class Tool(ABC):
         :return: the image message
         """
         return ToolInvokeMessage(
-            type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image), save_as=save_as
+            type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
         )
 
     def create_file_message(self, file: "File") -> ToolInvokeMessage:
@@ -173,10 +176,9 @@ class Tool(ABC):
             type=ToolInvokeMessage.MessageType.FILE,
             message=ToolInvokeMessage.FileMessage(),
             meta={"file": file},
-            save_as="",
         )
 
-    def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
+    def create_link_message(self, link: str) -> ToolInvokeMessage:
         """
         create a link message
 
@@ -184,10 +186,10 @@ class Tool(ABC):
         :return: the link message
         """
         return ToolInvokeMessage(
-            type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link), save_as=save_as
+            type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
         )
 
-    def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage:
+    def create_text_message(self, text: str) -> ToolInvokeMessage:
         """
         create a text message
 
@@ -195,10 +197,11 @@ class Tool(ABC):
         :return: the text message
         """
         return ToolInvokeMessage(
-            type=ToolInvokeMessage.MessageType.TEXT, message=ToolInvokeMessage.TextMessage(text=text), save_as=save_as
+            type=ToolInvokeMessage.MessageType.TEXT,
+            message=ToolInvokeMessage.TextMessage(text=text),
         )
 
-    def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage:
+    def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
         """
         create a blob message
 
@@ -209,7 +212,6 @@ class Tool(ABC):
             type=ToolInvokeMessage.MessageType.BLOB,
             message=ToolInvokeMessage.BlobMessage(blob=blob),
             meta=meta,
-            save_as=save_as,
         )
 
     def create_json_message(self, object: dict) -> ToolInvokeMessage:

Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 3 - 0
api/core/tools/builtin_tool/providers/audio/_assets/icon.svg


+ 6 - 0
api/core/tools/builtin_tool/providers/audio/audio.py

@@ -0,0 +1,6 @@
+from core.tools.builtin_tool.provider import BuiltinToolProviderController
+
+
+class AudioToolProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict) -> None:
+        pass

+ 11 - 0
api/core/tools/builtin_tool/providers/audio/audio.yaml

@@ -0,0 +1,11 @@
+identity:
+  author: hjlarry
+  name: audio
+  label:
+    en_US: Audio
+  description:
+    en_US: A tool for tts and asr.
+    zh_Hans: 一个用于文本转语音和语音转文本的工具。
+  icon: icon.svg
+  tags:
+    - utilities

+ 71 - 0
api/core/tools/builtin_tool/providers/audio/tools/asr.py

@@ -0,0 +1,71 @@
+import io
+from collections.abc import Generator
+from typing import Any
+
+from core.file.enums import FileType
+from core.file.file_manager import download
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.tools.builtin_tool.tool import BuiltinTool
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
+from services.model_provider_service import ModelProviderService
+
+
+class ASRTool(BuiltinTool):
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
+        file = tool_parameters.get("audio_file")
+        if file.type != FileType.AUDIO:  # type: ignore
+            yield self.create_text_message("not a valid audio file")
+            return
+        audio_binary = io.BytesIO(download(file))  # type: ignore
+        audio_binary.name = "temp.mp3"
+        provider, model = tool_parameters.get("model").split("#")  # type: ignore
+        model_manager = ModelManager()
+        model_instance = model_manager.get_model_instance(
+            tenant_id=self.runtime.tenant_id,
+            provider=provider,
+            model_type=ModelType.SPEECH2TEXT,
+            model=model,
+        )
+        text = model_instance.invoke_speech2text(
+            file=audio_binary,
+            user=user_id,
+        )
+        yield self.create_text_message(text)
+
+    def get_available_models(self) -> list[tuple[str, str]]:
+        model_provider_service = ModelProviderService()
+        models = model_provider_service.get_models_by_model_type(
+            tenant_id=self.runtime.tenant_id, model_type="speech2text"
+        )
+        items = []
+        for provider_model in models:
+            provider = provider_model.provider
+            for model in provider_model.models:
+                items.append((provider, model.model))
+        return items
+
+    def get_runtime_parameters(self) -> list[ToolParameter]:
+        parameters = []
+
+        options = []
+        for provider, model in self.get_available_models():
+            option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
+            options.append(option)
+
+        parameters.append(
+            ToolParameter(
+                name="model",
+                label=I18nObject(en_US="Model", zh_Hans="Model"),
+                human_description=I18nObject(
+                    en_US="All available ASR models. You can config model in the Model Provider of Settings.",
+                    zh_Hans="所有可用的 ASR 模型。你可以在设置中的模型供应商里配置。",
+                ),
+                type=ToolParameter.ToolParameterType.SELECT,
+                form=ToolParameter.ToolParameterForm.FORM,
+                required=True,
+                options=options,
+            )
+        )
+        return parameters

+ 22 - 0
api/core/tools/builtin_tool/providers/audio/tools/asr.yaml

@@ -0,0 +1,22 @@
+identity:
+  name: asr
+  author: hjlarry
+  label:
+    en_US: Speech To Text
+description:
+  human:
+    en_US: Convert audio file to text.
+    zh_Hans: 将音频文件转换为文本。
+  llm: Convert audio file to text.
+parameters:
+  - name: audio_file
+    type: file
+    required: true
+    label:
+      en_US: Audio File
+      zh_Hans: 音频文件
+    human_description:
+      en_US: The audio file to be converted.
+      zh_Hans: 要转换的音频文件。
+    llm_description: The audio file to be converted.
+    form: llm

+ 87 - 0
api/core/tools/builtin_tool/providers/audio/tools/tts.py

@@ -0,0 +1,87 @@
+import io
+from collections.abc import Generator
+from typing import Any
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
+from core.tools.builtin_tool.tool import BuiltinTool
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
+from services.model_provider_service import ModelProviderService
+
+
+class TTSTool(BuiltinTool):
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
+        provider, model = tool_parameters.get("model").split("#")  # type: ignore
+        voice = tool_parameters.get(f"voice#{provider}#{model}")
+        model_manager = ModelManager()
+        model_instance = model_manager.get_model_instance(
+            tenant_id=self.runtime.tenant_id,
+            provider=provider,
+            model_type=ModelType.TTS,
+            model=model,
+        )
+        tts = model_instance.invoke_tts(
+            content_text=tool_parameters.get("text"),  # type: ignore
+            user=user_id,
+            tenant_id=self.runtime.tenant_id,
+            voice=voice,  # type: ignore
+        )
+        buffer = io.BytesIO()
+        for chunk in tts:
+            buffer.write(chunk)
+
+        wav_bytes = buffer.getvalue()
+        yield self.create_text_message("Audio generated successfully")
+        yield self.create_blob_message(
+            blob=wav_bytes,
+            meta={"mime_type": "audio/x-wav"},
+        )
+
+    def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
+        model_provider_service = ModelProviderService()
+        models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
+        items = []
+        for provider_model in models:
+            provider = provider_model.provider
+            for model in provider_model.models:
+                voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
+                items.append((provider, model.model, voices))
+        return items
+
+    def get_runtime_parameters(self) -> list[ToolParameter]:
+        parameters = []
+
+        options = []
+        for provider, model, voices in self.get_available_models():
+            option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
+            options.append(option)
+            parameters.append(
+                ToolParameter(
+                    name=f"voice#{provider}#{model}",
+                    label=I18nObject(en_US=f"Voice of {model}({provider})"),
+                    type=ToolParameter.ToolParameterType.SELECT,
+                    form=ToolParameter.ToolParameterForm.FORM,
+                    options=[
+                        ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
+                        for voice in voices
+                    ],
+                )
+            )
+
+        parameters.insert(
+            0,
+            ToolParameter(
+                name="model",
+                label=I18nObject(en_US="Model", zh_Hans="Model"),
+                human_description=I18nObject(
+                    en_US="All available TTS models. You can config model in the Model Provider of Settings.",
+                    zh_Hans="所有可用的 TTS 模型。你可以在设置中的模型供应商里配置。",
+                ),
+                type=ToolParameter.ToolParameterType.SELECT,
+                form=ToolParameter.ToolParameterForm.FORM,
+                required=True,
+                options=options,
+            ),
+        )
+        return parameters

+ 22 - 0
api/core/tools/builtin_tool/providers/audio/tools/tts.yaml

@@ -0,0 +1,22 @@
+identity:
+  name: tts
+  author: hjlarry
+  label:
+    en_US: Text To Speech
+description:
+  human:
+    en_US: Convert text to audio file.
+    zh_Hans: 将文本转换为音频文件。
+  llm: Convert text to audio file.
+parameters:
+  - name: text
+    type: string
+    required: true
+    label:
+      en_US: Text
+      zh_Hans: 文本
+    human_description:
+      en_US: The text to be converted.
+      zh_Hans: 要转换的文本。
+    llm_description: The text to be converted.
+    form: llm

Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 3 - 0
api/core/tools/builtin_tool/providers/webscraper/_assets/icon.svg


+ 36 - 0
api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.py

@@ -0,0 +1,36 @@
+from collections.abc import Generator
+from typing import Any
+
+from core.tools.builtin_tool.tool import BuiltinTool
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.errors import ToolInvokeError
+from core.tools.utils.web_reader_tool import get_url
+
+
+class WebscraperTool(BuiltinTool):
+    def _invoke(
+        self,
+        user_id: str,
+        tool_parameters: dict[str, Any],
+    ) -> Generator[ToolInvokeMessage, None, None]:
+        """
+        invoke tools
+        """
+        try:
+            url = tool_parameters.get("url", "")
+            user_agent = tool_parameters.get("user_agent", "")
+            if not url:
+                yield self.create_text_message("Please input url")
+                return
+
+            # get webpage
+            result = get_url(url, user_agent=user_agent)
+
+            if tool_parameters.get("generate_summary"):
+                # summarize and return
+                yield self.create_text_message(self.summary(user_id=user_id, content=result))
+            else:
+                # return full webpage
+                yield self.create_text_message(result)
+        except Exception as e:
+            raise ToolInvokeError(str(e))

+ 60 - 0
api/core/tools/builtin_tool/providers/webscraper/tools/webscraper.yaml

@@ -0,0 +1,60 @@
+identity:
+  name: webscraper
+  author: Dify
+  label:
+    en_US: Web Scraper
+    zh_Hans: 网页爬虫
+    pt_BR: Web Scraper
+description:
+  human:
+    en_US: A tool for scraping webpages.
+    zh_Hans: 一个用于爬取网页的工具。
+    pt_BR: A tool for scraping webpages.
+  llm: A tool for scraping webpages. Input should be a URL.
+parameters:
+  - name: url
+    type: string
+    required: true
+    label:
+      en_US: URL
+      zh_Hans: 网页链接
+      pt_BR: URL
+    human_description:
+      en_US: used for linking to webpages
+      zh_Hans: 用于链接到网页
+      pt_BR: used for linking to webpages
+    llm_description: url for scraping
+    form: llm
+  - name: user_agent
+    type: string
+    required: false
+    label:
+      en_US: User Agent
+      zh_Hans: User Agent
+      pt_BR: User Agent
+    human_description:
+      en_US: used for identifying the browser.
+      zh_Hans: 用于识别浏览器。
+      pt_BR: used for identifying the browser.
+    form: form
+    default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36
+  - name: generate_summary
+    type: boolean
+    required: false
+    label:
+      en_US: Whether to generate summary
+      zh_Hans: 是否生成摘要
+    human_description:
+      en_US: If true, the crawler will only return the page summary content.
+      zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。
+    form: form
+    options:
+      - value: "true"
+        label:
+          en_US: "Yes"
+          zh_Hans: 是
+      - value: "false"
+        label:
+          en_US: "No"
+          zh_Hans: 否
+    default: "false"

+ 8 - 0
api/core/tools/builtin_tool/providers/webscraper/webscraper.py

@@ -0,0 +1,8 @@
+from typing import Any
+
+from core.tools.builtin_tool.provider import BuiltinToolProviderController
+
+
+class WebscraperProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+        pass

+ 15 - 0
api/core/tools/builtin_tool/providers/webscraper/webscraper.yaml

@@ -0,0 +1,15 @@
+identity:
+  author: Dify
+  name: webscraper
+  label:
+    en_US: WebScraper
+    zh_Hans: 网页抓取
+    pt_BR: WebScraper
+  description:
+    en_US: Web Scrapper tool kit is used to scrape web
+    zh_Hans: 一个用于抓取网页的工具。
+    pt_BR: Web Scrapper tool kit is used to scrape web
+  icon: icon.svg
+  tags:
+    - productivity
+credentials_for_provider: []

+ 0 - 2
api/core/tools/entities/tool_entities.py

@@ -166,7 +166,6 @@ class ToolInvokeMessage(BaseModel):
     """
     message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None
     meta: dict[str, Any] | None = None
-    save_as: str = ""
 
     @field_validator("message", mode="before")
     @classmethod
@@ -188,7 +187,6 @@ class ToolInvokeMessage(BaseModel):
 class ToolInvokeMessageBinary(BaseModel):
     mimetype: str = Field(..., description="The mimetype of the binary")
     url: str = Field(..., description="The url of the binary")
-    save_as: str = ""
     file_var: Optional[dict[str, Any]] = None
 
 

+ 4 - 7
api/core/tools/tool_engine.py

@@ -49,7 +49,7 @@ class ToolEngine:
         conversation_id: Optional[str] = None,
         app_id: Optional[str] = None,
         message_id: Optional[str] = None,
-    ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
+    ) -> tuple[str, list[str], ToolInvokeMeta]:
         """
         Agent invokes the tool with the given arguments.
         """
@@ -279,7 +279,6 @@ class ToolEngine:
                 yield ToolInvokeMessageBinary(
                     mimetype=response.meta.get("mime_type", "image/jpeg"),
                     url=cast(ToolInvokeMessage.TextMessage, response.message).text,
-                    save_as=response.save_as,
                 )
             elif response.type == ToolInvokeMessage.MessageType.BLOB:
                 if not response.meta:
@@ -288,7 +287,6 @@ class ToolEngine:
                 yield ToolInvokeMessageBinary(
                     mimetype=response.meta.get("mime_type", "octet/stream"),
                     url=cast(ToolInvokeMessage.TextMessage, response.message).text,
-                    save_as=response.save_as,
                 )
             elif response.type == ToolInvokeMessage.MessageType.LINK:
                 # check if there is a mime type in meta
@@ -296,7 +294,6 @@ class ToolEngine:
                     yield ToolInvokeMessageBinary(
                         mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream",
                         url=cast(ToolInvokeMessage.TextMessage, response.message).text,
-                        save_as=response.save_as,
                     )
 
     @staticmethod
@@ -305,12 +302,12 @@ class ToolEngine:
         agent_message: Message,
         invoke_from: InvokeFrom,
         user_id: str,
-    ) -> list[tuple[MessageFile, str]]:
+    ) -> list[str]:
         """
         Create message file
 
         :param messages: messages
-        :return: message files, should save as variable
+        :return: message file ids
         """
         result = []
 
@@ -347,7 +344,7 @@ class ToolEngine:
             db.session.commit()
             db.session.refresh(message_file)
 
-            result.append((message_file.id, message.save_as))
+            result.append(message_file.id)
 
         db.session.close()
 

+ 0 - 6
api/core/tools/utils/message_transformer.py

@@ -44,7 +44,6 @@ class ToolFileMessageTransformer:
                     yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                         message=ToolInvokeMessage.TextMessage(text=url),
-                        save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
                     )
                 except Exception as e:
@@ -54,7 +53,6 @@ class ToolFileMessageTransformer:
                             text=f"Failed to download image: {message.message.text}: {e}"
                         ),
                         meta=message.meta.copy() if message.meta is not None else {},
-                        save_as=message.save_as,
                     )
             elif message.type == ToolInvokeMessage.MessageType.BLOB:
                 # get mime type and save blob to storage
@@ -83,14 +81,12 @@ class ToolFileMessageTransformer:
                     yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                         message=ToolInvokeMessage.TextMessage(text=url),
-                        save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
                     )
                 else:
                     yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.LINK,
                         message=ToolInvokeMessage.TextMessage(text=url),
-                        save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
                     )
             elif message.type == ToolInvokeMessage.MessageType.FILE:
@@ -104,14 +100,12 @@ class ToolFileMessageTransformer:
                             yield ToolInvokeMessage(
                                 type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                                 message=ToolInvokeMessage.TextMessage(text=url),
-                                save_as=message.save_as,
                                 meta=message.meta.copy() if message.meta is not None else {},
                             )
                         else:
                             yield ToolInvokeMessage(
                                 type=ToolInvokeMessage.MessageType.LINK,
                                 message=ToolInvokeMessage.TextMessage(text=url),
-                                save_as=message.save_as,
                                 meta=message.meta.copy() if message.meta is not None else {},
                             )
                     else:

+ 374 - 0
api/core/tools/utils/web_reader_tool.py

@@ -0,0 +1,374 @@
+import hashlib
+import json
+import mimetypes
+import os
+import re
+import site
+import subprocess
+import tempfile
+import unicodedata
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Optional
+from urllib.parse import unquote
+
+import chardet
+import cloudscraper
+from bs4 import BeautifulSoup, CData, Comment, NavigableString
+from regex import regex
+
+from core.helper import ssrf_proxy
+from core.rag.extractor import extract_processor
+from core.rag.extractor.extract_processor import ExtractProcessor
+
+FULL_TEMPLATE = """
+TITLE: {title}
+AUTHORS: {authors}
+PUBLISH DATE: {publish_date}
+TOP_IMAGE_URL: {top_image}
+TEXT:
+
+{text}
+"""
+
+
+def page_result(text: str, cursor: int, max_length: int) -> str:
+    """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
+    return text[cursor : cursor + max_length]
+
+
+def get_url(url: str, user_agent: Optional[str] = None) -> str:
+    """Fetch URL and return the contents as a string."""
+    headers = {
+        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
+        " Chrome/91.0.4472.124 Safari/537.36"
+    }
+    if user_agent:
+        headers["User-Agent"] = user_agent
+
+    main_content_type = None
+    supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
+    response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10))
+
+    if response.status_code == 200:
+        # check content-type
+        content_type = response.headers.get("Content-Type")
+        if content_type:
+            main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
+        else:
+            content_disposition = response.headers.get("Content-Disposition", "")
+            filename_match = re.search(r'filename="([^"]+)"', content_disposition)
+            if filename_match:
+                filename = unquote(filename_match.group(1))
+                extension = re.search(r"\.(\w+)$", filename)
+                if extension:
+                    main_content_type = mimetypes.guess_type(filename)[0]
+
+        if main_content_type not in supported_content_types:
+            return "Unsupported content-type [{}] of URL.".format(main_content_type)
+
+        if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
+            return ExtractProcessor.load_from_url(url, return_text=True)
+
+        response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
+    elif response.status_code == 403:
+        scraper = cloudscraper.create_scraper()
+        scraper.perform_request = ssrf_proxy.make_request
+        response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
+
+    if response.status_code != 200:
+        return "URL returned status code {}.".format(response.status_code)
+
+    # Detect encoding using chardet
+    detected_encoding = chardet.detect(response.content)
+    encoding = detected_encoding["encoding"]
+    if encoding:
+        try:
+            content = response.content.decode(encoding)
+        except (UnicodeDecodeError, TypeError):
+            content = response.text
+    else:
+        content = response.text
+
+    a = extract_using_readabilipy(content)
+
+    if not a["plain_text"] or not a["plain_text"].strip():
+        return ""
+
+    res = FULL_TEMPLATE.format(
+        title=a["title"],
+        authors=a["byline"],
+        publish_date=a["date"],
+        top_image="",
+        text=a["plain_text"] or "",
+    )
+
+    return res
+
+
+def extract_using_readabilipy(html):
+    with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
+        f_html.write(html)
+        f_html.close()
+    html_path = f_html.name
+
+    # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
+    article_json_path = html_path + ".json"
+    jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
+    with chdir(jsdir):
+        subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
+
+    # Read output of call to Readability.parse() from JSON file and return as Python dictionary
+    input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))
+
+    # Deleting files after processing
+    os.unlink(article_json_path)
+    os.unlink(html_path)
+
+    article_json = {
+        "title": None,
+        "byline": None,
+        "date": None,
+        "content": None,
+        "plain_content": None,
+        "plain_text": None,
+    }
+    # Populate article fields from readability fields where present
+    if input_json:
+        if input_json.get("title"):
+            article_json["title"] = input_json["title"]
+        if input_json.get("byline"):
+            article_json["byline"] = input_json["byline"]
+        if input_json.get("date"):
+            article_json["date"] = input_json["date"]
+        if input_json.get("content"):
+            article_json["content"] = input_json["content"]
+            article_json["plain_content"] = plain_content(article_json["content"], False, False)
+            article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
+        if input_json.get("textContent"):
+            article_json["plain_text"] = input_json["textContent"]
+            article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
+
+    return article_json
+
+
+def find_module_path(module_name):
+    for package_path in site.getsitepackages():
+        potential_path = os.path.join(package_path, module_name)
+        if os.path.exists(potential_path):
+            return potential_path
+
+    return None
+
+
+@contextmanager
+def chdir(path):
+    """Change directory in context and return to original on exit"""
+    # From https://stackoverflow.com/a/37996581, couldn't find a built-in
+    original_path = os.getcwd()
+    os.chdir(path)
+    try:
+        yield
+    finally:
+        os.chdir(original_path)
+
+
+def extract_text_blocks_as_plain_text(paragraph_html):
+    # Load article as DOM
+    soup = BeautifulSoup(paragraph_html, "html.parser")
+    # Select all lists
+    list_elements = soup.find_all(["ul", "ol"])
+    # Prefix text in all list items with "* " and make lists paragraphs
+    for list_element in list_elements:
+        plain_items = "".join(
+            list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
+        )
+        list_element.string = plain_items
+        list_element.name = "p"
+    # Select all text blocks
+    text_blocks = [s.parent for s in soup.find_all(string=True)]
+    text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
+    # Drop empty paragraphs
+    text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
+    return text_blocks
+
+
+def plain_text_leaf_node(element):
+    # Extract all text, stripped of any child HTML elements and normalize it
+    plain_text = normalize_text(element.get_text())
+    if plain_text != "" and element.name == "li":
+        plain_text = "* {}, ".format(plain_text)
+    if plain_text == "":
+        plain_text = None
+    if "data-node-index" in element.attrs:
+        plain = {"node_index": element["data-node-index"], "text": plain_text}
+    else:
+        plain = {"text": plain_text}
+    return plain
+
+
+def plain_content(readability_content, content_digests, node_indexes):
+    # Load article as DOM
+    soup = BeautifulSoup(readability_content, "html.parser")
+    # Make all elements plain
+    elements = plain_elements(soup.contents, content_digests, node_indexes)
+    if node_indexes:
+        # Add node index attributes to nodes
+        elements = [add_node_indexes(element) for element in elements]
+    # Replace article contents with plain elements
+    soup.contents = elements
+    return str(soup)
+
+
+def plain_elements(elements, content_digests, node_indexes):
+    # Get plain content versions of all elements
+    elements = [plain_element(element, content_digests, node_indexes) for element in elements]
+    if content_digests:
+        # Add content digest attribute to nodes
+        elements = [add_content_digest(element) for element in elements]
+    return elements
+
+
+def plain_element(element, content_digests, node_indexes):
+    # For lists, we make each item plain text
+    if is_leaf(element):
+        # For leaf node elements, extract the text content, discarding any HTML tags
+        # 1. Get element contents as text
+        plain_text = element.get_text()
+        # 2. Normalize the extracted text string to a canonical representation
+        plain_text = normalize_text(plain_text)
+        # 3. Update element content to be plain text
+        element.string = plain_text
+    elif is_text(element):
+        if is_non_printing(element):
+            # The simplified HTML may have come from Readability.js so might
+            # have non-printing text (e.g. Comment or CData). In this case, we
+            # keep the structure, but ensure that the string is empty.
+            element = type(element)("")
+        else:
+            plain_text = element.string
+            plain_text = normalize_text(plain_text)
+            element = type(element)(plain_text)
+    else:
+        # If not a leaf node or leaf type call recursively on child nodes, replacing
+        element.contents = plain_elements(element.contents, content_digests, node_indexes)
+    return element
+
+
+def add_node_indexes(element, node_index="0"):
+    # Can't add attributes to string types
+    if is_text(element):
+        return element
+    # Add index to current element
+    element["data-node-index"] = node_index
+    # Add index to child elements
+    for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
+        # Can't add attributes to leaf string types
+        child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
+        add_node_indexes(child, node_index=child_index)
+    return element
+
+
+def normalize_text(text):
+    """Normalize unicode and whitespace."""
+    # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
+    text = strip_control_characters(text)
+    text = normalize_unicode(text)
+    text = normalize_whitespace(text)
+    return text
+
+
+def strip_control_characters(text):
+    """Strip out unicode control characters which might break the parsing."""
+    # Unicode control characters
+    #   [Cc]: Other, Control [includes new lines]
+    #   [Cf]: Other, Format
+    #   [Cn]: Other, Not Assigned
+    #   [Co]: Other, Private Use
+    #   [Cs]: Other, Surrogate
+    control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
+    retained_chars = ["\t", "\n", "\r", "\f"]
+
+    # Remove non-printing control characters
+    return "".join(
+        [
+            "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
+            for char in text
+        ]
+    )
+
+
+def normalize_unicode(text):
+    """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
+    normal_form = "NFKC"
+    text = unicodedata.normalize(normal_form, text)
+    return text
+
+
+def normalize_whitespace(text):
+    """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
+    text = regex.sub(r"\s+", " ", text)
+    # Remove leading and trailing whitespace
+    text = text.strip()
+    return text
+
+
+def is_leaf(element):
+    return element.name in {"p", "li"}
+
+
+def is_text(element):
+    return isinstance(element, NavigableString)
+
+
+def is_non_printing(element):
+    return any(isinstance(element, _e) for _e in [Comment, CData])
+
+
+def add_content_digest(element):
+    if not is_text(element):
+        element["data-content-digest"] = content_digest(element)
+    return element
+
+
+def content_digest(element):
+    if is_text(element):
+        # Hash
+        trimmed_string = element.string.strip()
+        if trimmed_string == "":
+            digest = ""
+        else:
+            digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
+    else:
+        contents = element.contents
+        num_contents = len(contents)
+        if num_contents == 0:
+            # No hash when no child elements exist
+            digest = ""
+        elif num_contents == 1:
+            # If single child, use digest of child
+            digest = content_digest(contents[0])
+        else:
+            # Build content digest from the "non-empty" digests of child nodes
+            digest = hashlib.sha256()
+            child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
+            for child in child_digests:
+                digest.update(child.encode("utf-8"))
+            digest = digest.hexdigest()
+    return digest
+
+
+def get_image_upload_file_ids(content):
+    pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)"
+    matches = re.findall(pattern, content)
+    image_upload_file_ids = []
+    for match in matches:
+        if match[1] == "file-preview":
+            content_pattern = r"files/([^/]+)/file-preview"
+        else:
+            content_pattern = r"files/([^/]+)/image-preview"
+        content_match = re.search(content_pattern, match[0])
+        if content_match:
+            image_upload_file_id = content_match.group(1)
+            image_upload_file_ids.append(image_upload_file_id)
+    return image_upload_file_ids

+ 1 - 3
api/core/workflow/nodes/tool/tool_node.py

@@ -1,5 +1,4 @@
 from collections.abc import Generator, Mapping, Sequence
-from os import path
 from typing import Any, cast
 
 from sqlalchemy import select
@@ -236,8 +235,7 @@ class ToolNode(BaseNode[ToolNodeData]):
                         type=FileType.IMAGE,
                         transfer_method=FileTransferMethod.TOOL_FILE,
                         related_id=tool_file_id,
-                        filename=message.save_as,
-                        extension=path.splitext(message.save_as)[1],
+                        extension=None,
                         mime_type=message.meta.get("mime_type", "application/octet-stream"),
                     )
                 )