spark.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import json
  2. from core.tools.entities.values import ToolLabelEnum
  3. from core.tools.errors import ToolProviderCredentialValidationError
  4. from core.tools.provider.builtin.spark.tools.spark_img_generation import spark_response
  5. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  6. class SparkProvider(BuiltinToolProviderController):
  7. def _validate_credentials(self, credentials: dict) -> None:
  8. try:
  9. if "APPID" not in credentials or not credentials.get("APPID"):
  10. raise ToolProviderCredentialValidationError("APPID is required.")
  11. if "APISecret" not in credentials or not credentials.get("APISecret"):
  12. raise ToolProviderCredentialValidationError("APISecret is required.")
  13. if "APIKey" not in credentials or not credentials.get("APIKey"):
  14. raise ToolProviderCredentialValidationError("APIKey is required.")
  15. appid = credentials.get("APPID")
  16. apisecret = credentials.get("APISecret")
  17. apikey = credentials.get("APIKey")
  18. prompt = "a cute black dog"
  19. try:
  20. response = spark_response(prompt, appid, apikey, apisecret)
  21. data = json.loads(response)
  22. code = data["header"]["code"]
  23. if code == 0:
  24. # 0 success,
  25. pass
  26. else:
  27. raise ToolProviderCredentialValidationError(
  28. "image generate error, code:{}".format(code)
  29. )
  30. except Exception as e:
  31. raise ToolProviderCredentialValidationError(
  32. "APPID APISecret APIKey is invalid. {}".format(e)
  33. )
  34. except Exception as e:
  35. raise ToolProviderCredentialValidationError(str(e))
  36. def _get_tool_labels(self) -> list[ToolLabelEnum]:
  37. return [
  38. ToolLabelEnum.IMAGE
  39. ]