diff --git a/internal/server/mcp.go b/internal/server/mcp.go index 9c08b7ad..a0f12dbc 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -8,6 +8,7 @@ import ( "regexp" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -109,6 +110,13 @@ type MCPProxyServer struct { // Hooks shared across all routing mode servers hooks *mcpserver.Hooks + // directToolPerms maps direct-mode tool names (server__tool) to the + // operation permission required to call them. It is populated with the + // direct-mode registry and used only to filter tools/list for scoped agent + // tokens; execution-time authorization remains authoritative. + directToolPermsMu sync.RWMutex + directToolPerms map[string]string + // Spec 049: in-memory only counter of retrieve_tools calls that opted into // include_disabled. Never persisted (privacy, consistent with Spec 042). includeDisabledCalls atomic.Int64 diff --git a/internal/server/mcp_direct_scope.go b/internal/server/mcp_direct_scope.go new file mode 100644 index 00000000..ee24131c --- /dev/null +++ b/internal/server/mcp_direct_scope.go @@ -0,0 +1,90 @@ +package server + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/auth" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/contracts" +) + +func requiredPermissionForDirectTool(annotations *config.ToolAnnotations) string { + switch contracts.DeriveCallWith(annotations) { + case contracts.ToolVariantWrite: + return auth.PermWrite + case contracts.ToolVariantDestructive: + return auth.PermDestructive + default: + return auth.PermRead + } +} + +func (p *MCPProxyServer) setDirectToolPermissions(perms map[string]string) { + p.directToolPermsMu.Lock() + defer p.directToolPermsMu.Unlock() + + if len(perms) == 0 { + p.directToolPerms = nil + return + } + + copied := make(map[string]string, len(perms)) + for name, perm := range perms { + copied[name] = perm + } + p.directToolPerms = copied +} + +func (p *MCPProxyServer) lookupDirectToolPermission(directName string) (string, bool) { + p.directToolPermsMu.RLock() + defer p.directToolPermsMu.RUnlock() + + perm, ok := p.directToolPerms[directName] + return perm, ok +} + +// filterDirectModeToolsForAuth filters tools/list for scoped agent tokens. +// +// Direct mode registers upstream tools globally as server__tool. Without this +// filter, scoped agent tokens prevent execution but still disclose tool names, +// descriptions, and schemas for servers outside their scope. Call-time auth is +// still authoritative; this filter only removes tools that the current token +// could not call from discovery responses. +func (p *MCPProxyServer) filterDirectModeToolsForAuth(ctx context.Context, tools []mcp.Tool) []mcp.Tool { + if len(tools) == 0 { + return tools + } + + authCtx := auth.AuthContextFromContext(ctx) + if authCtx == nil || authCtx.Type != auth.AuthTypeAgent { + return tools + } + + filtered := make([]mcp.Tool, 0, len(tools)) + for _, tool := range tools { + serverName, _, ok := ParseDirectToolName(tool.Name) + if !ok { + filtered = append(filtered, tool) + continue + } + + if !authCtx.CanAccessServer(serverName) { + continue + } + + requiredPerm, ok := p.lookupDirectToolPermission(tool.Name) + if !ok { + continue + } + + if requiredPerm != "" && !authCtx.HasPermission(requiredPerm) { + continue + } + + filtered = append(filtered, tool) + } + + return filtered +} diff --git a/internal/server/mcp_routing.go b/internal/server/mcp_routing.go index 2ea99ef7..f078c2bf 100644 --- a/internal/server/mcp_routing.go +++ b/internal/server/mcp_routing.go @@ -48,13 +48,16 @@ func (p *MCPProxyServer) buildDirectModeTools() []mcpserver.ServerTool { // Use DiscoverTools which already filters for connected, enabled, non-quarantined servers tools, err := p.upstreamManager.DiscoverTools(ctx) if err != nil { + p.setDirectToolPermissions(nil) p.logger.Error("failed to discover tools for direct mode", zap.Error(err)) return nil } serverTools := make([]mcpserver.ServerTool, 0, len(tools)) + directToolPerms := make(map[string]string, len(tools)) for _, tool := range tools { directName := FormatDirectToolName(tool.ServerName, tool.Name) + directToolPerms[directName] = requiredPermissionForDirectTool(tool.Annotations) // Build MCP tool options opts := []mcp.ToolOption{ @@ -110,6 +113,8 @@ func (p *MCPProxyServer) buildDirectModeTools() []mcpserver.ServerTool { }) } + p.setDirectToolPermissions(directToolPerms) + p.logger.Info("built direct mode tools", zap.Int("tool_count", len(serverTools))) @@ -483,6 +488,7 @@ func (p *MCPProxyServer) initRoutingModeServers() { opts := []mcpserver.ServerOption{ mcpserver.WithToolCapabilities(true), mcpserver.WithRecovery(), + mcpserver.WithToolFilter(p.filterDirectModeToolsForAuth), } if p.hooks != nil { opts = append(opts, mcpserver.WithHooks(p.hooks)) diff --git a/internal/server/mcp_routing_test.go b/internal/server/mcp_routing_test.go index cb6cf5ce..69d7594d 100644 --- a/internal/server/mcp_routing_test.go +++ b/internal/server/mcp_routing_test.go @@ -366,6 +366,204 @@ func TestDirectModeHandler_DestructiveToolNeedsDestructivePermission(t *testing. assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "destructive") } +func TestRequiredPermissionForDirectTool_MapsAnnotationsToAuthPermissions(t *testing.T) { + readOnly := true + write := false + destructive := true + + tests := []struct { + name string + annotations *config.ToolAnnotations + want string + }{ + { + name: "nil annotations default to read", + want: auth.PermRead, + }, + { + name: "read only hint maps to read", + annotations: &config.ToolAnnotations{ + ReadOnlyHint: &readOnly, + }, + want: auth.PermRead, + }, + { + name: "read only false maps to write", + annotations: &config.ToolAnnotations{ + ReadOnlyHint: &write, + }, + want: auth.PermWrite, + }, + { + name: "destructive hint maps to destructive", + annotations: &config.ToolAnnotations{ + DestructiveHint: &destructive, + }, + want: auth.PermDestructive, + }, + { + name: "destructive hint takes precedence over read only hint", + annotations: &config.ToolAnnotations{ + ReadOnlyHint: &readOnly, + DestructiveHint: &destructive, + }, + want: auth.PermDestructive, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, requiredPermissionForDirectTool(tt.annotations)) + }) + } +} + +func TestSetDirectToolPermissions_DefensivelyCopiesMap(t *testing.T) { + proxy := &MCPProxyServer{} + toolName := FormatDirectToolName("github", "get_issue") + perms := map[string]string{ + toolName: auth.PermRead, + } + + proxy.setDirectToolPermissions(perms) + perms[toolName] = auth.PermDestructive + + got, ok := proxy.lookupDirectToolPermission(toolName) + require.True(t, ok) + assert.Equal(t, auth.PermRead, got) +} + +func TestFilterDirectModeToolsForAuth_DoesNotMutateInputSlice(t *testing.T) { + proxy := &MCPProxyServer{} + allowed := FormatDirectToolName("github", "get_issue") + denied := FormatDirectToolName("gitlab", "get_issue") + tools := []mcp.Tool{ + {Name: allowed}, + {Name: denied}, + } + original := append([]mcp.Tool(nil), tools...) + + proxy.setDirectToolPermissions(map[string]string{ + allowed: auth.PermRead, + denied: auth.PermRead, + }) + + ctx := auth.WithAuthContext(context.Background(), &auth.AuthContext{ + Type: auth.AuthTypeAgent, + AgentName: "test-agent", + AllowedServers: []string{"github"}, + Permissions: []string{auth.PermRead}, + }) + + filtered := proxy.filterDirectModeToolsForAuth(ctx, tools) + + assert.Equal(t, []string{allowed}, directToolNamesForTest(filtered)) + assert.Equal(t, original, tools) +} + +func TestFilterDirectModeToolsForAuth_AgentServerAndPermissionScope(t *testing.T) { + proxy := &MCPProxyServer{} + + githubRead := FormatDirectToolName("github", "get_issue") + githubWrite := FormatDirectToolName("github", "create_issue") + githubDestroy := FormatDirectToolName("github", "delete_repo") + gitlabRead := FormatDirectToolName("gitlab", "get_issue") + + proxy.setDirectToolPermissions(map[string]string{ + githubRead: auth.PermRead, + githubWrite: auth.PermWrite, + githubDestroy: auth.PermDestructive, + gitlabRead: auth.PermRead, + }) + + tools := []mcp.Tool{ + {Name: githubRead}, + {Name: githubWrite}, + {Name: githubDestroy}, + {Name: gitlabRead}, + } + + agentCtx := auth.WithAuthContext(context.Background(), &auth.AuthContext{ + Type: auth.AuthTypeAgent, + AgentName: "test-agent", + AllowedServers: []string{"github"}, + Permissions: []string{auth.PermRead, auth.PermWrite}, + }) + + filtered := proxy.filterDirectModeToolsForAuth(agentCtx, tools) + + assert.Equal(t, []string{githubRead, githubWrite}, directToolNamesForTest(filtered)) +} + +func TestFilterDirectModeToolsForAuth_NonAgentUnchanged(t *testing.T) { + proxy := &MCPProxyServer{} + tools := []mcp.Tool{ + {Name: FormatDirectToolName("github", "get_issue")}, + {Name: FormatDirectToolName("gitlab", "get_issue")}, + } + + assert.Equal(t, tools, proxy.filterDirectModeToolsForAuth(context.Background(), tools)) + + adminCtx := auth.WithAuthContext(context.Background(), auth.AdminContext()) + assert.Equal(t, tools, proxy.filterDirectModeToolsForAuth(adminCtx, tools)) +} + +func TestFilterDirectModeToolsForAuth_FailsClosedOnMissingPermissionMetadata(t *testing.T) { + proxy := &MCPProxyServer{} + + visible := FormatDirectToolName("github", "get_issue") + missing := FormatDirectToolName("github", "unknown") + proxy.setDirectToolPermissions(map[string]string{ + visible: auth.PermRead, + }) + + ctx := auth.WithAuthContext(context.Background(), &auth.AuthContext{ + Type: auth.AuthTypeAgent, + AgentName: "test-agent", + AllowedServers: []string{"github"}, + Permissions: []string{auth.PermRead}, + }) + + filtered := proxy.filterDirectModeToolsForAuth(ctx, []mcp.Tool{ + {Name: visible}, + {Name: missing}, + }) + + assert.Equal(t, []string{visible}, directToolNamesForTest(filtered)) +} + +func TestFilterDirectModeToolsForAuth_KeepsNonDirectTools(t *testing.T) { + proxy := &MCPProxyServer{} + + direct := FormatDirectToolName("github", "get_issue") + nonDirect := "retrieve_tools" + proxy.setDirectToolPermissions(map[string]string{ + direct: auth.PermRead, + }) + + ctx := auth.WithAuthContext(context.Background(), &auth.AuthContext{ + Type: auth.AuthTypeAgent, + AgentName: "test-agent", + AllowedServers: []string{"github"}, + Permissions: []string{auth.PermRead}, + }) + + filtered := proxy.filterDirectModeToolsForAuth(ctx, []mcp.Tool{ + {Name: direct}, + {Name: nonDirect}, + }) + + assert.Equal(t, []string{direct, nonDirect}, directToolNamesForTest(filtered)) +} + +func directToolNamesForTest(tools []mcp.Tool) []string { + names := make([]string, 0, len(tools)) + for _, tool := range tools { + names = append(names, tool.Name) + } + return names +} + func TestDirectModeHandler_NoAuthContext(t *testing.T) { logger, _ := zap.NewDevelopment() proxy := &MCPProxyServer{