mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-29 03:24:07 +00:00
96 lines
2.6 KiB
Go
96 lines
2.6 KiB
Go
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
//
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
// Example cached-keys demonstrates persisting JWKS keys to disk so that a
|
|
// service can start verifying tokens immediately on restart without blocking
|
|
// on a network fetch.
|
|
//
|
|
// On startup, keys are loaded from a local file (if it exists) and passed
|
|
// as InitialKeys. After each Verifier() call, RefreshedAt is checked to
|
|
// detect updates, and keys are saved only when the sorted KIDs differ.
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/therootcompany/golib/auth/jwt"
|
|
"github.com/therootcompany/golib/auth/jwt/keyfetch"
|
|
"github.com/therootcompany/golib/auth/jwt/keyfile"
|
|
)
|
|
|
|
const (
|
|
jwksURL = "https://accounts.example.com/.well-known/jwks.json"
|
|
cacheFile = "jwks-cache.json"
|
|
)
|
|
|
|
func main() {
|
|
// Load cached keys from disk (if any).
|
|
initialKeys, err := loadCachedKeys(cacheFile)
|
|
if err != nil {
|
|
log.Printf("no cached keys: %v", err)
|
|
}
|
|
|
|
fetcher := &keyfetch.KeyFetcher{
|
|
URL: jwksURL,
|
|
RefreshTimeout: 10 * time.Second,
|
|
InitialKeys: initialKeys,
|
|
}
|
|
|
|
// Track when we last saved so we can detect refreshes.
|
|
cachedKIDs := sortedKIDs(initialKeys)
|
|
lastSaved := time.Time{}
|
|
|
|
verifier, err := fetcher.Verifier()
|
|
if err != nil {
|
|
log.Fatalf("failed to get verifier: %v", err)
|
|
}
|
|
|
|
// Save if keys were refreshed and KIDs changed.
|
|
if fetcher.RefreshedAt().After(lastSaved) {
|
|
lastSaved = fetcher.RefreshedAt()
|
|
kids := sortedKIDs(verifier.PublicKeys())
|
|
if !slices.Equal(kids, cachedKIDs) {
|
|
if err := keyfile.SavePublicJWKs(cacheFile, verifier.PublicKeys()); err != nil {
|
|
log.Printf("save cached keys: %v", err)
|
|
} else {
|
|
cachedKIDs = kids
|
|
log.Printf("saved %d keys to %s", len(verifier.PublicKeys()), cacheFile)
|
|
}
|
|
}
|
|
}
|
|
|
|
fmt.Printf("verifier ready with %d keys\n", len(verifier.PublicKeys()))
|
|
}
|
|
|
|
// loadCachedKeys reads a JWKS file and returns the keys. Returns nil
|
|
// (not an error) if the file doesn't exist.
|
|
func loadCachedKeys(path string) ([]jwt.PublicKey, error) {
|
|
jwks, err := keyfile.LoadWellKnownJWKs(path)
|
|
if os.IsNotExist(err) {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return jwks.Keys, nil
|
|
}
|
|
|
|
// sortedKIDs returns the KIDs from keys in sorted order.
|
|
func sortedKIDs(keys []jwt.PublicKey) []string {
|
|
kids := make([]string, len(keys))
|
|
for i := range keys {
|
|
kids[i] = keys[i].KID
|
|
}
|
|
slices.Sort(kids)
|
|
return kids
|
|
}
|