Procházet zdrojové kódy

feat: support GLM-4V (#2124)

Yeuoly před 1 rokem
rodič
revize
8394bbd47f

+ 2 - 1
api/core/model_runtime/model_providers/zhipuai/_common.py

@@ -11,7 +11,8 @@ class _CommonZhipuaiAI:
         :return:
         """
         credentials_kwargs = {
-            "api_key": credentials['api_key'],
+            "api_key": credentials['api_key'] if 'api_key' in credentials else 
+                        credentials['zhipuai_api_key'] if 'zhipuai_api_key' in credentials else None,
         }
 
         return credentials_kwargs

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 44 - 0
api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml


+ 49 - 10
api/core/model_runtime/model_providers/zhipuai/llm/llm.py

@@ -3,7 +3,8 @@ from typing import Any, Dict, Generator, List, Optional, Union
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
-                                                          PromptMessageTool, SystemPromptMessage, UserPromptMessage)
+                                                          PromptMessageTool, SystemPromptMessage, UserPromptMessage,
+                                                          TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
@@ -108,10 +109,21 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 prompt_messages = prompt_messages[1:]
 
         # resolve zhipuai model not support system message and user message, assistant message must be in sequence
-        new_prompt_messages = []
+        new_prompt_messages: List[PromptMessage] = []
         for prompt_message in prompt_messages:
             copy_prompt_message = prompt_message.copy()
             if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
+                if isinstance(copy_prompt_message.content, list):
+                    # check if model is 'glm-4v'
+                    if model != 'glm-4v':
+                        # not support list message
+                        continue
+                    # get image and 
+                    if not isinstance(copy_prompt_message, UserPromptMessage):
+                        # not support system message
+                        continue
+                    new_prompt_messages.append(copy_prompt_message)
+
                 if not isinstance(copy_prompt_message.content, str):
                     # not support image message
                     continue
@@ -130,14 +142,41 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
                 else:
                     new_prompt_messages.append(copy_prompt_message)
 
-        params = {
-            'model': model,
-            'prompt': [{
-                'role': prompt_message.role.value,
-                'content': prompt_message.content
-            } for prompt_message in new_prompt_messages],
-            **model_parameters
-        }
+        if model == 'glm-4v':
+            params = {
+                'model': model,
+                'prompt': [{
+                    'role': prompt_message.role.value,
+                    'content': 
+                        [
+                            {
+                                'type': 'text',
+                                'text': prompt_message.content
+                            }
+                        ] if isinstance(prompt_message.content, str) else 
+                        [
+                            {
+                                'type': 'image',
+                                'image_url': {
+                                    'url': content.data
+                                }
+                            } if content.type == PromptMessageContentType.IMAGE else {
+                                'type': 'text',
+                                'text': content.data
+                            } for content in prompt_message.content
+                        ],
+                } for prompt_message in new_prompt_messages],
+                **model_parameters
+            }
+        else:
+            params = {
+                'model': model,
+                'prompt': [{
+                    'role': prompt_message.role.value,
+                    'content': prompt_message.content,
+                } for prompt_message in new_prompt_messages],
+                **model_parameters
+            }
 
         if stream:
             response = client.sse_invoke(incremental=True, **params).events()