|  | @@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject
 | 
	
		
			
				|  |  |  from core.tools.errors import ToolProviderCredentialValidationError
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from typing import Any, Dict, List, Union
 | 
	
		
			
				|  |  | -from httpx import post
 | 
	
		
			
				|  |  | +from httpx import post, get
 | 
	
		
			
				|  |  |  from os.path import join
 | 
	
		
			
				|  |  |  from base64 import b64decode, b64encode
 | 
	
		
			
				|  |  |  from PIL import Image
 | 
	
	
		
			
				|  | @@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = {
 | 
	
		
			
				|  |  |      "alwayson_scripts": {}
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class StableDiffusionTool(BuiltinTool):
 | 
	
		
			
				|  |  |      def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
 | 
	
		
			
				|  |  |          -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
 | 
	
	
		
			
				|  | @@ -136,7 +137,31 @@ class StableDiffusionTool(BuiltinTool):
 | 
	
		
			
				|  |  |                               width=width,
 | 
	
		
			
				|  |  |                               height=height,
 | 
	
		
			
				|  |  |                               steps=steps)
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +            validate models
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            base_url = self.runtime.credentials.get('base_url', None)
 | 
	
		
			
				|  |  | +            if not base_url:
 | 
	
		
			
				|  |  | +                raise ToolProviderCredentialValidationError('Please input base_url')
 | 
	
		
			
				|  |  | +            model = self.runtime.credentials.get('model', None)
 | 
	
		
			
				|  |  | +            if not model:
 | 
	
		
			
				|  |  | +                raise ToolProviderCredentialValidationError('Please input model')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
 | 
	
		
			
				|  |  | +            if response.status_code != 200:
 | 
	
		
			
				|  |  | +                raise ToolProviderCredentialValidationError('Failed to get models')
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                models = [d['model_name'] for d in response.json()]
 | 
	
		
			
				|  |  | +                if len([d for d in models if d == model]) > 0:
 | 
	
		
			
				|  |  | +                    return self.create_text_message(json.dumps(models))
 | 
	
		
			
				|  |  | +                else:
 | 
	
		
			
				|  |  | +                    raise ToolProviderCredentialValidationError(f'model {model} does not exist')
 | 
	
		
			
				|  |  | +        except Exception as e:
 | 
	
		
			
				|  |  | +            raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def img2img(self, base_url: str, lora: str, image_binary: bytes, 
 | 
	
		
			
				|  |  |                  prompt: str, negative_prompt: str,
 | 
	
		
			
				|  |  |                  width: int, height: int, steps: int) \
 | 
	
	
		
			
				|  | @@ -211,10 +236,9 @@ class StableDiffusionTool(BuiltinTool):
 | 
	
		
			
				|  |  |          except Exception as e:
 | 
	
		
			
				|  |  |              return self.create_text_message('Failed to generate image')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      def get_runtime_parameters(self) -> List[ToolParameter]:
 | 
	
		
			
				|  |  |          parameters = [
 | 
	
		
			
				|  |  | -            ToolParameter(name='prompt', 
 | 
	
		
			
				|  |  | +            ToolParameter(name='prompt',
 | 
	
		
			
				|  |  |                           label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
 | 
	
		
			
				|  |  |                           human_description=I18nObject(
 | 
	
		
			
				|  |  |                               en_US='Image prompt, you can check the official documentation of Stable Diffusion',
 | 
	
	
		
			
				|  | @@ -227,7 +251,7 @@ class StableDiffusionTool(BuiltinTool):
 | 
	
		
			
				|  |  |          ]
 | 
	
		
			
				|  |  |          if len(self.list_default_image_variables()) != 0:
 | 
	
		
			
				|  |  |              parameters.append(
 | 
	
		
			
				|  |  | -                ToolParameter(name='image_id', 
 | 
	
		
			
				|  |  | +                ToolParameter(name='image_id',
 | 
	
		
			
				|  |  |                               label=I18nObject(en_US='image_id', zh_Hans='image_id'),
 | 
	
		
			
				|  |  |                               human_description=I18nObject(
 | 
	
		
			
				|  |  |                                  en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',
 |