Parcourir la source

feat: Add tools for open weather search and image generation using the Spark API. (#2845)

Onelevenvy il y a 1 an
Parent
commit
cb79a90031

+ 1 - 1
api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py

@@ -124,7 +124,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
             elif err == 'insufficient_quota':
                 raise InsufficientAccountBalance(msg)
             elif err == 'invalid_authentication':
-                raise InvalidAuthenticationError(msg)
+                raise InvalidAuthenticationError(msg) 
             elif err and 'rate' in err:
                 raise RateLimitReachedError(msg)
             elif err and 'internal' in err:

Fichier diff supprimé car celui-ci est trop grand
+ 12 - 0
api/core/tools/provider/builtin/openweather/_assets/icon.svg


+ 36 - 0
api/core/tools/provider/builtin/openweather/openweather.py

@@ -0,0 +1,36 @@
+import requests
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None):
+
+    url = "https://api.openweathermap.org/data/2.5/weather"
+    params = {"q": city, "appid": api_key, "units": units, "lang": language}
+
+    return requests.get(url, params=params)
+
+
+class OpenweatherProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict) -> None:
+        try:
+            if "api_key" not in credentials or not credentials.get("api_key"):
+                raise ToolProviderCredentialValidationError(
+                    "Open weather API key is required."
+                )
+            apikey = credentials.get("api_key")
+            try:
+                response = query_weather(api_key=apikey)
+                if response.status_code == 200:
+                    pass
+                else:
+                    raise ToolProviderCredentialValidationError(
+                        (response.json()).get("info")
+                    )
+            except Exception as e:
+                raise ToolProviderCredentialValidationError(
+                    "Open weather API Key is invalid. {}".format(e)
+                )
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(str(e))

+ 29 - 0
api/core/tools/provider/builtin/openweather/openweather.yaml

@@ -0,0 +1,29 @@
+identity:
+  author: Onelevenvy
+  name: openweather
+  label:
+    en_US: Open weather query
+    zh_Hans: Open Weather
+    pt_BR: Consulta de clima open weather
+  description:
+    en_US: Weather query toolkit based on Open Weather
+    zh_Hans: 基于open weather的天气查询工具包
+    pt_BR: Kit de consulta de clima baseado no Open Weather
+  icon: icon.svg
+credentials_for_provider:
+  api_key:
+    type: secret-input
+    required: true
+    label:
+      en_US: API Key
+      zh_Hans: API Key
+      pt_BR: Fogo a chave
+    placeholder:
+      en_US: Please enter your open weather API Key
+      zh_Hans: 请输入你的open weather API Key
+      pt_BR: Insira sua chave de API open weather
+    help:
+      en_US: Get your API Key from open weather
+      zh_Hans: 从open weather获取您的 API Key
+      pt_BR: Obtenha sua chave de API do open weather
+    url: https://openweathermap.org

+ 60 - 0
api/core/tools/provider/builtin/openweather/tools/weather.py

@@ -0,0 +1,60 @@
+import json
+from typing import Any, Union
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class OpenweatherTool(BuiltinTool):
+    def _invoke(
+        self, user_id: str, tool_parameters: dict[str, Any]
+    ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        """
+        invoke tools
+        """
+        city = tool_parameters.get("city", "")
+        if not city:
+            return self.create_text_message("Please tell me your city")
+        if (
+            "api_key" not in self.runtime.credentials
+            or not self.runtime.credentials.get("api_key")
+        ):
+            return self.create_text_message("OpenWeather API key is required.")
+
+        units = tool_parameters.get("units", "metric")
+        lang = tool_parameters.get("lang", "zh_cn")
+        try:
+            # request URL
+            url = "https://api.openweathermap.org/data/2.5/weather"
+
+            # request parmas
+            params = {
+                "q": city,
+                "appid": self.runtime.credentials.get("api_key"),
+                "units": units,
+                "lang": lang,
+            }
+            response = requests.get(url, params=params)
+
+            if response.status_code == 200:
+
+                data = response.json()
+                return self.create_text_message(
+                    self.summary(
+                        user_id=user_id, content=json.dumps(data, ensure_ascii=False)
+                    )
+                )
+            else:
+                error_message = {
+                    "error": f"failed:{response.status_code}",
+                    "data": response.text,
+                }
+                # return error
+                return json.dumps(error_message)
+
+        except Exception as e:
+            return self.create_text_message(
+                "Openweather API Key is invalid. {}".format(e)
+            )

+ 80 - 0
api/core/tools/provider/builtin/openweather/tools/weather.yaml

@@ -0,0 +1,80 @@
+identity:
+  name: weather
+  author: Onelevenvy
+  label:
+    en_US: Open Weather Query
+    zh_Hans: 天气查询
+    pt_BR: Previsão do tempo
+  icon: icon.svg
+description:
+  human:
+    en_US: Weather forecast inquiry
+    zh_Hans: 天气查询
+    pt_BR: Inquérito sobre previsão meteorológica
+  llm: A tool when you want to ask about the weather or weather-related question
+parameters:
+  - name: city
+    type: string
+    required: true
+    label:
+      en_US: city
+      zh_Hans: 城市
+      pt_BR: cidade
+    human_description:
+      en_US: Target city for weather forecast query
+      zh_Hans: 天气预报查询的目标城市
+      pt_BR: Cidade de destino para consulta de previsão do tempo
+    llm_description: If you don't know you can extract the city name from the
+      question or you can reply:Please tell me your city. You have to extract
+      the Chinese city name from the question.If the input region is in Chinese
+      characters for China, it should be replaced with the corresponding English
+      name, such as '北京' for correct input is 'Beijing'
+    form: llm
+  - name: lang
+    type: select
+    required: true
+    human_description:
+      en_US: language
+      zh_Hans: 语言
+      pt_BR: language
+    label:
+      en_US: language
+      zh_Hans: 语言
+      pt_BR: language
+    form: form
+    options:
+      - value: zh_cn
+        label:
+          en_US: cn
+          zh_Hans: 中国
+          pt_BR: cn
+      - value: en_us
+        label:
+          en_US: usa
+          zh_Hans: 美国
+          pt_BR: usa
+    default: zh_cn
+  - name: units
+    type: select
+    required: true
+    human_description:
+      en_US: units for temperature
+      zh_Hans: 温度单位
+      pt_BR: units for temperature
+    label:
+      en_US: units
+      zh_Hans: 单位
+      pt_BR: units
+    form: form
+    options:
+      - value: metric
+        label:
+          en_US: metric
+          zh_Hans: ℃
+          pt_BR: metric
+      - value: imperial
+        label:
+          en_US: imperial
+          zh_Hans: ℉
+          pt_BR: imperial
+    default: metric

+ 0 - 0
api/core/tools/provider/builtin/spark/__init__.py


Fichier diff supprimé car celui-ci est trop grand
+ 5 - 0
api/core/tools/provider/builtin/spark/_assets/icon.svg


+ 40 - 0
api/core/tools/provider/builtin/spark/spark.py

@@ -0,0 +1,40 @@
+import json
+
+from core.tools.errors import ToolProviderCredentialValidationError
+from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class SparkProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict) -> None:
+        try:
+            if "APPID" not in credentials or not credentials.get("APPID"):
+                raise ToolProviderCredentialValidationError("APPID is required.")
+            if "APISecret" not in credentials or not credentials.get("APISecret"):
+                raise ToolProviderCredentialValidationError("APISecret is required.")
+            if "APIKey" not in credentials or not credentials.get("APIKey"):
+                raise ToolProviderCredentialValidationError("APIKey is required.")
+
+            appid = credentials.get("APPID")
+            apisecret = credentials.get("APISecret")
+            apikey = credentials.get("APIKey")
+            prompt = "a cute black dog"
+
+            try:
+                response = spark_response(prompt, appid, apikey, apisecret)
+                data = json.loads(response)
+                code = data["header"]["code"]
+
+                if code == 0:
+                    #  0 success,
+                    pass
+                else:
+                    raise ToolProviderCredentialValidationError(
+                        "image generate error, code:{}".format(code)
+                    )
+            except Exception as e:
+                raise ToolProviderCredentialValidationError(
+                    "APPID APISecret APIKey is invalid. {}".format(e)
+                )
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(str(e))

+ 59 - 0
api/core/tools/provider/builtin/spark/spark.yaml

@@ -0,0 +1,59 @@
+identity:
+  author: Onelevenvy
+  name: spark
+  label:
+    en_US: Spark
+    zh_Hans: 讯飞星火
+    pt_BR: Spark
+  description:
+    en_US: Spark Platform Toolkit
+    zh_Hans: 讯飞星火平台工具
+    pt_BR: Pacote de Ferramentas da Plataforma Spark
+  icon: icon.svg
+credentials_for_provider:
+  APPID:
+    type: secret-input
+    required: true
+    label:
+      en_US: Spark APPID
+      zh_Hans: APPID
+      pt_BR: Spark APPID
+    help:
+      en_US: Please input your  APPID
+      zh_Hans: 请输入你的 APPID
+      pt_BR: Please input your APPID
+    placeholder:
+      en_US: Please input your APPID
+      zh_Hans: 请输入你的 APPID
+      pt_BR: Please input your APPID
+  APISecret:
+    type: secret-input
+    required: true
+    label:
+      en_US: Spark APISecret
+      zh_Hans: APISecret
+      pt_BR: Spark APISecret
+    help:
+      en_US: Please input your Spark APISecret
+      zh_Hans: 请输入你的 APISecret
+      pt_BR: Please input your Spark APISecret
+    placeholder:
+      en_US: Please input your Spark APISecret
+      zh_Hans: 请输入你的 APISecret
+      pt_BR: Please input your Spark APISecret
+  APIKey:
+    type: secret-input
+    required: true
+    label:
+      en_US: Spark APIKey
+      zh_Hans: APIKey
+      pt_BR: Spark APIKey
+    help:
+      en_US: Please input your Spark APIKey
+      zh_Hans: 请输入你的 APIKey
+      pt_BR: Please input your Spark APIKey
+    placeholder:
+      en_US: Please input your Spark APIKey
+      zh_Hans: 请输入你的 APIKey
+      pt_BR: Please input Spark APIKey
+    url: https://console.xfyun.cn/services

+ 154 - 0
api/core/tools/provider/builtin/spark/tools/spark_img_generation.py

@@ -0,0 +1,154 @@
+import base64
+import hashlib
+import hmac
+import json
+from base64 import b64decode
+from datetime import datetime
+from time import mktime
+from typing import Any, Union
+from urllib.parse import urlencode
+from wsgiref.handlers import format_date_time
+
+import requests
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class AssembleHeaderException(Exception):
+    def __init__(self, msg):
+        self.message = msg
+
+
+class Url:
+    def __init__(this, host, path, schema):
+        this.host = host
+        this.path = path
+        this.schema = schema
+
+
+# calculate sha256 and encode to base64
+def sha256base64(data):
+    sha256 = hashlib.sha256()
+    sha256.update(data)
+    digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8")
+    return digest
+
+
+def parse_url(requset_url):
+    stidx = requset_url.index("://")
+    host = requset_url[stidx + 3 :]
+    schema = requset_url[: stidx + 3]
+    edidx = host.index("/")
+    if edidx <= 0:
+        raise AssembleHeaderException("invalid request url:" + requset_url)
+    path = host[edidx:]
+    host = host[:edidx]
+    u = Url(host, path, schema)
+    return u
+
+def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
+    u = parse_url(requset_url)
+    host = u.host
+    path = u.path
+    now = datetime.now()
+    date = format_date_time(mktime(now.timetuple()))
+    signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
+        host, date, method, path
+    )
+    signature_sha = hmac.new(
+        api_secret.encode("utf-8"),
+        signature_origin.encode("utf-8"),
+        digestmod=hashlib.sha256,
+    ).digest()
+    signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
+    authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"'
+
+    authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
+        encoding="utf-8"
+    )
+    values = {"host": host, "date": date, "authorization": authorization}
+
+    return requset_url + "?" + urlencode(values)
+
+
+def get_body(appid, text):
+    body = {
+        "header": {"app_id": appid, "uid": "123456789"},
+        "parameter": {
+            "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}
+        },
+        "payload": {"message": {"text": [{"role": "user", "content": text}]}},
+    }
+    return body
+
+
+def spark_response(text, appid, apikey, apisecret):
+    host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti"
+    url = assemble_ws_auth_url(
+        host, method="POST", api_key=apikey, api_secret=apisecret
+    )
+    content = get_body(appid, text)
+    response = requests.post(
+        url, json=content, headers={"content-type": "application/json"}
+    ).text
+    return response
+
+
+class SparkImgGeneratorTool(BuiltinTool):
+    def _invoke(
+        self,
+        user_id: str,
+        tool_parameters: dict[str, Any],
+    ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        """
+        invoke tools
+        """
+
+        if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get(
+            "APPID"
+        ):
+            return self.create_text_message("APPID  is required.")
+        if (
+            "APISecret" not in self.runtime.credentials
+            or not self.runtime.credentials.get("APISecret")
+        ):
+            return self.create_text_message("APISecret  is required.")
+        if (
+            "APIKey" not in self.runtime.credentials
+            or not self.runtime.credentials.get("APIKey")
+        ):
+            return self.create_text_message("APIKey  is required.")
+
+        prompt = tool_parameters.get("prompt", "")
+        if not prompt:
+            return self.create_text_message("Please input prompt")
+        res = self.img_generation(prompt)
+        result = []
+        for image in res:
+            result.append(
+                self.create_blob_message(
+                    blob=b64decode(image["base64_image"]),
+                    meta={"mime_type": "image/png"},
+                    save_as=self.VARIABLE_KEY.IMAGE.value,
+                )
+            )
+        return result
+
+    def img_generation(self, prompt):
+        response = spark_response(
+            text=prompt,
+            appid=self.runtime.credentials.get("APPID"),
+            apikey=self.runtime.credentials.get("APIKey"),
+            apisecret=self.runtime.credentials.get("APISecret"),
+        )
+        data = json.loads(response)
+        code = data["header"]["code"]
+        if code != 0:
+            return self.create_text_message(f"error: {code}, {data}")
+        else:
+            text = data["payload"]["choices"]["text"]
+            image_content = text[0]
+            image_base = image_content["content"]
+            json_data = {"base64_image": image_base}
+        return [json_data]

+ 36 - 0
api/core/tools/provider/builtin/spark/tools/spark_img_generation.yaml

@@ -0,0 +1,36 @@
+identity:
+  name: spark_img_generation
+  author: Onelevenvy
+  label:
+    en_US: Spark Image Generation
+    zh_Hans: 图片生成
+    pt_BR: Geração de imagens Spark
+  icon: icon.svg
+  description:
+    en_US: Spark Image Generation
+    zh_Hans: 图片生成
+    pt_BR: Geração de imagens Spark
+description:
+  human:
+    en_US: Generate images based on user input, with image generation API
+      provided by Spark
+    zh_Hans: 根据用户的输入生成图片,由讯飞星火提供图片生成api
+    pt_BR: Gerar imagens com base na entrada do usuário, com API de geração
+      de imagem fornecida pela Spark
+  llm: spark_img_generation is a tool used to generate images from text
+parameters:
+  - name: prompt
+    type: string
+    required: true
+    label:
+      en_US: Prompt
+      zh_Hans: 提示词
+      pt_BR: Prompt
+    human_description:
+      en_US: Image prompt
+      zh_Hans: 图像提示词
+      pt_BR: Image prompt
+    llm_description: Image prompt of spark_img_generation tooll, you should
+      describe the image you want to generate as a list of words as possible
+      as detailed
+    form: llm

+ 1 - 1
sdks/python-client/dify_client/__init__.py

@@ -1 +1 @@
-from dify_client.client import ChatClient, CompletionClient, DifyClient
+from dify_client.client import ChatClient, CompletionClient, DifyClient