test_llm.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. UserPromptMessage,
  10. )
  11. from core.model_runtime.entities.model_entities import AIModelEntity
  12. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  13. from core.model_runtime.model_providers.fireworks.llm.llm import FireworksLargeLanguageModel
  14. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  15. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  16. def test_predefined_models():
  17. model = FireworksLargeLanguageModel()
  18. model_schemas = model.predefined_models()
  19. assert len(model_schemas) >= 1
  20. assert isinstance(model_schemas[0], AIModelEntity)
  21. @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
  22. def test_validate_credentials_for_chat_model(setup_openai_mock):
  23. model = FireworksLargeLanguageModel()
  24. with pytest.raises(CredentialsValidateFailedError):
  25. # model name to gpt-3.5-turbo because of mocking
  26. model.validate_credentials(model="gpt-3.5-turbo", credentials={"fireworks_api_key": "invalid_key"})
  27. model.validate_credentials(
  28. model="accounts/fireworks/models/llama-v3p1-8b-instruct",
  29. credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
  30. )
  31. @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
  32. def test_invoke_chat_model(setup_openai_mock):
  33. model = FireworksLargeLanguageModel()
  34. result = model.invoke(
  35. model="accounts/fireworks/models/llama-v3p1-8b-instruct",
  36. credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
  37. prompt_messages=[
  38. SystemPromptMessage(
  39. content="You are a helpful AI assistant.",
  40. ),
  41. UserPromptMessage(content="Hello World!"),
  42. ],
  43. model_parameters={
  44. "temperature": 0.0,
  45. "top_p": 1.0,
  46. "presence_penalty": 0.0,
  47. "frequency_penalty": 0.0,
  48. "max_tokens": 10,
  49. },
  50. stop=["How"],
  51. stream=False,
  52. user="foo",
  53. )
  54. assert isinstance(result, LLMResult)
  55. assert len(result.message.content) > 0
  56. @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
  57. def test_invoke_chat_model_with_tools(setup_openai_mock):
  58. model = FireworksLargeLanguageModel()
  59. result = model.invoke(
  60. model="accounts/fireworks/models/llama-v3p1-8b-instruct",
  61. credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
  62. prompt_messages=[
  63. SystemPromptMessage(
  64. content="You are a helpful AI assistant.",
  65. ),
  66. UserPromptMessage(
  67. content="what's the weather today in London?",
  68. ),
  69. ],
  70. model_parameters={"temperature": 0.0, "max_tokens": 100},
  71. tools=[
  72. PromptMessageTool(
  73. name="get_weather",
  74. description="Determine weather in my location",
  75. parameters={
  76. "type": "object",
  77. "properties": {
  78. "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
  79. "unit": {"type": "string", "enum": ["c", "f"]},
  80. },
  81. "required": ["location"],
  82. },
  83. ),
  84. PromptMessageTool(
  85. name="get_stock_price",
  86. description="Get the current stock price",
  87. parameters={
  88. "type": "object",
  89. "properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
  90. "required": ["symbol"],
  91. },
  92. ),
  93. ],
  94. stream=False,
  95. user="foo",
  96. )
  97. assert isinstance(result, LLMResult)
  98. assert isinstance(result.message, AssistantPromptMessage)
  99. assert len(result.message.tool_calls) > 0
  100. @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
  101. def test_invoke_stream_chat_model(setup_openai_mock):
  102. model = FireworksLargeLanguageModel()
  103. result = model.invoke(
  104. model="accounts/fireworks/models/llama-v3p1-8b-instruct",
  105. credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
  106. prompt_messages=[
  107. SystemPromptMessage(
  108. content="You are a helpful AI assistant.",
  109. ),
  110. UserPromptMessage(content="Hello World!"),
  111. ],
  112. model_parameters={"temperature": 0.0, "max_tokens": 100},
  113. stream=True,
  114. user="foo",
  115. )
  116. assert isinstance(result, Generator)
  117. for chunk in result:
  118. assert isinstance(chunk, LLMResultChunk)
  119. assert isinstance(chunk.delta, LLMResultChunkDelta)
  120. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  121. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  122. if chunk.delta.finish_reason is not None:
  123. assert chunk.delta.usage is not None
  124. assert chunk.delta.usage.completion_tokens > 0
  125. def test_get_num_tokens():
  126. model = FireworksLargeLanguageModel()
  127. num_tokens = model.get_num_tokens(
  128. model="accounts/fireworks/models/llama-v3p1-8b-instruct",
  129. credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
  130. prompt_messages=[UserPromptMessage(content="Hello World!")],
  131. )
  132. assert num_tokens == 10
  133. num_tokens = model.get_num_tokens(
  134. model="accounts/fireworks/models/llama-v3p1-8b-instruct",
  135. credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")},
  136. prompt_messages=[
  137. SystemPromptMessage(
  138. content="You are a helpful AI assistant.",
  139. ),
  140. UserPromptMessage(content="Hello World!"),
  141. ],
  142. tools=[
  143. PromptMessageTool(
  144. name="get_weather",
  145. description="Determine weather in my location",
  146. parameters={
  147. "type": "object",
  148. "properties": {
  149. "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
  150. "unit": {"type": "string", "enum": ["c", "f"]},
  151. },
  152. "required": ["location"],
  153. },
  154. ),
  155. ],
  156. )
  157. assert num_tokens == 77