| 
					
				 | 
			
			
				@@ -1,27 +1,45 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import decimal 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from typing import List, Optional, Any 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import openai 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from langchain.callbacks.manager import Callbacks 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from langchain.llms import ChatGLM 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from langchain.schema import LLMResult 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from langchain.schema import LLMResult, get_buffer_string 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from core.model_providers.error import LLMBadRequestError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_providers.error import LLMBadRequestError, LLMRateLimitError, LLMAuthorizationError, \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    LLMAPIUnavailableError, LLMAPIConnectionError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_providers.models.llm.base import BaseLLM 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_providers.models.entity.message import PromptMessage, MessageType 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class ChatGLMModel(BaseLLM): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    model_mode: ModelMode = ModelMode.COMPLETION 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    model_mode: ModelMode = ModelMode.CHAT 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _init_client(self) -> Any: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return ChatGLM( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        extra_model_kwargs = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'top_p': provider_model_kwargs.get('top_p') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if provider_model_kwargs.get('max_length') is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            extra_model_kwargs['max_length'] = provider_model_kwargs.get('max_length') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        client = EnhanceChatOpenAI( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_name=self.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            temperature=provider_model_kwargs.get('temperature'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            max_tokens=provider_model_kwargs.get('max_tokens'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_kwargs=extra_model_kwargs, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            streaming=self.streaming, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             callbacks=self.callbacks, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            endpoint_url=self.credentials.get('api_base'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            **provider_model_kwargs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            request_timeout=60, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            openai_api_key="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            openai_api_base=self.credentials['api_base'] + '/v1' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return client 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _run(self, messages: List[PromptMessage], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				              stop: Optional[List[str]] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				              callbacks: Callbacks = None, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         :return: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         prompts = self._get_prompt_from_messages(messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return max(self._client.get_num_tokens(prompts), 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def get_currency(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return 'RMB' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _set_model_kwargs(self, model_kwargs: ModelKwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for k, v in provider_model_kwargs.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if hasattr(self.client, k): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                setattr(self.client, k, v) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        extra_model_kwargs = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'top_p': provider_model_kwargs.get('top_p') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.client.temperature = provider_model_kwargs.get('temperature') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.client.max_tokens = provider_model_kwargs.get('max_tokens') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.client.model_kwargs = extra_model_kwargs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def handle_exceptions(self, ex: Exception) -> Exception: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if isinstance(ex, ValueError): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return LLMBadRequestError(f"ChatGLM: {str(ex)}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if isinstance(ex, openai.error.InvalidRequestError): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logging.warning("Invalid request to ChatGLM API.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return LLMBadRequestError(str(ex)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(ex, openai.error.APIConnectionError): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logging.warning("Failed to connect to ChatGLM API.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logging.warning("ChatGLM service unavailable.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(ex, openai.error.RateLimitError): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return LLMRateLimitError(str(ex)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(ex, openai.error.AuthenticationError): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return LLMAuthorizationError(str(ex)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif isinstance(ex, openai.error.OpenAIError): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return ex 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @classmethod 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def support_streaming(cls): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return True 
			 |