mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-29 13:13:57 +00:00
336 lines
11 KiB
Go
336 lines
11 KiB
Go
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
//
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package keyfetch
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/therootcompany/golib/auth/jwt"
|
|
)
|
|
|
|
// Error sentinels for fetch operations.
|
|
var (
|
|
// ErrFetchFailed indicates a network, HTTP, or parsing failure during
|
|
// a JWKS fetch. The wrapped message includes details (status code,
|
|
// network error, parse error, etc).
|
|
ErrFetchFailed = errors.New("fetch failed")
|
|
|
|
// ErrKeysExpired indicates the cached keys are past their hard expiry.
|
|
// Returned alongside expired keys when RefreshTimeout fires before the
|
|
// refresh completes. Use [errors.Is] to check.
|
|
ErrKeysExpired = errors.New("cached keys expired")
|
|
|
|
// ErrEmptyKeySet indicates the JWKS document contains no keys.
|
|
ErrEmptyKeySet = errors.New("empty key set")
|
|
)
|
|
|
|
// maxResponseBody is the maximum JWKS response body size (1 MiB).
|
|
// A realistic JWKS with dozens of keys is well under 100 KiB.
|
|
const maxResponseBody = 1 << 20
|
|
|
|
// Default cache policy values, used when the corresponding [KeyFetcher]
|
|
// field is zero.
|
|
const (
|
|
defaultMinTTL = 1 * time.Minute // floor - server values below this are raised
|
|
defaultMaxTTL = 24 * time.Hour // ceiling - server values above this are clamped
|
|
defaultTTL = 15 * time.Minute // used when no cache headers are present
|
|
)
|
|
|
|
// defaultTimeout is the timeout used when no HTTP client is provided.
|
|
const defaultTimeout = 30 * time.Second
|
|
|
|
// asset holds the response body and computed cache timing from a fetch.
|
|
type asset struct {
|
|
data []byte
|
|
expiry time.Time // hard expiry - do not use after this time
|
|
stale time.Time // background refresh should start at this time
|
|
etag string // opaque validator for conditional re-fetch
|
|
lastModified string // date-based validator for conditional re-fetch
|
|
}
|
|
|
|
// fetchRaw retrieves raw bytes from a URL using the given HTTP client.
|
|
// If prev is non-nil, conditional request headers (If-None-Match,
|
|
// If-Modified-Since) are sent; a 304 response refreshes the cache
|
|
// timing on prev and returns it without re-downloading the body.
|
|
// If client is nil, a default client with a 30s timeout is used.
|
|
//
|
|
// The returned *http.Response has its Body consumed and closed; headers
|
|
// remain accessible.
|
|
func fetchRaw(ctx context.Context, url string, client *http.Client, p cachePolicy, prev *asset) (*asset, *http.Response, error) {
|
|
resp, err := doGET(ctx, url, client, prev)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("fetch %q: %w", url, err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
now := time.Now()
|
|
|
|
// 304 Not Modified - reuse previous body with refreshed cache timing.
|
|
if resp.StatusCode == http.StatusNotModified && prev != nil {
|
|
expiry, stale := cacheTimings(now, resp, p)
|
|
etag := resp.Header.Get("ETag")
|
|
if etag == "" {
|
|
etag = prev.etag
|
|
}
|
|
lastMod := resp.Header.Get("Last-Modified")
|
|
if lastMod == "" {
|
|
lastMod = prev.lastModified
|
|
}
|
|
return &asset{
|
|
data: prev.data,
|
|
expiry: expiry,
|
|
stale: stale,
|
|
etag: etag,
|
|
lastModified: lastMod,
|
|
}, resp, nil
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBody+1))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("fetch %q: read body: %w: %w", url, ErrFetchFailed, err)
|
|
}
|
|
if len(body) > maxResponseBody {
|
|
return nil, nil, fmt.Errorf("fetch %q: response exceeds %d byte limit: %w", url, maxResponseBody, ErrFetchFailed)
|
|
}
|
|
|
|
expiry, stale := cacheTimings(now, resp, p)
|
|
|
|
return &asset{
|
|
data: body,
|
|
expiry: expiry,
|
|
stale: stale,
|
|
etag: resp.Header.Get("ETag"),
|
|
lastModified: resp.Header.Get("Last-Modified"),
|
|
}, resp, nil
|
|
}
|
|
|
|
// cachePolicy holds resolved cache tuning parameters.
|
|
type cachePolicy struct {
|
|
minTTL time.Duration
|
|
maxTTL time.Duration
|
|
defaultTTL time.Duration
|
|
}
|
|
|
|
// defaultPolicy returns a cachePolicy using the package defaults.
|
|
func defaultPolicy() cachePolicy {
|
|
return cachePolicy{
|
|
minTTL: defaultMinTTL,
|
|
maxTTL: defaultMaxTTL,
|
|
defaultTTL: defaultTTL,
|
|
}
|
|
}
|
|
|
|
// cacheTimings computes expiry and stale times from the response headers.
|
|
// Stale time is always 3/4 of the TTL.
|
|
//
|
|
// Policy:
|
|
// - No usable max-age => defaultTTL (15m), stale at 3/4
|
|
// - max-age < minTTL (1m) => minTTL*2 expiry, minTTL stale
|
|
// - max-age > maxTTL (24h) => clamped to maxTTL, stale at 3/4
|
|
// - Otherwise => server value, stale at 3/4
|
|
func cacheTimings(now time.Time, resp *http.Response, p cachePolicy) (expiry, stale time.Time) {
|
|
serverTTL := parseCacheControlMaxAge(resp.Header.Get("Cache-Control"))
|
|
if age := parseAge(resp.Header.Get("Age")); age > 0 {
|
|
serverTTL -= age
|
|
}
|
|
|
|
var ttl time.Duration
|
|
switch {
|
|
case serverTTL <= 0:
|
|
// No cache headers or max-age=0 or Age consumed it all
|
|
ttl = p.defaultTTL
|
|
case serverTTL < p.minTTL:
|
|
// Server says cache briefly - use floor
|
|
return now.Add(p.minTTL * 2), now.Add(p.minTTL)
|
|
case serverTTL > p.maxTTL:
|
|
ttl = p.maxTTL
|
|
default:
|
|
ttl = serverTTL
|
|
}
|
|
|
|
return now.Add(ttl), now.Add(ttl * 3 / 4)
|
|
}
|
|
|
|
// FetchURL retrieves and parses a JWKS document from the given JWKS endpoint URL.
|
|
//
|
|
// The response body is limited to 1 MiB. If client is nil, a default client
|
|
// with a 30s timeout is used.
|
|
//
|
|
// The returned [*http.Response] has its Body already consumed and closed.
|
|
// Headers such as ETag, Last-Modified, and Cache-Control remain accessible
|
|
// and are used internally by [KeyFetcher] for cache management.
|
|
func FetchURL(ctx context.Context, jwksURL string, client *http.Client) ([]jwt.PublicKey, *http.Response, error) {
|
|
a, resp, err := fetchRaw(ctx, jwksURL, client, defaultPolicy(), nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
keys, err := parseJWKS(a.data)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return keys, resp, nil
|
|
}
|
|
|
|
// fetchJWKS fetches and parses a JWKS document, returning the asset for
|
|
// cache timing and the parsed keys. prev is passed through to fetchRaw
|
|
// for conditional requests.
|
|
func fetchJWKS(ctx context.Context, jwksURL string, client *http.Client, p cachePolicy, prev *asset) (*asset, []jwt.PublicKey, error) {
|
|
a, _, err := fetchRaw(ctx, jwksURL, client, p, prev)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
keys, err := parseJWKS(a.data)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return a, keys, nil
|
|
}
|
|
|
|
// parseJWKS unmarshals a JWKS document into public keys.
|
|
// Returns [ErrEmptyKeySet] if the key set is empty.
|
|
func parseJWKS(data []byte) ([]jwt.PublicKey, error) {
|
|
var jwks jwt.WellKnownJWKs
|
|
if err := json.Unmarshal(data, &jwks); err != nil {
|
|
return nil, fmt.Errorf("parse JWKS: %w: %w", ErrFetchFailed, err)
|
|
}
|
|
if len(jwks.Keys) == 0 {
|
|
return nil, fmt.Errorf("parse JWKS: %w", ErrEmptyKeySet)
|
|
}
|
|
return jwks.Keys, nil
|
|
}
|
|
|
|
// parseCacheControlMaxAge extracts the max-age value from a Cache-Control header.
|
|
// Returns 0 if the header is absent or does not contain a valid max-age directive.
|
|
func parseCacheControlMaxAge(header string) time.Duration {
|
|
for part := range strings.SplitSeq(header, ",") {
|
|
part = strings.TrimSpace(part)
|
|
if val, ok := strings.CutPrefix(part, "max-age="); ok {
|
|
n, err := strconv.Atoi(val)
|
|
if err == nil && n > 0 {
|
|
return time.Duration(n) * time.Second
|
|
}
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// parseAge extracts the Age header value as a Duration.
|
|
// Returns 0 if the header is absent or unparseable.
|
|
func parseAge(header string) time.Duration {
|
|
if header == "" {
|
|
return 0
|
|
}
|
|
n, err := strconv.Atoi(strings.TrimSpace(header))
|
|
if err != nil || n <= 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(n) * time.Second
|
|
}
|
|
|
|
// FetchOIDC fetches JWKS via OIDC discovery from the given base URL.
|
|
//
|
|
// It fetches {baseURL}/.well-known/openid-configuration, reads the jwks_uri
|
|
// field, then fetches and parses the JWKS from that URI.
|
|
//
|
|
// client is used for all HTTP requests; if nil, a default 30s-timeout client is used.
|
|
func FetchOIDC(ctx context.Context, baseURL string, client *http.Client) ([]jwt.PublicKey, *http.Response, error) {
|
|
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
|
|
jwksURI, err := fetchDiscoveryURI(ctx, discoveryURL, client)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return FetchURL(ctx, jwksURI, client)
|
|
}
|
|
|
|
// FetchOAuth2 fetches JWKS via OAuth 2.0 authorization server metadata from the
|
|
// given base URL.
|
|
//
|
|
// https://www.rfc-editor.org/rfc/rfc8414.html
|
|
//
|
|
// It fetches {baseURL}/.well-known/oauth-authorization-server, reads the
|
|
// jwks_uri field, then fetches and parses the JWKS from that URI.
|
|
//
|
|
// client is used for all HTTP requests; if nil, a default 30s-timeout client is used.
|
|
func FetchOAuth2(ctx context.Context, baseURL string, client *http.Client) ([]jwt.PublicKey, *http.Response, error) {
|
|
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
|
|
jwksURI, err := fetchDiscoveryURI(ctx, discoveryURL, client)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return FetchURL(ctx, jwksURI, client)
|
|
}
|
|
|
|
// fetchDiscoveryURI fetches a discovery document and returns the validated
|
|
// jwks_uri from it. The URI is required to use HTTPS to prevent SSRF via a
|
|
// malicious discovery document pointing at an internal endpoint.
|
|
func fetchDiscoveryURI(ctx context.Context, discoveryURL string, client *http.Client) (string, error) {
|
|
resp, err := doGET(ctx, discoveryURL, client, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("fetch discovery: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
var doc struct {
|
|
JWKsURI string `json:"jwks_uri"`
|
|
}
|
|
if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseBody)).Decode(&doc); err != nil {
|
|
return "", fmt.Errorf("parse discovery doc: %w: %w", ErrFetchFailed, err)
|
|
}
|
|
if doc.JWKsURI == "" {
|
|
return "", fmt.Errorf("discovery doc missing jwks_uri: %w", ErrFetchFailed)
|
|
}
|
|
if !strings.HasPrefix(doc.JWKsURI, "https://") {
|
|
return "", fmt.Errorf("jwks_uri must be https, got %q: %w", doc.JWKsURI, ErrFetchFailed)
|
|
}
|
|
return doc.JWKsURI, nil
|
|
}
|
|
|
|
// doGET performs an HTTP GET request and returns the response. It follows
|
|
// redirects (Go's default of up to 10), handles nil client defaults, and
|
|
// checks the final status code. If prev is non-nil, conditional request
|
|
// headers are sent and a 304 response is allowed. Callers must close
|
|
// resp.Body.
|
|
func doGET(ctx context.Context, url string, client *http.Client, prev *asset) (*http.Response, error) {
|
|
if client == nil {
|
|
client = &http.Client{Timeout: defaultTimeout}
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %w", ErrFetchFailed, err)
|
|
}
|
|
if prev != nil {
|
|
if prev.etag != "" {
|
|
req.Header.Set("If-None-Match", prev.etag)
|
|
}
|
|
if prev.lastModified != "" {
|
|
req.Header.Set("If-Modified-Since", prev.lastModified)
|
|
}
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %w", ErrFetchFailed, err)
|
|
}
|
|
if resp.StatusCode == http.StatusNotModified && prev != nil {
|
|
return resp, nil
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
_ = resp.Body.Close()
|
|
return nil, fmt.Errorf("status %d: %w", resp.StatusCode, ErrFetchFailed)
|
|
}
|
|
return resp, nil
|
|
}
|