|  | @@ -0,0 +1,66 @@
 | 
	
		
			
				|  |  | +from typing import Any, Dict, List, Union
 | 
	
		
			
				|  |  | +from core.tools.entities.tool_entities import ToolInvokeMessage
 | 
	
		
			
				|  |  | +from core.tools.tool.builtin_tool import BuiltinTool
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from base64 import b64decode
 | 
	
		
			
				|  |  | +from os.path import join
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from openai import AzureOpenAI
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class DallE3Tool(BuiltinTool):
 | 
	
		
			
				|  |  | +    def _invoke(self, 
 | 
	
		
			
				|  |  | +                user_id: str, 
 | 
	
		
			
				|  |  | +               tool_paramters: Dict[str, Any], 
 | 
	
		
			
				|  |  | +        ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +            invoke tools
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        client = AzureOpenAI(
 | 
	
		
			
				|  |  | +            api_version=self.runtime.credentials['azure_openai_api_version'],
 | 
	
		
			
				|  |  | +            azure_endpoint=self.runtime.credentials['azure_openai_base_url'],
 | 
	
		
			
				|  |  | +            api_key=self.runtime.credentials['azure_openai_api_key'],
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        SIZE_MAPPING = {
 | 
	
		
			
				|  |  | +            'square': '1024x1024',
 | 
	
		
			
				|  |  | +            'vertical': '1024x1792',
 | 
	
		
			
				|  |  | +            'horizontal': '1792x1024',
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # prompt
 | 
	
		
			
				|  |  | +        prompt = tool_paramters.get('prompt', '')
 | 
	
		
			
				|  |  | +        if not prompt:
 | 
	
		
			
				|  |  | +            return self.create_text_message('Please input prompt')
 | 
	
		
			
				|  |  | +        # get size
 | 
	
		
			
				|  |  | +        size = SIZE_MAPPING[tool_paramters.get('size', 'square')]
 | 
	
		
			
				|  |  | +        # get n
 | 
	
		
			
				|  |  | +        n = tool_paramters.get('n', 1)
 | 
	
		
			
				|  |  | +        # get quality
 | 
	
		
			
				|  |  | +        quality = tool_paramters.get('quality', 'standard')
 | 
	
		
			
				|  |  | +        if quality not in ['standard', 'hd']:
 | 
	
		
			
				|  |  | +            return self.create_text_message('Invalid quality')
 | 
	
		
			
				|  |  | +        # get style
 | 
	
		
			
				|  |  | +        style = tool_paramters.get('style', 'vivid')
 | 
	
		
			
				|  |  | +        if style not in ['natural', 'vivid']:
 | 
	
		
			
				|  |  | +            return self.create_text_message('Invalid style')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # call openapi dalle3
 | 
	
		
			
				|  |  | +        model=self.runtime.credentials['azure_openai_api_model_name']
 | 
	
		
			
				|  |  | +        response = client.images.generate(
 | 
	
		
			
				|  |  | +            prompt=prompt,
 | 
	
		
			
				|  |  | +            model=model,
 | 
	
		
			
				|  |  | +            size=size,
 | 
	
		
			
				|  |  | +            n=n,
 | 
	
		
			
				|  |  | +            style=style,
 | 
	
		
			
				|  |  | +            quality=quality,
 | 
	
		
			
				|  |  | +            response_format='b64_json'
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        result = []
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        for image in response.data:
 | 
	
		
			
				|  |  | +            result.append(self.create_blob_message(blob=b64decode(image.b64_json), 
 | 
	
		
			
				|  |  | +                                                   meta={ 'mime_type': 'image/png' },
 | 
	
		
			
				|  |  | +                                                    save_as=self.VARIABLE_KEY.IMAGE.value))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return result
 |