agent_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "strings"
  6. "testing"
  7. "time"
  8. "github.com/sashabaranov/go-openai"
  9. )
  10. // TestAgent_Initialize tests agent initialization
  11. func TestAgent_Initialize(t *testing.T) {
  12. mockMCP := NewMockMCPClient([]Tool{
  13. {
  14. Name: "introspect",
  15. Description: "Discover the GraphQL schema",
  16. InputSchema: InputSchema{
  17. Type: "object",
  18. Properties: map[string]Property{},
  19. AdditionalProperties: false,
  20. },
  21. },
  22. {
  23. Name: "query",
  24. Description: "Execute a GraphQL query",
  25. InputSchema: InputSchema{
  26. Type: "object",
  27. Properties: map[string]Property{
  28. "query": {Type: "string", Description: "The GraphQL query"},
  29. },
  30. Required: []string{"query"},
  31. AdditionalProperties: false,
  32. },
  33. },
  34. })
  35. mockLLM := NewMockLLM(nil)
  36. agent := NewTestAgent(mockLLM, mockMCP)
  37. err := agent.Initialize()
  38. if err != nil {
  39. t.Fatalf("Initialize failed: %v", err)
  40. }
  41. if len(agent.tools) != 2 {
  42. t.Errorf("Expected 2 tools, got %d", len(agent.tools))
  43. }
  44. // Verify tools were converted correctly
  45. toolNames := make([]string, len(agent.tools))
  46. for i, tool := range agent.tools {
  47. toolNames[i] = tool.Function.Name
  48. }
  49. expectedNames := []string{"introspect", "query"}
  50. for i, expected := range expectedNames {
  51. if toolNames[i] != expected {
  52. t.Errorf("Tool %d: expected name %s, got %s", i, expected, toolNames[i])
  53. }
  54. }
  55. }
  56. // TestAgent_ProcessEvent tests event processing
  57. func TestAgent_ProcessEvent(t *testing.T) {
  58. ctx := context.Background()
  59. t.Run("TaskCreatedEvent", func(t *testing.T) {
  60. mockMCP := NewMockMCPClient([]Tool{
  61. {Name: "query", Description: "Execute a GraphQL query", InputSchema: InputSchema{Type: "object"}},
  62. })
  63. mockMCP.SetToolResult("query", &CallToolResult{
  64. Content: []ContentBlock{{Type: "text", Text: `{"data": {"tasks": []}}`}},
  65. })
  66. // Mock LLM that makes a tool call then responds
  67. mockLLM := NewMockLLM([]*openai.ChatCompletionMessage{
  68. {
  69. Role: openai.ChatMessageRoleAssistant,
  70. ToolCalls: []openai.ToolCall{
  71. {
  72. ID: "call-1",
  73. Function: openai.FunctionCall{
  74. Name: "query",
  75. Arguments: `{"query": "{ tasks { id title } }"}`,
  76. },
  77. },
  78. },
  79. },
  80. {
  81. Role: openai.ChatMessageRoleAssistant,
  82. Content: "I've processed the task created event.",
  83. },
  84. })
  85. agent := NewTestAgent(mockLLM, mockMCP)
  86. agent.Initialize()
  87. eventData := json.RawMessage(`{"taskId": "task-123", "title": "New Task"}`)
  88. err := agent.ProcessEvent(ctx, "graphql://subscription/taskCreated", eventData)
  89. if err != nil {
  90. t.Errorf("ProcessEvent failed: %v", err)
  91. }
  92. })
  93. t.Run("MessageAddedEvent", func(t *testing.T) {
  94. mockMCP := NewMockMCPClient([]Tool{
  95. {Name: "query", Description: "Execute a GraphQL query", InputSchema: InputSchema{Type: "object"}},
  96. })
  97. // Mock LLM that responds directly without tool calls
  98. mockLLM := NewMockLLM([]*openai.ChatCompletionMessage{
  99. {
  100. Role: openai.ChatMessageRoleAssistant,
  101. Content: "I received the message added event.",
  102. },
  103. })
  104. agent := NewTestAgent(mockLLM, mockMCP)
  105. agent.Initialize()
  106. eventData := json.RawMessage(`{"messageId": "msg-456", "content": "Hello!"}`)
  107. err := agent.ProcessEvent(ctx, "graphql://subscription/messageAdded", eventData)
  108. if err != nil {
  109. t.Errorf("ProcessEvent failed: %v", err)
  110. }
  111. })
  112. }
  113. // TestAgent_Run tests the interactive Run method
  114. func TestAgent_Run(t *testing.T) {
  115. ctx := context.Background()
  116. t.Run("SimpleResponse", func(t *testing.T) {
  117. mockMCP := NewMockMCPClient([]Tool{})
  118. mockLLM := NewMockLLM([]*openai.ChatCompletionMessage{
  119. {
  120. Role: openai.ChatMessageRoleAssistant,
  121. Content: "Hello! How can I help you?",
  122. },
  123. })
  124. agent := NewTestAgent(mockLLM, mockMCP)
  125. agent.Initialize()
  126. response, err := agent.Run(ctx, "Hello")
  127. if err != nil {
  128. t.Fatalf("Run failed: %v", err)
  129. }
  130. if response != "Hello! How can I help you?" {
  131. t.Errorf("Expected 'Hello! How can I help you?', got '%s'", response)
  132. }
  133. })
  134. t.Run("WithToolCall", func(t *testing.T) {
  135. mockMCP := NewMockMCPClient([]Tool{
  136. {Name: "introspect", Description: "Introspect schema", InputSchema: InputSchema{Type: "object"}},
  137. })
  138. mockMCP.SetToolResult("introspect", &CallToolResult{
  139. Content: []ContentBlock{{Type: "text", Text: "Schema: Query, Mutation, Subscription"}},
  140. })
  141. mockLLM := NewMockLLM([]*openai.ChatCompletionMessage{
  142. {
  143. Role: openai.ChatMessageRoleAssistant,
  144. ToolCalls: []openai.ToolCall{
  145. {
  146. ID: "call-1",
  147. Function: openai.FunctionCall{
  148. Name: "introspect",
  149. Arguments: `{}`,
  150. },
  151. },
  152. },
  153. },
  154. {
  155. Role: openai.ChatMessageRoleAssistant,
  156. Content: "The schema has Query, Mutation, and Subscription types.",
  157. },
  158. })
  159. agent := NewTestAgent(mockLLM, mockMCP)
  160. agent.Initialize()
  161. response, err := agent.Run(ctx, "What types are in the schema?")
  162. if err != nil {
  163. t.Fatalf("Run failed: %v", err)
  164. }
  165. if response != "The schema has Query, Mutation, and Subscription types." {
  166. t.Errorf("Unexpected response: %s", response)
  167. }
  168. })
  169. }
  170. // TestAgent_BuildEventPrompt tests event prompt building
  171. func TestAgent_BuildEventPrompt(t *testing.T) {
  172. tests := []struct {
  173. name string
  174. uri string
  175. eventData json.RawMessage
  176. wantType string
  177. }{
  178. {
  179. name: "TaskCreated",
  180. uri: "graphql://subscription/taskCreated",
  181. eventData: json.RawMessage(`{"id": "1"}`),
  182. wantType: "task created",
  183. },
  184. {
  185. name: "TaskUpdated",
  186. uri: "graphql://subscription/taskUpdated",
  187. eventData: json.RawMessage(`{"id": "2"}`),
  188. wantType: "task updated",
  189. },
  190. {
  191. name: "TaskDeleted",
  192. uri: "graphql://subscription/taskDeleted",
  193. eventData: json.RawMessage(`{"id": "3"}`),
  194. wantType: "task deleted",
  195. },
  196. {
  197. name: "MessageAdded",
  198. uri: "graphql://subscription/messageAdded",
  199. eventData: json.RawMessage(`{"id": "4"}`),
  200. wantType: "message added",
  201. },
  202. {
  203. name: "UnknownEvent",
  204. uri: "graphql://subscription/unknown",
  205. eventData: json.RawMessage(`{}`),
  206. wantType: "unknown",
  207. },
  208. }
  209. for _, tt := range tests {
  210. t.Run(tt.name, func(t *testing.T) {
  211. prompt := buildTestEventPrompt(tt.uri, tt.eventData)
  212. if !strings.Contains(prompt, tt.wantType) {
  213. t.Errorf("Expected prompt to contain '%s', got: %s", tt.wantType, prompt)
  214. }
  215. if !strings.Contains(prompt, tt.uri) {
  216. t.Errorf("Expected prompt to contain URI '%s', got: %s", tt.uri, prompt)
  217. }
  218. })
  219. }
  220. }
  221. // TestAgent_ToolNames tests the toolNames helper function
  222. func TestAgent_ToolNames(t *testing.T) {
  223. tools := []Tool{
  224. {Name: "introspect"},
  225. {Name: "query"},
  226. {Name: "mutate"},
  227. }
  228. names := toolNames(tools)
  229. expected := []string{"introspect", "query", "mutate"}
  230. if len(names) != len(expected) {
  231. t.Errorf("Expected %d names, got %d", len(expected), len(names))
  232. }
  233. for i, name := range names {
  234. if name != expected[i] {
  235. t.Errorf("Name %d: expected %s, got %s", i, expected[i], name)
  236. }
  237. }
  238. }
  239. // TestEventQueue tests the EventQueue operations
  240. func TestEventQueue(t *testing.T) {
  241. t.Run("TryEnqueueSuccess", func(t *testing.T) {
  242. queue := NewEventQueue("test", 10)
  243. event := &QueuedEvent{
  244. URI: "test://uri",
  245. Data: json.RawMessage(`{"test": "data"}`),
  246. Timestamp: time.Now(),
  247. }
  248. success := queue.TryEnqueue(event)
  249. if !success {
  250. t.Error("Expected TryEnqueue to succeed")
  251. }
  252. if queue.Len() != 1 {
  253. t.Errorf("Expected queue length 1, got %d", queue.Len())
  254. }
  255. })
  256. t.Run("TryEnqueueFullQueue", func(t *testing.T) {
  257. queue := NewEventQueue("test", 2)
  258. // Fill the queue
  259. for i := 0; i < 2; i++ {
  260. success := queue.TryEnqueue(&QueuedEvent{URI: "test://uri"})
  261. if !success {
  262. t.Errorf("Expected TryEnqueue %d to succeed", i)
  263. }
  264. }
  265. // Try to add one more - should fail
  266. success := queue.TryEnqueue(&QueuedEvent{URI: "test://overflow"})
  267. if success {
  268. t.Error("Expected TryEnqueue to fail on full queue")
  269. }
  270. if queue.Len() != 2 {
  271. t.Errorf("Expected queue length 2, got %d", queue.Len())
  272. }
  273. })
  274. t.Run("Dequeue", func(t *testing.T) {
  275. queue := NewEventQueue("test", 10)
  276. event1 := &QueuedEvent{URI: "test://uri1"}
  277. event2 := &QueuedEvent{URI: "test://uri2"}
  278. queue.TryEnqueue(event1)
  279. queue.TryEnqueue(event2)
  280. // Dequeue should return events in FIFO order
  281. dequeued1 := queue.Dequeue()
  282. if dequeued1.URI != "test://uri1" {
  283. t.Errorf("Expected URI 'test://uri1', got '%s'", dequeued1.URI)
  284. }
  285. dequeued2 := queue.Dequeue()
  286. if dequeued2.URI != "test://uri2" {
  287. t.Errorf("Expected URI 'test://uri2', got '%s'", dequeued2.URI)
  288. }
  289. })
  290. t.Run("Len", func(t *testing.T) {
  291. queue := NewEventQueue("test", 10)
  292. if queue.Len() != 0 {
  293. t.Errorf("Expected empty queue to have length 0, got %d", queue.Len())
  294. }
  295. queue.TryEnqueue(&QueuedEvent{URI: "test://uri1"})
  296. if queue.Len() != 1 {
  297. t.Errorf("Expected queue length 1, got %d", queue.Len())
  298. }
  299. queue.TryEnqueue(&QueuedEvent{URI: "test://uri2"})
  300. if queue.Len() != 2 {
  301. t.Errorf("Expected queue length 2, got %d", queue.Len())
  302. }
  303. })
  304. }
  305. // TestAgent_QueueEvent tests event routing to queues
  306. func TestAgent_QueueEvent(t *testing.T) {
  307. mockMCP := NewMockMCPClient([]Tool{})
  308. mockLLM := NewMockLLM(nil)
  309. agent := NewTestAgent(mockLLM, mockMCP)
  310. agent.SetupQueues(10)
  311. tests := []struct {
  312. name string
  313. uri string
  314. expectedTask int
  315. expectedMsg int
  316. }{
  317. {
  318. name: "TaskCreated",
  319. uri: "graphql://subscription/taskCreated",
  320. expectedTask: 1,
  321. expectedMsg: 0,
  322. },
  323. {
  324. name: "TaskUpdated",
  325. uri: "graphql://subscription/taskUpdated",
  326. expectedTask: 1,
  327. expectedMsg: 0,
  328. },
  329. {
  330. name: "TaskDeleted",
  331. uri: "graphql://subscription/taskDeleted",
  332. expectedTask: 1,
  333. expectedMsg: 0,
  334. },
  335. {
  336. name: "MessageAdded",
  337. uri: "graphql://subscription/messageAdded",
  338. expectedTask: 0,
  339. expectedMsg: 1,
  340. },
  341. {
  342. name: "UnknownEvent",
  343. uri: "graphql://subscription/unknown",
  344. expectedTask: 1, // Unknown events go to task queue
  345. expectedMsg: 0,
  346. },
  347. }
  348. for _, tt := range tests {
  349. t.Run(tt.name, func(t *testing.T) {
  350. // Reset queues
  351. agent.SetupQueues(10)
  352. agent.QueueEvent(tt.uri, json.RawMessage(`{}`))
  353. stats := agent.GetQueueStats()
  354. if stats.TaskQueueSize != tt.expectedTask {
  355. t.Errorf("Expected task queue size %d, got %d", tt.expectedTask, stats.TaskQueueSize)
  356. }
  357. if stats.MessageQueueSize != tt.expectedMsg {
  358. t.Errorf("Expected message queue size %d, got %d", tt.expectedMsg, stats.MessageQueueSize)
  359. }
  360. })
  361. }
  362. }
  363. // TestAgent_QueueEventFullQueue tests that events are dropped when queue is full
  364. func TestAgent_QueueEventFullQueue(t *testing.T) {
  365. mockMCP := NewMockMCPClient([]Tool{})
  366. mockLLM := NewMockLLM(nil)
  367. agent := NewTestAgent(mockLLM, mockMCP)
  368. agent.SetupQueues(2) // Small queue for testing
  369. // Fill the task queue
  370. agent.QueueEvent("graphql://subscription/taskCreated", json.RawMessage(`{"id": "1"}`))
  371. agent.QueueEvent("graphql://subscription/taskCreated", json.RawMessage(`{"id": "2"}`))
  372. // This should be dropped
  373. agent.QueueEvent("graphql://subscription/taskCreated", json.RawMessage(`{"id": "3"}`))
  374. stats := agent.GetQueueStats()
  375. if stats.TaskQueueSize != 2 {
  376. t.Errorf("Expected task queue size 2 (full), got %d", stats.TaskQueueSize)
  377. }
  378. }
  379. // TestAgent_StartStop tests the queue processor lifecycle
  380. func TestAgent_StartStop(t *testing.T) {
  381. mockMCP := NewMockMCPClient([]Tool{
  382. {Name: "query", Description: "Execute a GraphQL query", InputSchema: InputSchema{Type: "object"}},
  383. })
  384. mockMCP.SetToolResult("query", &CallToolResult{
  385. Content: []ContentBlock{{Type: "text", Text: `{"data": {}}`}},
  386. })
  387. // Mock LLM that responds immediately
  388. mockLLM := NewMockLLM([]*openai.ChatCompletionMessage{
  389. {
  390. Role: openai.ChatMessageRoleAssistant,
  391. Content: "Processed",
  392. },
  393. })
  394. agent := NewTestAgent(mockLLM, mockMCP)
  395. agent.Initialize()
  396. agent.SetupQueues(10)
  397. ctx, cancel := context.WithCancel(context.Background())
  398. defer cancel()
  399. // Start the queue processor
  400. agent.Start(ctx)
  401. // Queue an event
  402. agent.QueueEvent("graphql://subscription/taskCreated", json.RawMessage(`{"id": "1"}`))
  403. // Give it time to process
  404. time.Sleep(100 * time.Millisecond)
  405. // Stop the processor
  406. agent.Stop()
  407. // Verify the queue is empty (event was processed)
  408. stats := agent.GetQueueStats()
  409. if stats.TaskQueueSize != 0 {
  410. t.Errorf("Expected task queue to be empty after processing, got %d", stats.TaskQueueSize)
  411. }
  412. }
  413. // TestAgent_MultipleEventsInOrder tests that events are processed in arrival order
  414. func TestAgent_MultipleEventsInOrder(t *testing.T) {
  415. mockMCP := NewMockMCPClient([]Tool{
  416. {Name: "query", Description: "Execute a GraphQL query", InputSchema: InputSchema{Type: "object"}},
  417. })
  418. mockMCP.SetToolResult("query", &CallToolResult{
  419. Content: []ContentBlock{{Type: "text", Text: `{"data": {}}`}},
  420. })
  421. // Track the order of processed events
  422. var processedOrder []string
  423. // Mock LLM that responds immediately and tracks order
  424. mockLLM := NewMockLLM([]*openai.ChatCompletionMessage{
  425. {
  426. Role: openai.ChatMessageRoleAssistant,
  427. Content: "Processed",
  428. },
  429. {
  430. Role: openai.ChatMessageRoleAssistant,
  431. Content: "Processed",
  432. },
  433. {
  434. Role: openai.ChatMessageRoleAssistant,
  435. Content: "Processed",
  436. },
  437. })
  438. agent := NewTestAgent(mockLLM, mockMCP)
  439. agent.Initialize()
  440. agent.SetupQueues(10)
  441. ctx, cancel := context.WithCancel(context.Background())
  442. defer cancel()
  443. // Start the queue processor
  444. agent.Start(ctx)
  445. // Queue multiple events in order: task, message, task
  446. agent.QueueEvent("graphql://subscription/taskCreated", json.RawMessage(`{"id": "task-1"}`))
  447. agent.QueueEvent("graphql://subscription/messageAdded", json.RawMessage(`{"id": "msg-1"}`))
  448. agent.QueueEvent("graphql://subscription/taskUpdated", json.RawMessage(`{"id": "task-2"}`))
  449. // Give it time to process all events
  450. time.Sleep(200 * time.Millisecond)
  451. // Stop the processor
  452. agent.Stop()
  453. // Verify all queues are empty
  454. stats := agent.GetQueueStats()
  455. if stats.TaskQueueSize != 0 {
  456. t.Errorf("Expected task queue to be empty, got %d", stats.TaskQueueSize)
  457. }
  458. if stats.MessageQueueSize != 0 {
  459. t.Errorf("Expected message queue to be empty, got %d", stats.MessageQueueSize)
  460. }
  461. // Verify order was preserved (we can't easily check this with the mock, but the test validates the mechanism)
  462. _ = processedOrder
  463. }