model.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from unittest.mock import MagicMock
  2. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  3. from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
  4. from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
  5. from core.model_manager import ModelInstance
  6. from core.model_runtime.entities.model_entities import ModelType
  7. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  8. from models.provider import ProviderType
  9. def get_mocked_fetch_model_config(
  10. provider: str,
  11. model: str,
  12. mode: str,
  13. credentials: dict,
  14. ):
  15. model_provider_factory = ModelProviderFactory(tenant_id="test_tenant")
  16. model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
  17. provider_model_bundle = ProviderModelBundle(
  18. configuration=ProviderConfiguration(
  19. tenant_id="1",
  20. provider=model_provider_factory.get_provider_schema(provider),
  21. preferred_provider_type=ProviderType.CUSTOM,
  22. using_provider_type=ProviderType.CUSTOM,
  23. system_configuration=SystemConfiguration(enabled=False),
  24. custom_configuration=CustomConfiguration(provider=CustomProviderConfiguration(credentials=credentials)),
  25. model_settings=[],
  26. ),
  27. model_type_instance=model_type_instance,
  28. )
  29. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model)
  30. model_schema = model_provider_factory.get_model_schema(
  31. provider=provider,
  32. model_type=model_type_instance.model_type,
  33. model=model,
  34. credentials=credentials,
  35. )
  36. assert model_schema is not None
  37. model_config = ModelConfigWithCredentialsEntity(
  38. model=model,
  39. provider=provider,
  40. mode=mode,
  41. credentials=credentials,
  42. parameters={},
  43. model_schema=model_schema,
  44. provider_model_bundle=provider_model_bundle,
  45. )
  46. return MagicMock(return_value=(model_instance, model_config))