|  | @@ -0,0 +1,204 @@
 | 
	
		
			
				|  |  | +import time
 | 
	
		
			
				|  |  | +from typing import Optional
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.common_entities import I18nObject
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 | 
	
		
			
				|  |  | +from core.model_runtime.errors.invoke import (
 | 
	
		
			
				|  |  | +    InvokeAuthorizationError,
 | 
	
		
			
				|  |  | +    InvokeBadRequestError,
 | 
	
		
			
				|  |  | +    InvokeConnectionError,
 | 
	
		
			
				|  |  | +    InvokeError,
 | 
	
		
			
				|  |  | +    InvokeRateLimitError,
 | 
	
		
			
				|  |  | +    InvokeServerUnavailableError,
 | 
	
		
			
				|  |  | +)
 | 
	
		
			
				|  |  | +from core.model_runtime.errors.validate import CredentialsValidateFailedError
 | 
	
		
			
				|  |  | +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 | 
	
		
			
				|  |  | +from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +    Model class for Text Embedding Inference text embedding model.
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _invoke(
 | 
	
		
			
				|  |  | +        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
 | 
	
		
			
				|  |  | +    ) -> TextEmbeddingResult:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Invoke text embedding model
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        credentials should be like:
 | 
	
		
			
				|  |  | +        {
 | 
	
		
			
				|  |  | +            'server_url': 'server url',
 | 
	
		
			
				|  |  | +            'model_uid': 'model uid',
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: model credentials
 | 
	
		
			
				|  |  | +        :param texts: texts to embed
 | 
	
		
			
				|  |  | +        :param user: unique user id
 | 
	
		
			
				|  |  | +        :return: embeddings result
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        server_url = credentials['server_url']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if server_url.endswith('/'):
 | 
	
		
			
				|  |  | +            server_url = server_url[:-1]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # get model properties
 | 
	
		
			
				|  |  | +        context_size = self._get_context_size(model, credentials)
 | 
	
		
			
				|  |  | +        max_chunks = self._get_max_chunks(model, credentials)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        inputs = []
 | 
	
		
			
				|  |  | +        indices = []
 | 
	
		
			
				|  |  | +        used_tokens = 0
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # get tokenized results from TEI
 | 
	
		
			
				|  |  | +        batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            # Check if the number of tokens is larger than the context size
 | 
	
		
			
				|  |  | +            num_tokens = len(tokenize_result)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            if num_tokens >= context_size:
 | 
	
		
			
				|  |  | +                # Find the best cutoff point
 | 
	
		
			
				|  |  | +                pre_special_token_count = 0
 | 
	
		
			
				|  |  | +                for token in tokenize_result:
 | 
	
		
			
				|  |  | +                    if token['special']:
 | 
	
		
			
				|  |  | +                        pre_special_token_count += 1
 | 
	
		
			
				|  |  | +                    else:
 | 
	
		
			
				|  |  | +                        break
 | 
	
		
			
				|  |  | +                rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
 | 
	
		
			
				|  |  | +                token_cutoff = context_size - rest_special_token_count - 20
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                # Find the cutoff index
 | 
	
		
			
				|  |  | +                cutpoint_token = tokenize_result[token_cutoff]
 | 
	
		
			
				|  |  | +                cutoff = cutpoint_token['start']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                inputs.append(text[0: cutoff])
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                inputs.append(text)
 | 
	
		
			
				|  |  | +            indices += [i]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        batched_embeddings = []
 | 
	
		
			
				|  |  | +        _iter = range(0, len(inputs), max_chunks)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            used_tokens = 0
 | 
	
		
			
				|  |  | +            for i in _iter:
 | 
	
		
			
				|  |  | +                iter_texts = inputs[i : i + max_chunks]
 | 
	
		
			
				|  |  | +                results = TeiHelper.invoke_embeddings(server_url, iter_texts)
 | 
	
		
			
				|  |  | +                embeddings = results['data']
 | 
	
		
			
				|  |  | +                embeddings = [embedding['embedding'] for embedding in embeddings]
 | 
	
		
			
				|  |  | +                batched_embeddings.extend(embeddings)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                usage = results['usage']
 | 
	
		
			
				|  |  | +                used_tokens += usage['total_tokens']
 | 
	
		
			
				|  |  | +        except RuntimeError as e:
 | 
	
		
			
				|  |  | +            raise InvokeServerUnavailableError(str(e))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return result
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Get number of tokens for given prompt messages
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: model credentials
 | 
	
		
			
				|  |  | +        :param texts: texts to embed
 | 
	
		
			
				|  |  | +        :return:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        num_tokens = 0
 | 
	
		
			
				|  |  | +        server_url = credentials['server_url']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if server_url.endswith('/'):
 | 
	
		
			
				|  |  | +            server_url = server_url[:-1]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
 | 
	
		
			
				|  |  | +        num_tokens = sum(len(tokens) for tokens in batch_tokens)
 | 
	
		
			
				|  |  | +        return num_tokens
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def validate_credentials(self, model: str, credentials: dict) -> None:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Validate model credentials
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: model credentials
 | 
	
		
			
				|  |  | +        :return:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            server_url = credentials['server_url']
 | 
	
		
			
				|  |  | +            extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
 | 
	
		
			
				|  |  | +            print(extra_args)
 | 
	
		
			
				|  |  | +            if extra_args.model_type != 'embedding':
 | 
	
		
			
				|  |  | +                raise CredentialsValidateFailedError('Current model is not a embedding model')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            credentials['context_size'] = extra_args.max_input_length
 | 
	
		
			
				|  |  | +            credentials['max_chunks'] = extra_args.max_client_batch_size
 | 
	
		
			
				|  |  | +            self._invoke(model=model, credentials=credentials, texts=['ping'])
 | 
	
		
			
				|  |  | +        except Exception as ex:
 | 
	
		
			
				|  |  | +            raise CredentialsValidateFailedError(str(ex))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @property
 | 
	
		
			
				|  |  | +    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
 | 
	
		
			
				|  |  | +        return {
 | 
	
		
			
				|  |  | +            InvokeConnectionError: [InvokeConnectionError],
 | 
	
		
			
				|  |  | +            InvokeServerUnavailableError: [InvokeServerUnavailableError],
 | 
	
		
			
				|  |  | +            InvokeRateLimitError: [InvokeRateLimitError],
 | 
	
		
			
				|  |  | +            InvokeAuthorizationError: [InvokeAuthorizationError],
 | 
	
		
			
				|  |  | +            InvokeBadRequestError: [KeyError],
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Calculate response usage
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param model: model name
 | 
	
		
			
				|  |  | +        :param credentials: model credentials
 | 
	
		
			
				|  |  | +        :param tokens: input tokens
 | 
	
		
			
				|  |  | +        :return: usage
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        # get input price info
 | 
	
		
			
				|  |  | +        input_price_info = self.get_price(
 | 
	
		
			
				|  |  | +            model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # transform usage
 | 
	
		
			
				|  |  | +        usage = EmbeddingUsage(
 | 
	
		
			
				|  |  | +            tokens=tokens,
 | 
	
		
			
				|  |  | +            total_tokens=tokens,
 | 
	
		
			
				|  |  | +            unit_price=input_price_info.unit_price,
 | 
	
		
			
				|  |  | +            price_unit=input_price_info.unit,
 | 
	
		
			
				|  |  | +            total_price=input_price_info.total_amount,
 | 
	
		
			
				|  |  | +            currency=input_price_info.currency,
 | 
	
		
			
				|  |  | +            latency=time.perf_counter() - self.started_at,
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return usage
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        used to define customizable model schema
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        entity = AIModelEntity(
 | 
	
		
			
				|  |  | +            model=model,
 | 
	
		
			
				|  |  | +            label=I18nObject(en_US=model),
 | 
	
		
			
				|  |  | +            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
 | 
	
		
			
				|  |  | +            model_type=ModelType.TEXT_EMBEDDING,
 | 
	
		
			
				|  |  | +            model_properties={
 | 
	
		
			
				|  |  | +                ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)),
 | 
	
		
			
				|  |  | +                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
 | 
	
		
			
				|  |  | +            },
 | 
	
		
			
				|  |  | +            parameter_rules=[],
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return entity
 |