test_embedding.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.volcengine_maas.text_embedding.text_embedding import (
  6. VolcengineMaaSTextEmbeddingModel,
  7. )
  8. def test_validate_credentials():
  9. model = VolcengineMaaSTextEmbeddingModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model="NOT IMPORTANT",
  13. credentials={
  14. "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
  15. "volc_region": "cn-beijing",
  16. "volc_access_key_id": "INVALID",
  17. "volc_secret_access_key": "INVALID",
  18. "endpoint_id": "INVALID",
  19. "base_model_name": "Doubao-embedding",
  20. },
  21. )
  22. model.validate_credentials(
  23. model="NOT IMPORTANT",
  24. credentials={
  25. "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
  26. "volc_region": "cn-beijing",
  27. "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
  28. "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
  29. "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
  30. "base_model_name": "Doubao-embedding",
  31. },
  32. )
  33. def test_invoke_model():
  34. model = VolcengineMaaSTextEmbeddingModel()
  35. result = model.invoke(
  36. model="NOT IMPORTANT",
  37. credentials={
  38. "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
  39. "volc_region": "cn-beijing",
  40. "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
  41. "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
  42. "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
  43. "base_model_name": "Doubao-embedding",
  44. },
  45. texts=["hello", "world"],
  46. user="abc-123",
  47. )
  48. assert isinstance(result, TextEmbeddingResult)
  49. assert len(result.embeddings) == 2
  50. assert result.usage.total_tokens > 0
  51. def test_get_num_tokens():
  52. model = VolcengineMaaSTextEmbeddingModel()
  53. num_tokens = model.get_num_tokens(
  54. model="NOT IMPORTANT",
  55. credentials={
  56. "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
  57. "volc_region": "cn-beijing",
  58. "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
  59. "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
  60. "endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
  61. "base_model_name": "Doubao-embedding",
  62. },
  63. texts=["hello", "world"],
  64. )
  65. assert num_tokens == 2