1
0

testutil_test.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "testing"
  9. "time"
  10. "github.com/bradleyjkemp/cupaloy/v2"
  11. "github.com/sashabaranov/go-openai"
  12. )
  13. var testSnapshotter = cupaloy.New(cupaloy.SnapshotSubdirectory("testdata/snapshots"))
  14. // testSystemPrompt is the system prompt for testing
  15. const testSystemPrompt = `You are an intelligent agent connected to an ARP (Agent-native ERP) platform via the Model Context Protocol (MCP).
  16. You have access to the following tools:
  17. - introspect: Discover the GraphQL schema structure
  18. - query: Execute GraphQL queries (read operations)
  19. - mutate: Execute GraphQL mutations (create/update/delete operations)
  20. When you receive events (task created, task updated, message added), you should:
  21. 1. Understand the event context
  22. 2. Take appropriate action using the available tools
  23. 3. Respond concisely about what you did
  24. You can query for more information or make changes as needed. Be helpful and proactive.`
  25. // LLMInterface defines the interface for LLM operations
  26. type LLMInterface interface {
  27. Chat(ctx context.Context, messages []openai.ChatCompletionMessage, tools []openai.Tool) (*openai.ChatCompletionMessage, error)
  28. }
  29. // MockLLM is a mock implementation of the LLM for testing
  30. type MockLLM struct {
  31. responses []*openai.ChatCompletionMessage
  32. callCount int
  33. }
  34. // NewMockLLM creates a new mock LLM with predefined responses
  35. func NewMockLLM(responses []*openai.ChatCompletionMessage) *MockLLM {
  36. return &MockLLM{responses: responses}
  37. }
  38. // Chat implements the LLMInterface
  39. func (m *MockLLM) Chat(ctx context.Context, messages []openai.ChatCompletionMessage, tools []openai.Tool) (*openai.ChatCompletionMessage, error) {
  40. if m.callCount >= len(m.responses) {
  41. return &openai.ChatCompletionMessage{
  42. Role: openai.ChatMessageRoleAssistant,
  43. Content: "No more responses available",
  44. }, nil
  45. }
  46. response := m.responses[m.callCount]
  47. m.callCount++
  48. return response, nil
  49. }
  50. // MockMCPClient is a mock implementation of the MCPClient for testing
  51. type MockMCPClient struct {
  52. tools []Tool
  53. toolResults map[string]*CallToolResult
  54. }
  55. // NewMockMCPClient creates a new mock MCP client
  56. func NewMockMCPClient(tools []Tool) *MockMCPClient {
  57. return &MockMCPClient{
  58. tools: tools,
  59. toolResults: make(map[string]*CallToolResult),
  60. }
  61. }
  62. // SetToolResult sets the result for a specific tool call
  63. func (m *MockMCPClient) SetToolResult(name string, result *CallToolResult) {
  64. m.toolResults[name] = result
  65. }
  66. // ListTools implements MCPClientInterface
  67. func (m *MockMCPClient) ListTools() ([]Tool, error) {
  68. return m.tools, nil
  69. }
  70. // CallTool implements MCPClientInterface
  71. func (m *MockMCPClient) CallTool(name string, args map[string]interface{}) (*CallToolResult, error) {
  72. if result, ok := m.toolResults[name]; ok {
  73. return result, nil
  74. }
  75. return &CallToolResult{
  76. Content: []ContentBlock{{Type: "text", Text: "mock result"}},
  77. }, nil
  78. }
  79. // TestAgent is an agent that uses interfaces for testing
  80. type TestAgent struct {
  81. llm LLMInterface
  82. mcpClient MCPClientInterface
  83. tools []openai.Tool
  84. // Event queues
  85. taskQueue *EventQueue
  86. messageQueue *EventQueue
  87. // Queue control
  88. ctx context.Context
  89. cancel context.CancelFunc
  90. wg sync.WaitGroup
  91. }
  92. // NewTestAgent creates a new test agent with interfaces
  93. func NewTestAgent(llm LLMInterface, mcpClient MCPClientInterface) *TestAgent {
  94. return &TestAgent{
  95. llm: llm,
  96. mcpClient: mcpClient,
  97. }
  98. }
  99. // SetupQueues initializes the event queues with the given capacity
  100. func (a *TestAgent) SetupQueues(maxQueueSize int) {
  101. a.taskQueue = NewEventQueue("task", maxQueueSize)
  102. a.messageQueue = NewEventQueue("message", maxQueueSize)
  103. }
  104. // QueueEvent adds an event to the appropriate queue based on its URI
  105. func (a *TestAgent) QueueEvent(uri string, data json.RawMessage) {
  106. event := &QueuedEvent{
  107. URI: uri,
  108. Data: data,
  109. Timestamp: time.Now(),
  110. }
  111. var queue *EventQueue
  112. if strings.Contains(uri, "taskCreated") || strings.Contains(uri, "taskUpdated") || strings.Contains(uri, "taskDeleted") {
  113. queue = a.taskQueue
  114. } else if strings.Contains(uri, "messageAdded") {
  115. queue = a.messageQueue
  116. } else {
  117. queue = a.taskQueue
  118. }
  119. queue.TryEnqueue(event)
  120. }
  121. // Start begins processing events from the queues
  122. func (a *TestAgent) Start(ctx context.Context) {
  123. a.ctx, a.cancel = context.WithCancel(ctx)
  124. a.wg.Add(1)
  125. go a.processQueues()
  126. }
  127. // Stop gracefully stops the queue processor
  128. func (a *TestAgent) Stop() {
  129. if a.cancel != nil {
  130. a.cancel()
  131. }
  132. a.wg.Wait()
  133. }
  134. // processQueues is the main worker that processes events from both queues
  135. func (a *TestAgent) processQueues() {
  136. defer a.wg.Done()
  137. for {
  138. select {
  139. case <-a.ctx.Done():
  140. return
  141. case event := <-a.taskQueue.Channel():
  142. a.ProcessEvent(a.ctx, event.URI, event.Data)
  143. case event := <-a.messageQueue.Channel():
  144. a.ProcessEvent(a.ctx, event.URI, event.Data)
  145. }
  146. }
  147. }
  148. // GetQueueStats returns current queue statistics
  149. func (a *TestAgent) GetQueueStats() QueueStats {
  150. return QueueStats{
  151. TaskQueueSize: a.taskQueue.Len(),
  152. MessageQueueSize: a.messageQueue.Len(),
  153. }
  154. }
  155. // Initialize initializes the test agent
  156. func (a *TestAgent) Initialize() error {
  157. mcpTools, err := a.mcpClient.ListTools()
  158. if err != nil {
  159. return err
  160. }
  161. a.tools = ConvertMCPToolsToOpenAI(mcpTools)
  162. return nil
  163. }
  164. // ProcessEvent processes an event
  165. func (a *TestAgent) ProcessEvent(ctx context.Context, uri string, eventData json.RawMessage) error {
  166. prompt := buildTestEventPrompt(uri, eventData)
  167. messages := []openai.ChatCompletionMessage{
  168. {Role: openai.ChatMessageRoleSystem, Content: testSystemPrompt},
  169. {Role: openai.ChatMessageRoleUser, Content: prompt},
  170. }
  171. return a.processWithTools(ctx, messages)
  172. }
  173. // Run runs the agent interactively
  174. func (a *TestAgent) Run(ctx context.Context, userMessage string) (string, error) {
  175. messages := []openai.ChatCompletionMessage{
  176. {Role: openai.ChatMessageRoleSystem, Content: testSystemPrompt},
  177. {Role: openai.ChatMessageRoleUser, Content: userMessage},
  178. }
  179. response, err := a.llm.Chat(ctx, messages, a.tools)
  180. if err != nil {
  181. return "", err
  182. }
  183. if len(response.ToolCalls) == 0 {
  184. return response.Content, nil
  185. }
  186. // Handle tool calls
  187. messages = append(messages, *response)
  188. for _, toolCall := range response.ToolCalls {
  189. name, args, err := ParseToolCall(toolCall)
  190. if err != nil {
  191. continue
  192. }
  193. result, err := a.mcpClient.CallTool(name, args)
  194. if err != nil {
  195. continue
  196. }
  197. messages = append(messages, openai.ChatCompletionMessage{
  198. Role: openai.ChatMessageRoleTool,
  199. ToolCallID: toolCall.ID,
  200. Content: extractTextFromResult(result),
  201. })
  202. }
  203. finalResponse, err := a.llm.Chat(ctx, messages, a.tools)
  204. if err != nil {
  205. return "", err
  206. }
  207. return finalResponse.Content, nil
  208. }
  209. func (a *TestAgent) processWithTools(ctx context.Context, messages []openai.ChatCompletionMessage) error {
  210. response, err := a.llm.Chat(ctx, messages, a.tools)
  211. if err != nil {
  212. return err
  213. }
  214. if len(response.ToolCalls) == 0 {
  215. return nil
  216. }
  217. messages = append(messages, *response)
  218. for _, toolCall := range response.ToolCalls {
  219. name, args, err := ParseToolCall(toolCall)
  220. if err != nil {
  221. continue
  222. }
  223. result, err := a.mcpClient.CallTool(name, args)
  224. if err != nil {
  225. continue
  226. }
  227. messages = append(messages, openai.ChatCompletionMessage{
  228. Role: openai.ChatMessageRoleTool,
  229. ToolCallID: toolCall.ID,
  230. Content: extractTextFromResult(result),
  231. })
  232. }
  233. return a.processWithTools(ctx, messages)
  234. }
  235. // buildTestEventPrompt builds a prompt from the event data
  236. func buildTestEventPrompt(uri string, eventData json.RawMessage) string {
  237. eventType := "unknown"
  238. if strings.Contains(uri, "taskCreated") {
  239. eventType = "task created"
  240. } else if strings.Contains(uri, "taskUpdated") {
  241. eventType = "task updated"
  242. } else if strings.Contains(uri, "taskDeleted") {
  243. eventType = "task deleted"
  244. } else if strings.Contains(uri, "messageAdded") {
  245. eventType = "message added"
  246. }
  247. eventStr := "{}"
  248. if len(eventData) > 0 {
  249. eventStr = string(eventData)
  250. }
  251. return fmt.Sprintf(`A %s event was received.
  252. Event URI: %s
  253. Event Data: %s
  254. Please process this event appropriately.`, eventType, uri, eventStr)
  255. }
  256. // testNormalizeJSON normalizes JSON for snapshot comparison
  257. func testNormalizeJSON(jsonStr string) string {
  258. var data interface{}
  259. if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
  260. return jsonStr
  261. }
  262. testNormalizeData(data)
  263. bytes, _ := json.MarshalIndent(data, "", " ")
  264. return string(bytes)
  265. }
  266. // testNormalizeData recursively normalizes data structures
  267. func testNormalizeData(data interface{}) {
  268. switch v := data.(type) {
  269. case map[string]interface{}:
  270. delete(v, "id")
  271. delete(v, "ID")
  272. delete(v, "createdAt")
  273. delete(v, "updatedAt")
  274. delete(v, "sentAt")
  275. delete(v, "createdByID")
  276. delete(v, "userId")
  277. delete(v, "serviceId")
  278. delete(v, "statusId")
  279. delete(v, "assigneeId")
  280. delete(v, "conversationId")
  281. delete(v, "senderId")
  282. delete(v, "password")
  283. for _, val := range v {
  284. testNormalizeData(val)
  285. }
  286. case []interface{}:
  287. for _, item := range v {
  288. testNormalizeData(item)
  289. }
  290. }
  291. }
  292. // testSnapshotResult captures a snapshot of the result
  293. func testSnapshotResult(t *testing.T, name string, response interface{}) {
  294. jsonBytes, _ := json.MarshalIndent(response, "", " ")
  295. normalized := testNormalizeJSON(string(jsonBytes))
  296. testSnapshotter.SnapshotT(t, name, normalized)
  297. }