mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-29 03:24:07 +00:00
299 lines
8.9 KiB
Go
299 lines
8.9 KiB
Go
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
//
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
// Example mcp-server-auth demonstrates how an MCP (Model Context Protocol)
|
|
// server verifies OAuth 2.1 access tokens from MCP clients.
|
|
//
|
|
// MCP uses OAuth 2.1 for authorization per the spec. This example shows:
|
|
// - Bearer token extraction from Authorization headers
|
|
// - JWT signature verification and claims validation
|
|
// - Scope-based access control for MCP operations
|
|
// - A JSON-RPC handler that lists tools or executes them based on granted scopes
|
|
//
|
|
// Run the server, then use the printed curl commands to test each scope level.
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/therootcompany/golib/auth/jwt"
|
|
)
|
|
|
|
// MCP scope values for access control.
|
|
const (
|
|
scopeRead = "mcp:read"
|
|
scopeWrite = "mcp:write"
|
|
scopeAdmin = "mcp:admin"
|
|
)
|
|
|
|
// toolDef is the single source of truth for tool name, description, and
|
|
// required scope. Both tools/list and tools/call use this registry.
|
|
type toolDef struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Scope string `json:"-"` // minimum scope required to see/call this tool
|
|
}
|
|
|
|
// toolRegistry is the canonical list of tools this server exposes.
|
|
var toolRegistry = []toolDef{
|
|
{"search", "Search the knowledge base", scopeRead},
|
|
{"summarize", "Summarize a document", scopeRead},
|
|
{"create_document", "Create a new document", scopeWrite},
|
|
{"manage_users", "Add or remove users from the workspace", scopeAdmin},
|
|
}
|
|
|
|
// --- Context accessors (two lines of code -- no library support required) ---
|
|
|
|
type contextKey string
|
|
|
|
const claimsKey contextKey = "claims"
|
|
|
|
// WithClaims returns a new context carrying the given token claims.
|
|
func WithClaims(ctx context.Context, c *jwt.TokenClaims) context.Context {
|
|
return context.WithValue(ctx, claimsKey, c)
|
|
}
|
|
|
|
// ClaimsFromContext extracts claims stashed by the auth middleware.
|
|
func ClaimsFromContext(ctx context.Context) (*jwt.TokenClaims, bool) {
|
|
c, ok := ctx.Value(claimsKey).(*jwt.TokenClaims)
|
|
return c, ok
|
|
}
|
|
|
|
// --- JSON-RPC types (minimal subset of the MCP protocol) ---
|
|
|
|
// JSONRPCRequest is a minimal JSON-RPC 2.0 request envelope.
|
|
type JSONRPCRequest struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID json.RawMessage `json:"id"`
|
|
Method string `json:"method"`
|
|
Params json.RawMessage `json:"params,omitempty"`
|
|
}
|
|
|
|
// JSONRPCResponse is a minimal JSON-RPC 2.0 response envelope.
|
|
type JSONRPCResponse struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID json.RawMessage `json:"id"`
|
|
Result any `json:"result,omitempty"`
|
|
Error *RPCError `json:"error,omitempty"`
|
|
}
|
|
|
|
// RPCError represents a JSON-RPC 2.0 error object.
|
|
type RPCError struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
// Application-level JSON-RPC error codes (outside the -32768..-32000 reserved range).
|
|
const (
|
|
errCodeForbidden = -31403 // insufficient scope
|
|
)
|
|
|
|
func main() {
|
|
// --- Setup: self-signed key pair for demonstration ---
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
verifier := signer.Verifier()
|
|
|
|
// The validator checks standard access token claims per RFC 9068.
|
|
// In production, iss and aud would match your authorization server
|
|
// and MCP server resource identifier.
|
|
validator := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"}, // expected issuers
|
|
[]string{"https://mcp.example.com/jsonrpc"}, // expected audiences
|
|
0, // grace period (0 = default 2s)
|
|
)
|
|
|
|
// --- Auth middleware: verify signature + validate claims ---
|
|
authMiddleware := func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
auth := r.Header.Get("Authorization")
|
|
if !strings.HasPrefix(auth, "Bearer ") {
|
|
http.Error(w, "missing bearer token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
tokenStr := strings.TrimPrefix(auth, "Bearer ")
|
|
|
|
jws, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
http.Error(w, "bad token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if err := verifier.Verify(jws); err != nil {
|
|
http.Error(w, "invalid signature", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
var claims jwt.TokenClaims
|
|
if err := jws.UnmarshalClaims(&claims); err != nil {
|
|
http.Error(w, "bad claims", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if err := validator.Validate(nil, &claims, time.Now()); err != nil {
|
|
http.Error(w, "invalid claims: "+err.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r.WithContext(WithClaims(r.Context(), &claims)))
|
|
})
|
|
}
|
|
|
|
// --- MCP JSON-RPC handler ---
|
|
mcpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
claims, ok := ClaimsFromContext(r.Context())
|
|
if !ok {
|
|
http.Error(w, "no claims in context", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var req JSONRPCRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeRPCError(w, []byte("null"), -32700, "parse error")
|
|
return
|
|
}
|
|
if req.JSONRPC != "2.0" {
|
|
writeRPCError(w, req.ID, -32600, "invalid request: expected jsonrpc 2.0")
|
|
return
|
|
}
|
|
|
|
switch req.Method {
|
|
case "tools/list":
|
|
handleToolsList(w, req, claims)
|
|
case "tools/call":
|
|
handleToolsCall(w, req, claims)
|
|
default:
|
|
writeRPCError(w, req.ID, -32601, fmt.Sprintf("method not found: %s", req.Method))
|
|
}
|
|
})
|
|
|
|
mux := http.NewServeMux()
|
|
mux.Handle("/mcp", authMiddleware(mcpHandler))
|
|
|
|
// --- Mint demo tokens at each scope level ---
|
|
now := time.Now()
|
|
scopes := []struct {
|
|
label string
|
|
scope jwt.SpaceDelimited
|
|
}{
|
|
{"read-only", jwt.SpaceDelimited{scopeRead}},
|
|
{"read-write", jwt.SpaceDelimited{scopeRead, scopeWrite}},
|
|
{"admin", jwt.SpaceDelimited{scopeRead, scopeWrite, scopeAdmin}},
|
|
}
|
|
|
|
fmt.Println("MCP server listening on :8080")
|
|
fmt.Println()
|
|
for _, s := range scopes {
|
|
token, err := signer.SignToString(&jwt.TokenClaims{
|
|
Iss: "https://auth.example.com",
|
|
Sub: "client-agent-1",
|
|
Aud: jwt.Listish{"https://mcp.example.com/jsonrpc"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
JTI: fmt.Sprintf("tok-%s", s.label),
|
|
ClientID: "mcp-client-demo",
|
|
Scope: s.scope,
|
|
})
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
fmt.Printf("--- %s token ---\n", s.label)
|
|
fmt.Printf("curl -s -X POST http://localhost:8080/mcp \\\n")
|
|
fmt.Printf(" -H 'Authorization: Bearer %s' \\\n", token)
|
|
fmt.Printf(" -H 'Content-Type: application/json' \\\n")
|
|
fmt.Printf(" -d '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/list\"}'\n\n")
|
|
}
|
|
|
|
log.Fatal(http.ListenAndServe(":8080", mux))
|
|
}
|
|
|
|
// handleToolsList returns the tools visible to the caller based on scopes.
|
|
func handleToolsList(w http.ResponseWriter, req JSONRPCRequest, claims *jwt.TokenClaims) {
|
|
var visible []toolDef
|
|
for _, td := range toolRegistry {
|
|
if hasScope(claims, td.Scope) {
|
|
visible = append(visible, td)
|
|
}
|
|
}
|
|
|
|
writeRPCResult(w, req.ID, map[string]any{
|
|
"tools": visible,
|
|
})
|
|
}
|
|
|
|
// CallParams holds the parameters for a tools/call request.
|
|
type CallParams struct {
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
// handleToolsCall executes a tool if the caller has the required scope.
|
|
func handleToolsCall(w http.ResponseWriter, req JSONRPCRequest, claims *jwt.TokenClaims) {
|
|
var params CallParams
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
writeRPCError(w, req.ID, -32602, "invalid params")
|
|
return
|
|
}
|
|
|
|
// Look up the tool in the single registry.
|
|
var found *toolDef
|
|
for i := range toolRegistry {
|
|
if toolRegistry[i].Name == params.Name {
|
|
found = &toolRegistry[i]
|
|
break
|
|
}
|
|
}
|
|
if found == nil {
|
|
writeRPCError(w, req.ID, -32602, fmt.Sprintf("unknown tool: %s", params.Name))
|
|
return
|
|
}
|
|
if !hasScope(claims, found.Scope) {
|
|
writeRPCError(w, req.ID, errCodeForbidden, fmt.Sprintf("insufficient scope: %s required", found.Scope))
|
|
return
|
|
}
|
|
|
|
writeRPCResult(w, req.ID, map[string]any{
|
|
"content": []map[string]string{
|
|
{"type": "text", "text": fmt.Sprintf("executed %s for %s", params.Name, claims.Sub)},
|
|
},
|
|
})
|
|
}
|
|
|
|
// hasScope checks whether the token's scope claim contains the given value.
|
|
func hasScope(claims *jwt.TokenClaims, scope string) bool {
|
|
return slices.Contains(claims.Scope, scope)
|
|
}
|
|
|
|
func writeRPCResult(w http.ResponseWriter, id json.RawMessage, result any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(JSONRPCResponse{
|
|
JSONRPC: "2.0",
|
|
ID: id,
|
|
Result: result,
|
|
})
|
|
}
|
|
|
|
func writeRPCError(w http.ResponseWriter, id json.RawMessage, code int, message string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(JSONRPCResponse{
|
|
JSONRPC: "2.0",
|
|
ID: id,
|
|
Error: &RPCError{Code: code, Message: message},
|
|
})
|
|
}
|