test_text_embedding.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.baichuan.text_embedding.text_embedding import BaichuanTextEmbeddingModel
  6. def test_validate_credentials():
  7. model = BaichuanTextEmbeddingModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"})
  10. model.validate_credentials(
  11. model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}
  12. )
  13. def test_invoke_model():
  14. model = BaichuanTextEmbeddingModel()
  15. result = model.invoke(
  16. model="baichuan-text-embedding",
  17. credentials={
  18. "api_key": os.environ.get("BAICHUAN_API_KEY"),
  19. },
  20. texts=["hello", "world"],
  21. user="abc-123",
  22. )
  23. assert isinstance(result, TextEmbeddingResult)
  24. assert len(result.embeddings) == 2
  25. assert result.usage.total_tokens == 6
  26. def test_get_num_tokens():
  27. model = BaichuanTextEmbeddingModel()
  28. num_tokens = model.get_num_tokens(
  29. model="baichuan-text-embedding",
  30. credentials={
  31. "api_key": os.environ.get("BAICHUAN_API_KEY"),
  32. },
  33. texts=["hello", "world"],
  34. )
  35. assert num_tokens == 2
  36. def test_max_chunks():
  37. model = BaichuanTextEmbeddingModel()
  38. result = model.invoke(
  39. model="baichuan-text-embedding",
  40. credentials={
  41. "api_key": os.environ.get("BAICHUAN_API_KEY"),
  42. },
  43. texts=[
  44. "hello",
  45. "world",
  46. "hello",
  47. "world",
  48. "hello",
  49. "world",
  50. "hello",
  51. "world",
  52. "hello",
  53. "world",
  54. "hello",
  55. "world",
  56. "hello",
  57. "world",
  58. "hello",
  59. "world",
  60. "hello",
  61. "world",
  62. "hello",
  63. "world",
  64. "hello",
  65. "world",
  66. ],
  67. )
  68. assert isinstance(result, TextEmbeddingResult)
  69. assert len(result.embeddings) == 22