|
@@ -2,7 +2,6 @@ import json
|
|
|
from typing import Type
|
|
|
|
|
|
from huggingface_hub import HfApi
|
|
|
-from langchain.llms import HuggingFaceEndpoint
|
|
|
|
|
|
from core.helper import encrypter
|
|
|
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
|
@@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
|
|
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
|
|
|
|
|
from core.model_providers.models.base import BaseProviderModel
|
|
|
+from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
|
|
from models.provider import ProviderType
|
|
|
|
|
|
|
|
@@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
|
|
if 'huggingfacehub_endpoint_url' not in credentials:
|
|
|
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
|
|
|
|
|
|
+ if 'task_type' not in credentials:
|
|
|
+ raise CredentialsValidateFailedError('Task Type must be provided.')
|
|
|
+
|
|
|
+ if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
|
|
+ raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
|
|
|
+
|
|
|
try:
|
|
|
- llm = HuggingFaceEndpoint(
|
|
|
+ llm = HuggingFaceEndpointLLM(
|
|
|
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
|
|
- task="text2text-generation",
|
|
|
+ task=credentials['task_type'],
|
|
|
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
|
|
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
|
|
)
|
|
@@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
|
|
}
|
|
|
|
|
|
credentials = json.loads(provider_model.encrypted_config)
|
|
|
+
|
|
|
+ if 'task_type' not in credentials:
|
|
|
+ credentials['task_type'] = 'text-generation'
|
|
|
+
|
|
|
if credentials['huggingfacehub_api_token']:
|
|
|
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
|
|
|
self.provider.tenant_id,
|