testutil_test.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. "sync"
  8. "testing"
  9. "time"
  10. "github.com/bradleyjkemp/cupaloy/v2"
  11. "github.com/sashabaranov/go-openai"
  12. )
  13. var testSnapshotter = cupaloy.New(cupaloy.SnapshotSubdirectory("testdata/snapshots"))
  14. // testSystemPrompt is the system prompt for testing
  15. const testSystemPrompt = `You are an intelligent agent connected to an ARP (Agent-native ERP) platform via the Model Context Protocol (MCP).
  16. You have access to the following tools:
  17. - introspect: Discover the GraphQL schema structure
  18. - query: Execute GraphQL queries (read operations)
  19. - mutate: Execute GraphQL mutations (create/update/delete operations)
  20. When you receive events (task created, task updated, message added), you should:
  21. 1. Understand the event context
  22. 2. Take appropriate action using the available tools
  23. 3. Respond concisely about what you did
  24. You can query for more information or make changes as needed. Be helpful and proactive.`
  25. // LLMInterface defines the interface for LLM operations
  26. type LLMInterface interface {
  27. Chat(ctx context.Context, messages []openai.ChatCompletionMessage, tools []openai.Tool) (*openai.ChatCompletionMessage, error)
  28. }
  29. // MCPClientInterface defines the interface for MCP client operations
  30. type MCPClientInterface interface {
  31. ListTools() ([]Tool, error)
  32. CallTool(name string, args map[string]interface{}) (*CallToolResult, error)
  33. }
  34. // MockLLM is a mock implementation of the LLM for testing
  35. type MockLLM struct {
  36. responses []*openai.ChatCompletionMessage
  37. callCount int
  38. }
  39. // NewMockLLM creates a new mock LLM with predefined responses
  40. func NewMockLLM(responses []*openai.ChatCompletionMessage) *MockLLM {
  41. return &MockLLM{responses: responses}
  42. }
  43. // Chat implements the LLMInterface
  44. func (m *MockLLM) Chat(ctx context.Context, messages []openai.ChatCompletionMessage, tools []openai.Tool) (*openai.ChatCompletionMessage, error) {
  45. if m.callCount >= len(m.responses) {
  46. return &openai.ChatCompletionMessage{
  47. Role: openai.ChatMessageRoleAssistant,
  48. Content: "No more responses available",
  49. }, nil
  50. }
  51. response := m.responses[m.callCount]
  52. m.callCount++
  53. return response, nil
  54. }
  55. // MockMCPClient is a mock implementation of the MCPClient for testing
  56. type MockMCPClient struct {
  57. tools []Tool
  58. toolResults map[string]*CallToolResult
  59. }
  60. // NewMockMCPClient creates a new mock MCP client
  61. func NewMockMCPClient(tools []Tool) *MockMCPClient {
  62. return &MockMCPClient{
  63. tools: tools,
  64. toolResults: make(map[string]*CallToolResult),
  65. }
  66. }
  67. // SetToolResult sets the result for a specific tool call
  68. func (m *MockMCPClient) SetToolResult(name string, result *CallToolResult) {
  69. m.toolResults[name] = result
  70. }
  71. // ListTools implements MCPClientInterface
  72. func (m *MockMCPClient) ListTools() ([]Tool, error) {
  73. return m.tools, nil
  74. }
  75. // CallTool implements MCPClientInterface
  76. func (m *MockMCPClient) CallTool(name string, args map[string]interface{}) (*CallToolResult, error) {
  77. if result, ok := m.toolResults[name]; ok {
  78. return result, nil
  79. }
  80. return &CallToolResult{
  81. Content: []ContentBlock{{Type: "text", Text: "mock result"}},
  82. }, nil
  83. }
  84. // TestAgent is an agent that uses interfaces for testing
  85. type TestAgent struct {
  86. llm LLMInterface
  87. mcpClient MCPClientInterface
  88. tools []openai.Tool
  89. // Event queues
  90. taskQueue *EventQueue
  91. messageQueue *EventQueue
  92. // Queue control
  93. ctx context.Context
  94. cancel context.CancelFunc
  95. wg sync.WaitGroup
  96. }
  97. // NewTestAgent creates a new test agent with interfaces
  98. func NewTestAgent(llm LLMInterface, mcpClient MCPClientInterface) *TestAgent {
  99. return &TestAgent{
  100. llm: llm,
  101. mcpClient: mcpClient,
  102. }
  103. }
  104. // SetupQueues initializes the event queues with the given capacity
  105. func (a *TestAgent) SetupQueues(maxQueueSize int) {
  106. a.taskQueue = NewEventQueue("task", maxQueueSize)
  107. a.messageQueue = NewEventQueue("message", maxQueueSize)
  108. }
  109. // QueueEvent adds an event to the appropriate queue based on its URI
  110. func (a *TestAgent) QueueEvent(uri string, data json.RawMessage) {
  111. event := &QueuedEvent{
  112. URI: uri,
  113. Data: data,
  114. Timestamp: time.Now(),
  115. }
  116. var queue *EventQueue
  117. if strings.Contains(uri, "taskCreated") || strings.Contains(uri, "taskUpdated") || strings.Contains(uri, "taskDeleted") {
  118. queue = a.taskQueue
  119. } else if strings.Contains(uri, "messageAdded") {
  120. queue = a.messageQueue
  121. } else {
  122. queue = a.taskQueue
  123. }
  124. queue.TryEnqueue(event)
  125. }
  126. // Start begins processing events from the queues
  127. func (a *TestAgent) Start(ctx context.Context) {
  128. a.ctx, a.cancel = context.WithCancel(ctx)
  129. a.wg.Add(1)
  130. go a.processQueues()
  131. }
  132. // Stop gracefully stops the queue processor
  133. func (a *TestAgent) Stop() {
  134. if a.cancel != nil {
  135. a.cancel()
  136. }
  137. a.wg.Wait()
  138. }
  139. // processQueues is the main worker that processes events from both queues
  140. func (a *TestAgent) processQueues() {
  141. defer a.wg.Done()
  142. for {
  143. select {
  144. case <-a.ctx.Done():
  145. return
  146. case event := <-a.taskQueue.Channel():
  147. a.ProcessEvent(a.ctx, event.URI, event.Data)
  148. case event := <-a.messageQueue.Channel():
  149. a.ProcessEvent(a.ctx, event.URI, event.Data)
  150. }
  151. }
  152. }
  153. // GetQueueStats returns current queue statistics
  154. func (a *TestAgent) GetQueueStats() QueueStats {
  155. return QueueStats{
  156. TaskQueueSize: a.taskQueue.Len(),
  157. MessageQueueSize: a.messageQueue.Len(),
  158. }
  159. }
  160. // Initialize initializes the test agent
  161. func (a *TestAgent) Initialize() error {
  162. mcpTools, err := a.mcpClient.ListTools()
  163. if err != nil {
  164. return err
  165. }
  166. a.tools = ConvertMCPToolsToOpenAI(mcpTools)
  167. return nil
  168. }
  169. // ProcessEvent processes an event
  170. func (a *TestAgent) ProcessEvent(ctx context.Context, uri string, eventData json.RawMessage) error {
  171. prompt := buildTestEventPrompt(uri, eventData)
  172. messages := []openai.ChatCompletionMessage{
  173. {Role: openai.ChatMessageRoleSystem, Content: testSystemPrompt},
  174. {Role: openai.ChatMessageRoleUser, Content: prompt},
  175. }
  176. return a.processWithTools(ctx, messages)
  177. }
  178. // Run runs the agent interactively
  179. func (a *TestAgent) Run(ctx context.Context, userMessage string) (string, error) {
  180. messages := []openai.ChatCompletionMessage{
  181. {Role: openai.ChatMessageRoleSystem, Content: testSystemPrompt},
  182. {Role: openai.ChatMessageRoleUser, Content: userMessage},
  183. }
  184. response, err := a.llm.Chat(ctx, messages, a.tools)
  185. if err != nil {
  186. return "", err
  187. }
  188. if len(response.ToolCalls) == 0 {
  189. return response.Content, nil
  190. }
  191. // Handle tool calls
  192. messages = append(messages, *response)
  193. for _, toolCall := range response.ToolCalls {
  194. name, args, err := ParseToolCall(toolCall)
  195. if err != nil {
  196. continue
  197. }
  198. result, err := a.mcpClient.CallTool(name, args)
  199. if err != nil {
  200. continue
  201. }
  202. messages = append(messages, openai.ChatCompletionMessage{
  203. Role: openai.ChatMessageRoleTool,
  204. ToolCallID: toolCall.ID,
  205. Content: extractTextFromResult(result),
  206. })
  207. }
  208. finalResponse, err := a.llm.Chat(ctx, messages, a.tools)
  209. if err != nil {
  210. return "", err
  211. }
  212. return finalResponse.Content, nil
  213. }
  214. func (a *TestAgent) processWithTools(ctx context.Context, messages []openai.ChatCompletionMessage) error {
  215. response, err := a.llm.Chat(ctx, messages, a.tools)
  216. if err != nil {
  217. return err
  218. }
  219. if len(response.ToolCalls) == 0 {
  220. return nil
  221. }
  222. messages = append(messages, *response)
  223. for _, toolCall := range response.ToolCalls {
  224. name, args, err := ParseToolCall(toolCall)
  225. if err != nil {
  226. continue
  227. }
  228. result, err := a.mcpClient.CallTool(name, args)
  229. if err != nil {
  230. continue
  231. }
  232. messages = append(messages, openai.ChatCompletionMessage{
  233. Role: openai.ChatMessageRoleTool,
  234. ToolCallID: toolCall.ID,
  235. Content: extractTextFromResult(result),
  236. })
  237. }
  238. return a.processWithTools(ctx, messages)
  239. }
  240. // buildTestEventPrompt builds a prompt from the event data
  241. func buildTestEventPrompt(uri string, eventData json.RawMessage) string {
  242. eventType := "unknown"
  243. if strings.Contains(uri, "taskCreated") {
  244. eventType = "task created"
  245. } else if strings.Contains(uri, "taskUpdated") {
  246. eventType = "task updated"
  247. } else if strings.Contains(uri, "taskDeleted") {
  248. eventType = "task deleted"
  249. } else if strings.Contains(uri, "messageAdded") {
  250. eventType = "message added"
  251. }
  252. eventStr := "{}"
  253. if len(eventData) > 0 {
  254. eventStr = string(eventData)
  255. }
  256. return fmt.Sprintf(`A %s event was received.
  257. Event URI: %s
  258. Event Data: %s
  259. Please process this event appropriately.`, eventType, uri, eventStr)
  260. }
  261. // testNormalizeJSON normalizes JSON for snapshot comparison
  262. func testNormalizeJSON(jsonStr string) string {
  263. var data interface{}
  264. if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
  265. return jsonStr
  266. }
  267. testNormalizeData(data)
  268. bytes, _ := json.MarshalIndent(data, "", " ")
  269. return string(bytes)
  270. }
  271. // testNormalizeData recursively normalizes data structures
  272. func testNormalizeData(data interface{}) {
  273. switch v := data.(type) {
  274. case map[string]interface{}:
  275. delete(v, "id")
  276. delete(v, "ID")
  277. delete(v, "createdAt")
  278. delete(v, "updatedAt")
  279. delete(v, "sentAt")
  280. delete(v, "createdByID")
  281. delete(v, "userId")
  282. delete(v, "serviceId")
  283. delete(v, "statusId")
  284. delete(v, "assigneeId")
  285. delete(v, "conversationId")
  286. delete(v, "senderId")
  287. delete(v, "password")
  288. for _, val := range v {
  289. testNormalizeData(val)
  290. }
  291. case []interface{}:
  292. for _, item := range v {
  293. testNormalizeData(item)
  294. }
  295. }
  296. }
  297. // testSnapshotResult captures a snapshot of the result
  298. func testSnapshotResult(t *testing.T, name string, response interface{}) {
  299. jsonBytes, _ := json.MarshalIndent(response, "", " ")
  300. normalized := testNormalizeJSON(string(jsonBytes))
  301. testSnapshotter.SnapshotT(t, name, normalized)
  302. }