| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349 |
- package main
- import (
- "context"
- "encoding/json"
- "fmt"
- "strings"
- "sync"
- "testing"
- "time"
- "github.com/bradleyjkemp/cupaloy/v2"
- "github.com/sashabaranov/go-openai"
- )
- var testSnapshotter = cupaloy.New(cupaloy.SnapshotSubdirectory("testdata/snapshots"))
- // testSystemPrompt is the system prompt for testing
- const testSystemPrompt = `You are an intelligent agent connected to an ARP (Agent-native ERP) platform via the Model Context Protocol (MCP).
- You have access to the following tools:
- - introspect: Discover the GraphQL schema structure
- - query: Execute GraphQL queries (read operations)
- - mutate: Execute GraphQL mutations (create/update/delete operations)
- When you receive events (task created, task updated, message added), you should:
- 1. Understand the event context
- 2. Take appropriate action using the available tools
- 3. Respond concisely about what you did
- You can query for more information or make changes as needed. Be helpful and proactive.`
- // LLMInterface defines the interface for LLM operations
- type LLMInterface interface {
- Chat(ctx context.Context, messages []openai.ChatCompletionMessage, tools []openai.Tool) (*openai.ChatCompletionMessage, error)
- }
- // MCPClientInterface defines the interface for MCP client operations
- type MCPClientInterface interface {
- ListTools() ([]Tool, error)
- CallTool(name string, args map[string]interface{}) (*CallToolResult, error)
- }
- // MockLLM is a mock implementation of the LLM for testing
- type MockLLM struct {
- responses []*openai.ChatCompletionMessage
- callCount int
- }
- // NewMockLLM creates a new mock LLM with predefined responses
- func NewMockLLM(responses []*openai.ChatCompletionMessage) *MockLLM {
- return &MockLLM{responses: responses}
- }
- // Chat implements the LLMInterface
- func (m *MockLLM) Chat(ctx context.Context, messages []openai.ChatCompletionMessage, tools []openai.Tool) (*openai.ChatCompletionMessage, error) {
- if m.callCount >= len(m.responses) {
- return &openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleAssistant,
- Content: "No more responses available",
- }, nil
- }
- response := m.responses[m.callCount]
- m.callCount++
- return response, nil
- }
- // MockMCPClient is a mock implementation of the MCPClient for testing
- type MockMCPClient struct {
- tools []Tool
- toolResults map[string]*CallToolResult
- }
- // NewMockMCPClient creates a new mock MCP client
- func NewMockMCPClient(tools []Tool) *MockMCPClient {
- return &MockMCPClient{
- tools: tools,
- toolResults: make(map[string]*CallToolResult),
- }
- }
- // SetToolResult sets the result for a specific tool call
- func (m *MockMCPClient) SetToolResult(name string, result *CallToolResult) {
- m.toolResults[name] = result
- }
- // ListTools implements MCPClientInterface
- func (m *MockMCPClient) ListTools() ([]Tool, error) {
- return m.tools, nil
- }
- // CallTool implements MCPClientInterface
- func (m *MockMCPClient) CallTool(name string, args map[string]interface{}) (*CallToolResult, error) {
- if result, ok := m.toolResults[name]; ok {
- return result, nil
- }
- return &CallToolResult{
- Content: []ContentBlock{{Type: "text", Text: "mock result"}},
- }, nil
- }
- // TestAgent is an agent that uses interfaces for testing
- type TestAgent struct {
- llm LLMInterface
- mcpClient MCPClientInterface
- tools []openai.Tool
- // Event queues
- taskQueue *EventQueue
- messageQueue *EventQueue
- // Queue control
- ctx context.Context
- cancel context.CancelFunc
- wg sync.WaitGroup
- }
- // NewTestAgent creates a new test agent with interfaces
- func NewTestAgent(llm LLMInterface, mcpClient MCPClientInterface) *TestAgent {
- return &TestAgent{
- llm: llm,
- mcpClient: mcpClient,
- }
- }
- // SetupQueues initializes the event queues with the given capacity
- func (a *TestAgent) SetupQueues(maxQueueSize int) {
- a.taskQueue = NewEventQueue("task", maxQueueSize)
- a.messageQueue = NewEventQueue("message", maxQueueSize)
- }
- // QueueEvent adds an event to the appropriate queue based on its URI
- func (a *TestAgent) QueueEvent(uri string, data json.RawMessage) {
- event := &QueuedEvent{
- URI: uri,
- Data: data,
- Timestamp: time.Now(),
- }
- var queue *EventQueue
- if strings.Contains(uri, "taskCreated") || strings.Contains(uri, "taskUpdated") || strings.Contains(uri, "taskDeleted") {
- queue = a.taskQueue
- } else if strings.Contains(uri, "messageAdded") {
- queue = a.messageQueue
- } else {
- queue = a.taskQueue
- }
- queue.TryEnqueue(event)
- }
- // Start begins processing events from the queues
- func (a *TestAgent) Start(ctx context.Context) {
- a.ctx, a.cancel = context.WithCancel(ctx)
- a.wg.Add(1)
- go a.processQueues()
- }
- // Stop gracefully stops the queue processor
- func (a *TestAgent) Stop() {
- if a.cancel != nil {
- a.cancel()
- }
- a.wg.Wait()
- }
- // processQueues is the main worker that processes events from both queues
- func (a *TestAgent) processQueues() {
- defer a.wg.Done()
- for {
- select {
- case <-a.ctx.Done():
- return
- case event := <-a.taskQueue.Channel():
- a.ProcessEvent(a.ctx, event.URI, event.Data)
- case event := <-a.messageQueue.Channel():
- a.ProcessEvent(a.ctx, event.URI, event.Data)
- }
- }
- }
- // GetQueueStats returns current queue statistics
- func (a *TestAgent) GetQueueStats() QueueStats {
- return QueueStats{
- TaskQueueSize: a.taskQueue.Len(),
- MessageQueueSize: a.messageQueue.Len(),
- }
- }
- // Initialize initializes the test agent
- func (a *TestAgent) Initialize() error {
- mcpTools, err := a.mcpClient.ListTools()
- if err != nil {
- return err
- }
- a.tools = ConvertMCPToolsToOpenAI(mcpTools)
- return nil
- }
- // ProcessEvent processes an event
- func (a *TestAgent) ProcessEvent(ctx context.Context, uri string, eventData json.RawMessage) error {
- prompt := buildTestEventPrompt(uri, eventData)
- messages := []openai.ChatCompletionMessage{
- {Role: openai.ChatMessageRoleSystem, Content: testSystemPrompt},
- {Role: openai.ChatMessageRoleUser, Content: prompt},
- }
- return a.processWithTools(ctx, messages)
- }
- // Run runs the agent interactively
- func (a *TestAgent) Run(ctx context.Context, userMessage string) (string, error) {
- messages := []openai.ChatCompletionMessage{
- {Role: openai.ChatMessageRoleSystem, Content: testSystemPrompt},
- {Role: openai.ChatMessageRoleUser, Content: userMessage},
- }
- response, err := a.llm.Chat(ctx, messages, a.tools)
- if err != nil {
- return "", err
- }
- if len(response.ToolCalls) == 0 {
- return response.Content, nil
- }
- // Handle tool calls
- messages = append(messages, *response)
- for _, toolCall := range response.ToolCalls {
- name, args, err := ParseToolCall(toolCall)
- if err != nil {
- continue
- }
- result, err := a.mcpClient.CallTool(name, args)
- if err != nil {
- continue
- }
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleTool,
- ToolCallID: toolCall.ID,
- Content: extractTextFromResult(result),
- })
- }
- finalResponse, err := a.llm.Chat(ctx, messages, a.tools)
- if err != nil {
- return "", err
- }
- return finalResponse.Content, nil
- }
- func (a *TestAgent) processWithTools(ctx context.Context, messages []openai.ChatCompletionMessage) error {
- response, err := a.llm.Chat(ctx, messages, a.tools)
- if err != nil {
- return err
- }
- if len(response.ToolCalls) == 0 {
- return nil
- }
- messages = append(messages, *response)
- for _, toolCall := range response.ToolCalls {
- name, args, err := ParseToolCall(toolCall)
- if err != nil {
- continue
- }
- result, err := a.mcpClient.CallTool(name, args)
- if err != nil {
- continue
- }
- messages = append(messages, openai.ChatCompletionMessage{
- Role: openai.ChatMessageRoleTool,
- ToolCallID: toolCall.ID,
- Content: extractTextFromResult(result),
- })
- }
- return a.processWithTools(ctx, messages)
- }
- // buildTestEventPrompt builds a prompt from the event data
- func buildTestEventPrompt(uri string, eventData json.RawMessage) string {
- eventType := "unknown"
- if strings.Contains(uri, "taskCreated") {
- eventType = "task created"
- } else if strings.Contains(uri, "taskUpdated") {
- eventType = "task updated"
- } else if strings.Contains(uri, "taskDeleted") {
- eventType = "task deleted"
- } else if strings.Contains(uri, "messageAdded") {
- eventType = "message added"
- }
- eventStr := "{}"
- if len(eventData) > 0 {
- eventStr = string(eventData)
- }
- return fmt.Sprintf(`A %s event was received.
- Event URI: %s
- Event Data: %s
- Please process this event appropriately.`, eventType, uri, eventStr)
- }
- // testNormalizeJSON normalizes JSON for snapshot comparison
- func testNormalizeJSON(jsonStr string) string {
- var data interface{}
- if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
- return jsonStr
- }
- testNormalizeData(data)
- bytes, _ := json.MarshalIndent(data, "", " ")
- return string(bytes)
- }
- // testNormalizeData recursively normalizes data structures
- func testNormalizeData(data interface{}) {
- switch v := data.(type) {
- case map[string]interface{}:
- delete(v, "id")
- delete(v, "ID")
- delete(v, "createdAt")
- delete(v, "updatedAt")
- delete(v, "sentAt")
- delete(v, "createdByID")
- delete(v, "userId")
- delete(v, "serviceId")
- delete(v, "statusId")
- delete(v, "assigneeId")
- delete(v, "conversationId")
- delete(v, "senderId")
- delete(v, "password")
- for _, val := range v {
- testNormalizeData(val)
- }
- case []interface{}:
- for _, item := range v {
- testNormalizeData(item)
- }
- }
- }
- // testSnapshotResult captures a snapshot of the result
- func testSnapshotResult(t *testing.T, name string, response interface{}) {
- jsonBytes, _ := json.MarshalIndent(response, "", " ")
- normalized := testNormalizeJSON(string(jsonBytes))
- testSnapshotter.SnapshotT(t, name, normalized)
- }
|