test_llm.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. UserPromptMessage,
  10. )
  11. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  12. from core.model_runtime.model_providers.localai.llm.llm import LocalAILanguageModel
  13. def test_validate_credentials_for_chat_model():
  14. model = LocalAILanguageModel()
  15. with pytest.raises(CredentialsValidateFailedError):
  16. model.validate_credentials(
  17. model="chinese-llama-2-7b",
  18. credentials={
  19. "server_url": "hahahaha",
  20. "completion_type": "completion",
  21. },
  22. )
  23. model.validate_credentials(
  24. model="chinese-llama-2-7b",
  25. credentials={
  26. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  27. "completion_type": "completion",
  28. },
  29. )
  30. def test_invoke_completion_model():
  31. model = LocalAILanguageModel()
  32. response = model.invoke(
  33. model="chinese-llama-2-7b",
  34. credentials={
  35. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  36. "completion_type": "completion",
  37. },
  38. prompt_messages=[UserPromptMessage(content="ping")],
  39. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  40. stop=[],
  41. user="abc-123",
  42. stream=False,
  43. )
  44. assert isinstance(response, LLMResult)
  45. assert len(response.message.content) > 0
  46. assert response.usage.total_tokens > 0
  47. def test_invoke_chat_model():
  48. model = LocalAILanguageModel()
  49. response = model.invoke(
  50. model="chinese-llama-2-7b",
  51. credentials={
  52. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  53. "completion_type": "chat_completion",
  54. },
  55. prompt_messages=[UserPromptMessage(content="ping")],
  56. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  57. stop=[],
  58. user="abc-123",
  59. stream=False,
  60. )
  61. assert isinstance(response, LLMResult)
  62. assert len(response.message.content) > 0
  63. assert response.usage.total_tokens > 0
  64. def test_invoke_stream_completion_model():
  65. model = LocalAILanguageModel()
  66. response = model.invoke(
  67. model="chinese-llama-2-7b",
  68. credentials={
  69. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  70. "completion_type": "completion",
  71. },
  72. prompt_messages=[UserPromptMessage(content="Hello World!")],
  73. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  74. stop=["you"],
  75. stream=True,
  76. user="abc-123",
  77. )
  78. assert isinstance(response, Generator)
  79. for chunk in response:
  80. assert isinstance(chunk, LLMResultChunk)
  81. assert isinstance(chunk.delta, LLMResultChunkDelta)
  82. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  83. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  84. def test_invoke_stream_chat_model():
  85. model = LocalAILanguageModel()
  86. response = model.invoke(
  87. model="chinese-llama-2-7b",
  88. credentials={
  89. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  90. "completion_type": "chat_completion",
  91. },
  92. prompt_messages=[UserPromptMessage(content="Hello World!")],
  93. model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
  94. stop=["you"],
  95. stream=True,
  96. user="abc-123",
  97. )
  98. assert isinstance(response, Generator)
  99. for chunk in response:
  100. assert isinstance(chunk, LLMResultChunk)
  101. assert isinstance(chunk.delta, LLMResultChunkDelta)
  102. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  103. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  104. def test_get_num_tokens():
  105. model = LocalAILanguageModel()
  106. num_tokens = model.get_num_tokens(
  107. model="????",
  108. credentials={
  109. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  110. "completion_type": "chat_completion",
  111. },
  112. prompt_messages=[
  113. SystemPromptMessage(
  114. content="You are a helpful AI assistant.",
  115. ),
  116. UserPromptMessage(content="Hello World!"),
  117. ],
  118. tools=[
  119. PromptMessageTool(
  120. name="get_current_weather",
  121. description="Get the current weather in a given location",
  122. parameters={
  123. "type": "object",
  124. "properties": {
  125. "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
  126. "unit": {"type": "string", "enum": ["c", "f"]},
  127. },
  128. "required": ["location"],
  129. },
  130. )
  131. ],
  132. )
  133. assert isinstance(num_tokens, int)
  134. assert num_tokens == 77
  135. num_tokens = model.get_num_tokens(
  136. model="????",
  137. credentials={
  138. "server_url": os.environ.get("LOCALAI_SERVER_URL"),
  139. "completion_type": "chat_completion",
  140. },
  141. prompt_messages=[UserPromptMessage(content="Hello World!")],
  142. )
  143. assert isinstance(num_tokens, int)
  144. assert num_tokens == 10