| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638 |
- package mcp
- import (
- "context"
- "encoding/json"
- "fmt"
- "strings"
- "testing"
- "time"
- "github.com/bradleyjkemp/cupaloy/v2"
- "github.com/vektah/gqlparser/v2/ast"
- "gogs.dmsc.dev/arp/auth"
- "gogs.dmsc.dev/arp/graph"
- "gogs.dmsc.dev/arp/graph/testutil"
- "gorm.io/gorm"
- )
- var snapshotter = cupaloy.New(cupaloy.SnapshotSubdirectory("testdata/snapshots"))
- // MCPTestClient wraps the MCP server for testing
- type MCPTestClient struct {
- server *Server
- db *gorm.DB
- schema *ast.Schema
- session *Session
- user *auth.UserContext
- }
- // IDTracker tracks entity IDs created during tests
- type IDTracker struct {
- Permissions map[string]string
- Roles map[string]string
- Users map[string]string
- TaskStatuses map[string]string
- Services map[string]string
- Tasks map[string]string
- Notes map[string]string
- Messages []string
- }
- func NewIDTracker() *IDTracker {
- return &IDTracker{
- Permissions: make(map[string]string),
- Roles: make(map[string]string),
- Users: make(map[string]string),
- TaskStatuses: make(map[string]string),
- Services: make(map[string]string),
- Tasks: make(map[string]string),
- Notes: make(map[string]string),
- Messages: make([]string, 0),
- }
- }
- // setupMCPTestClient creates a test client with bootstrapped database
- func setupMCPTestClient(t *testing.T) (*MCPTestClient, *IDTracker) {
- db, err := testutil.SetupAndBootstrapTestDB()
- if err != nil {
- t.Fatalf("Failed to setup test database: %v", err)
- }
- resolver := graph.NewResolver(db)
- schema := graph.NewExecutableSchema(graph.Config{Resolvers: resolver})
- astSchema := schema.Schema()
- server := NewServer(resolver, astSchema)
- adminUser := &auth.UserContext{
- ID: 1,
- Email: "admin@example.com",
- Roles: []auth.RoleClaim{{ID: 1, Name: "admin"}},
- Permissions: []string{
- "user:read", "user:write",
- "task:read", "task:write",
- "service:read", "service:write",
- "note:read", "note:write",
- },
- }
- session := &Session{
- ID: "test-session-id",
- User: adminUser,
- Events: make(chan []byte, 100),
- Done: make(chan struct{}),
- Subscriptions: make(map[string]context.CancelFunc),
- }
- tracker := NewIDTracker()
- tracker.Users["admin@example.com"] = "1"
- var perms []struct {
- ID uint
- Code string
- }
- db.Table("permissions").Find(&perms)
- for _, perm := range perms {
- tracker.Permissions[perm.Code] = fmt.Sprintf("%d", perm.ID)
- }
- var roles []struct {
- ID uint
- Name string
- }
- db.Table("roles").Find(&roles)
- for _, role := range roles {
- tracker.Roles[role.Name] = fmt.Sprintf("%d", role.ID)
- }
- var statuses []struct {
- ID uint
- Code string
- }
- db.Table("task_statuses").Find(&statuses)
- for _, status := range statuses {
- tracker.TaskStatuses[status.Code] = fmt.Sprintf("%d", status.ID)
- }
- return &MCPTestClient{
- server: server,
- db: db,
- schema: astSchema,
- session: session,
- user: adminUser,
- }, tracker
- }
- func (tc *MCPTestClient) callTool(ctx context.Context, name string, args map[string]interface{}) *JSONRPCResponse {
- argsJSON, _ := json.Marshal(args)
- params := json.RawMessage(fmt.Sprintf(`{"name": "%s", "arguments": %s}`, name, string(argsJSON)))
- req := &JSONRPCRequest{
- JSONRPC: "2.0",
- ID: fmt.Sprintf("test-%d", time.Now().UnixNano()),
- Method: "tools/call",
- Params: params,
- }
- // Inject user into context
- ctxWithUser := auth.WithUser(ctx, tc.user)
- return tc.server.handleRequest(ctxWithUser, req, tc.session)
- }
- func (tc *MCPTestClient) callMethod(ctx context.Context, method string, params json.RawMessage) *JSONRPCResponse {
- req := &JSONRPCRequest{
- JSONRPC: "2.0",
- ID: fmt.Sprintf("test-%d", time.Now().UnixNano()),
- Method: method,
- Params: params,
- }
- // Inject user into context
- ctxWithUser := auth.WithUser(ctx, tc.user)
- return tc.server.handleRequest(ctxWithUser, req, tc.session)
- }
- func (tc *MCPTestClient) subscribeToResource(ctx context.Context, uri string) *JSONRPCResponse {
- params := json.RawMessage(fmt.Sprintf(`{"uri": "%s"}`, uri))
- return tc.callMethod(ctx, "resources/subscribe", params)
- }
- func (tc *MCPTestClient) unsubscribeFromResource(ctx context.Context, uri string) *JSONRPCResponse {
- params := json.RawMessage(fmt.Sprintf(`{"uri": "%s"}`, uri))
- return tc.callMethod(ctx, "resources/unsubscribe", params)
- }
- func normalizeJSON(jsonStr string) string {
- var data interface{}
- if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
- return jsonStr
- }
- normalizeData(data)
- bytes, _ := json.MarshalIndent(data, "", " ")
- return string(bytes)
- }
- func normalizeData(data interface{}) {
- switch v := data.(type) {
- case map[string]interface{}:
- delete(v, "id")
- delete(v, "ID")
- delete(v, "createdAt")
- delete(v, "updatedAt")
- delete(v, "sentAt")
- delete(v, "createdByID")
- delete(v, "userId")
- delete(v, "serviceId")
- delete(v, "statusId")
- delete(v, "assigneeId")
- delete(v, "conversationId")
- delete(v, "senderId")
- delete(v, "password") // Remove password hashes
- for key, val := range v {
- // Normalize embedded JSON strings in "text" field
- if key == "text" {
- if strVal, ok := val.(string); ok {
- var embedded interface{}
- if err := json.Unmarshal([]byte(strVal), &embedded); err == nil {
- normalizeData(embedded)
- if embeddedBytes, err := json.Marshal(embedded); err == nil {
- v[key] = string(embeddedBytes)
- continue
- }
- }
- }
- }
- normalizeData(val)
- }
- case []interface{}:
- for _, item := range v {
- normalizeData(item)
- }
- }
- }
- func snapshotResult(t *testing.T, name string, response *JSONRPCResponse) {
- jsonBytes, _ := json.MarshalIndent(response, "", " ")
- normalized := normalizeJSON(string(jsonBytes))
- snapshotter.SnapshotT(t, name, normalized)
- }
- // TestMCP_Initialize tests the initialize method
- func TestMCP_Initialize(t *testing.T) {
- tc, _ := setupMCPTestClient(t)
- ctx := context.Background()
- params := json.RawMessage(`{"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "test-client", "version": "1.0.0"}}`)
- response := tc.callMethod(ctx, "initialize", params)
- if response.Error != nil {
- t.Fatalf("Initialize failed: %v", response.Error)
- }
- result, ok := response.Result.(InitializeResult)
- if !ok {
- t.Fatalf("Expected InitializeResult, got %T", response.Result)
- }
- if result.ProtocolVersion != ProtocolVersion {
- t.Errorf("Expected protocol version %s, got %s", ProtocolVersion, result.ProtocolVersion)
- }
- if result.Capabilities.Tools == nil {
- t.Error("Expected tools capability to be present")
- }
- if result.Capabilities.Resources == nil {
- t.Error("Expected resources capability to be present")
- }
- if !result.Capabilities.Resources.Subscribe {
- t.Error("Expected resources.subscribe to be true")
- }
- snapshotResult(t, "initialize", response)
- }
- // TestMCP_ToolsList tests the tools/list method
- func TestMCP_ToolsList(t *testing.T) {
- tc, _ := setupMCPTestClient(t)
- ctx := context.Background()
- response := tc.callMethod(ctx, "tools/list", nil)
- if response.Error != nil {
- t.Fatalf("tools/list failed: %v", response.Error)
- }
- snapshotResult(t, "tools_list", response)
- }
- // TestMCP_ResourcesList tests the resources/list method
- func TestMCP_ResourcesList(t *testing.T) {
- tc, _ := setupMCPTestClient(t)
- ctx := context.Background()
- response := tc.callMethod(ctx, "resources/list", nil)
- if response.Error != nil {
- t.Fatalf("resources/list failed: %v", response.Error)
- }
- snapshotResult(t, "resources_list", response)
- }
- // TestMCP_Introspect tests the introspect tool
- func TestMCP_Introspect(t *testing.T) {
- tc, _ := setupMCPTestClient(t)
- ctx := context.Background()
- t.Run("FullSchema", func(t *testing.T) {
- response := tc.callTool(ctx, "introspect", map[string]interface{}{})
- if response.Error != nil {
- t.Fatalf("introspect failed: %v", response.Error)
- }
- // Verify the response contains expected content (skip snapshot due to non-deterministic ordering)
- jsonBytes, _ := json.Marshal(response.Result)
- var result map[string]interface{}
- if err := json.Unmarshal(jsonBytes, &result); err != nil {
- t.Fatalf("Failed to unmarshal result: %v", err)
- }
- content, ok := result["content"].([]interface{})
- if !ok || len(content) == 0 {
- t.Fatal("Expected content array with at least one item")
- }
- text, ok := content[0].(map[string]interface{})["text"].(string)
- if !ok {
- t.Fatal("Expected text field in content")
- }
- // Verify key sections are present
- expectedSections := []string{"Query Type", "Mutation Type", "Subscription Type", "Object Types", "Input Types"}
- for _, section := range expectedSections {
- if !strings.Contains(text, section) {
- t.Errorf("Expected section '%s' in introspection result", section)
- }
- }
- })
- t.Run("QueryType", func(t *testing.T) {
- response := tc.callTool(ctx, "introspect", map[string]interface{}{
- "typeName": "Query",
- })
- if response.Error != nil {
- t.Fatalf("introspect Query failed: %v", response.Error)
- }
- snapshotResult(t, "introspect_query", response)
- })
- t.Run("UserType", func(t *testing.T) {
- response := tc.callTool(ctx, "introspect", map[string]interface{}{
- "typeName": "User",
- })
- if response.Error != nil {
- t.Fatalf("introspect User failed: %v", response.Error)
- }
- snapshotResult(t, "introspect_user", response)
- })
- }
- // TestMCP_Query tests the query tool
- func TestMCP_Query(t *testing.T) {
- tc, _ := setupMCPTestClient(t)
- ctx := context.Background()
- t.Run("Users", func(t *testing.T) {
- response := tc.callTool(ctx, "query", map[string]interface{}{
- "query": "query { users { email roles { name } } }",
- })
- if response.Error != nil {
- t.Fatalf("query users failed: %v", response.Error)
- }
- snapshotResult(t, "query_users", response)
- })
- t.Run("Tasks", func(t *testing.T) {
- response := tc.callTool(ctx, "query", map[string]interface{}{
- "query": "query { tasks { title content priority } }",
- })
- if response.Error != nil {
- t.Fatalf("query tasks failed: %v", response.Error)
- }
- snapshotResult(t, "query_tasks", response)
- })
- t.Run("Services", func(t *testing.T) {
- response := tc.callTool(ctx, "query", map[string]interface{}{
- "query": "query { services { name description } }",
- })
- if response.Error != nil {
- t.Fatalf("query services failed: %v", response.Error)
- }
- snapshotResult(t, "query_services", response)
- })
- t.Run("Roles", func(t *testing.T) {
- response := tc.callTool(ctx, "query", map[string]interface{}{
- "query": "query { roles { name description permissions { code } } }",
- })
- if response.Error != nil {
- t.Fatalf("query roles failed: %v", response.Error)
- }
- snapshotResult(t, "query_roles", response)
- })
- t.Run("InvalidQuery", func(t *testing.T) {
- response := tc.callTool(ctx, "query", map[string]interface{}{
- "query": "query { nonexistent { id } }",
- })
- if response.Error != nil {
- t.Fatalf("query failed: %v", response.Error)
- }
- // The result is a CallToolResult with isError=true
- jsonBytes, _ := json.Marshal(response.Result)
- var result map[string]interface{}
- if err := json.Unmarshal(jsonBytes, &result); err != nil {
- t.Fatalf("Failed to unmarshal result: %v", err)
- }
- if isError, ok := result["isError"].(bool); !ok || !isError {
- t.Error("Expected isError to be true for invalid query")
- }
- snapshotResult(t, "query_invalid", response)
- })
- }
- // TestMCP_Mutate tests the mutate tool
- func TestMCP_Mutate(t *testing.T) {
- tc, tracker := setupMCPTestClient(t)
- ctx := context.Background()
- t.Run("CreateUser", func(t *testing.T) {
- response := tc.callTool(ctx, "mutate", map[string]interface{}{
- "mutation": fmt.Sprintf(`mutation { createUser(input: {email: "newuser@example.com", password: "password123", roles: ["%s"]}) { email } }`, tracker.Roles["admin"]),
- })
- if response.Error != nil {
- t.Fatalf("createUser failed: %v", response.Error)
- }
- snapshotResult(t, "mutate_create_user", response)
- })
- t.Run("CreateTask", func(t *testing.T) {
- response := tc.callTool(ctx, "mutate", map[string]interface{}{
- "mutation": fmt.Sprintf(`mutation { createTask(input: {title: "Test Task", content: "Test content", createdById: "%s", statusId: "%s", priority: "high"}) { title content priority } }`, tracker.Users["admin@example.com"], tracker.TaskStatuses["open"]),
- })
- if response.Error != nil {
- t.Fatalf("createTask failed: %v", response.Error)
- }
- snapshotResult(t, "mutate_create_task", response)
- })
- t.Run("CreateNote", func(t *testing.T) {
- response := tc.callTool(ctx, "mutate", map[string]interface{}{
- "mutation": fmt.Sprintf(`mutation { createNote(input: {title: "Test Note", content: "Note content", userId: "%s"}) { title content } }`, tracker.Users["admin@example.com"]),
- })
- if response.Error != nil {
- t.Fatalf("createNote failed: %v", response.Error)
- }
- snapshotResult(t, "mutate_create_note", response)
- })
- }
- // TestMCP_Resources tests resource operations
- func TestMCP_Resources(t *testing.T) {
- tc, _ := setupMCPTestClient(t)
- ctx := context.Background()
- t.Run("Read", func(t *testing.T) {
- params := json.RawMessage(`{"uri": "graphql://subscription/taskCreated"}`)
- response := tc.callMethod(ctx, "resources/read", params)
- if response.Error != nil {
- t.Fatalf("resources/read failed: %v", response.Error)
- }
- snapshotResult(t, "resources_read", response)
- })
- t.Run("Subscribe", func(t *testing.T) {
- response := tc.subscribeToResource(ctx, "graphql://subscription/taskCreated")
- if response.Error != nil {
- t.Fatalf("resources/subscribe failed: %v", response.Error)
- }
- // Verify subscription was registered
- tc.session.SubsMu.RLock()
- _, ok := tc.session.Subscriptions["graphql://subscription/taskCreated"]
- tc.session.SubsMu.RUnlock()
- if !ok {
- t.Error("Expected subscription to be registered in session")
- }
- snapshotResult(t, "resources_subscribe", response)
- })
- t.Run("Unsubscribe", func(t *testing.T) {
- // First subscribe
- tc.subscribeToResource(ctx, "graphql://subscription/taskUpdated")
- // Then unsubscribe
- response := tc.unsubscribeFromResource(ctx, "graphql://subscription/taskUpdated")
- if response.Error != nil {
- t.Fatalf("resources/unsubscribe failed: %v", response.Error)
- }
- // Verify subscription was removed
- tc.session.SubsMu.RLock()
- _, ok := tc.session.Subscriptions["graphql://subscription/taskUpdated"]
- tc.session.SubsMu.RUnlock()
- if ok {
- t.Error("Expected subscription to be removed from session")
- }
- snapshotResult(t, "resources_unsubscribe", response)
- })
- }
- // TestMCP_SubscriptionNotifications tests that subscription notifications are sent
- func TestMCP_SubscriptionNotifications(t *testing.T) {
- tc, tracker := setupMCPTestClient(t)
- ctx := context.Background()
- // Subscribe to taskCreated
- _ = tc.subscribeToResource(ctx, "graphql://subscription/taskCreated")
- // Create a task assigned to admin user (ID 1)
- _ = tc.callTool(ctx, "mutate", map[string]interface{}{
- "mutation": fmt.Sprintf(`mutation { createTask(input: {title: "Notification Test Task", content: "Testing notifications", createdById: "%s", assigneeId: "%s", statusId: "%s", priority: "medium"}) { title } }`, tracker.Users["admin@example.com"], tracker.Users["admin@example.com"], tracker.TaskStatuses["open"]),
- })
- // Wait for notification
- select {
- case event := <-tc.session.Events:
- var notification JSONRPCNotification
- if err := json.Unmarshal(event, ¬ification); err != nil {
- t.Fatalf("Failed to unmarshal notification: %v", err)
- }
- // Verify it's a resource update notification
- if notification.Method != "notifications/resources/updated" {
- t.Errorf("Expected method 'notifications/resources/updated', got '%s'", notification.Method)
- }
- t.Logf("Received notification: %s", string(event))
- case <-time.After(2 * time.Second):
- t.Error("Timeout waiting for taskCreated notification")
- }
- }
- // TestMCP_Ping tests the ping method
- func TestMCP_Ping(t *testing.T) {
- tc, _ := setupMCPTestClient(t)
- ctx := context.Background()
- response := tc.callMethod(ctx, "ping", nil)
- if response.Error != nil {
- t.Fatalf("ping failed: %v", response.Error)
- }
- snapshotResult(t, "ping", response)
- }
- // TestMCP_Unauthenticated tests operations without authentication
- func TestMCP_Unauthenticated(t *testing.T) {
- db, err := testutil.SetupAndBootstrapTestDB()
- if err != nil {
- t.Fatalf("Failed to setup test database: %v", err)
- }
- resolver := graph.NewResolver(db)
- schema := graph.NewExecutableSchema(graph.Config{Resolvers: resolver})
- astSchema := schema.Schema()
- server := NewServer(resolver, astSchema)
- // Create session without user
- session := &Session{
- ID: "unauth-session",
- User: nil, // No user
- Events: make(chan []byte, 100),
- Done: make(chan struct{}),
- Subscriptions: make(map[string]context.CancelFunc),
- }
- ctx := context.Background()
- t.Run("SubscribeRequiresAuth", func(t *testing.T) {
- params := json.RawMessage(`{"uri": "graphql://subscription/taskCreated"}`)
- req := &JSONRPCRequest{
- JSONRPC: "2.0",
- ID: "test-unauth",
- Method: "resources/subscribe",
- Params: params,
- }
- response := server.handleRequest(ctx, req, session)
- if response.Error == nil {
- t.Error("Expected error for unauthenticated subscribe")
- }
- if !strings.Contains(response.Error.Message, "Authentication required") {
- t.Errorf("Expected authentication error, got: %s", response.Error.Message)
- }
- })
- t.Run("QueryWithoutAuth", func(t *testing.T) {
- // Query should work without auth (no permission checks in current implementation)
- argsJSON, _ := json.Marshal(map[string]interface{}{
- "query": "query { users { email } }",
- })
- params := json.RawMessage(fmt.Sprintf(`{"name": "query", "arguments": %s}`, string(argsJSON)))
- req := &JSONRPCRequest{
- JSONRPC: "2.0",
- ID: "test-unauth-query",
- Method: "tools/call",
- Params: params,
- }
- response := server.handleRequest(ctx, req, session)
- // Query should succeed (no auth required for basic queries)
- if response.Error != nil {
- t.Logf("Query returned error (may be expected): %v", response.Error)
- }
- })
- }
|