ref!(auth/jwt): full modern rewrite

This commit is contained in:
AJ ONeal 2026-03-17 07:14:17 -06:00
parent 117ed8cc9b
commit 26bdc0a3db
No known key found for this signature in database
19 changed files with 7867 additions and 710 deletions

98
auth/jwt/claims.go Normal file
View File

@ -0,0 +1,98 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package jwt
// TokenClaims holds the standard JWT and OIDC claims: the RFC 7519
// registered claim names (iss, sub, aud, exp, nbf, iat, jti), the OIDC-specific
// authentication event fields (auth_time, nonce, amr, azp), and OAuth 2.1
// access token fields (client_id, scope).
//
// For OIDC UserInfo profile fields (name, email, phone, locale, etc.),
// use [StandardClaims] instead - it embeds TokenClaims and adds §5.1 fields.
//
// https://www.rfc-editor.org/rfc/rfc7519.html
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
// https://www.rfc-editor.org/rfc/rfc9068.html#section-2.2
//
// Embed TokenClaims or StandardClaims in your own claims struct to
// satisfy [Claims] for free via Go's method promotion - zero boilerplate:
//
// type AppClaims struct {
// jwt.TokenClaims // promotes GetTokenClaims()
// RoleList string `json:"roles"`
// }
// // AppClaims now satisfies Claims automatically.
type TokenClaims struct {
Iss string `json:"iss"` // Issuer (a.k.a. Provider ID) - the auth provider's identifier
Sub string `json:"sub"` // Subject (a.k.a. Account ID) - pairwise id between provider and account
Aud Listish `json:"aud,omitzero"` // Audience (a.k.a. Service Provider) - the intended token recipient
Exp int64 `json:"exp"` // Expiration - the token is not valid after this Unix time
NBf int64 `json:"nbf,omitempty"` // Not Before - the token is not valid until this Unix time
IAt int64 `json:"iat"` // Issued At - when the token was signed
JTI string `json:"jti,omitempty"` // JSON Web Token ID - unique identifier for replay/revocation
AuthTime int64 `json:"auth_time,omitempty"` // Authentication Time - when the end-user last authenticated
Nonce string `json:"nonce,omitempty"` // Nonce - ties an ID Token to a specific auth request
AMR []string `json:"amr,omitempty"` // Authentication Method Reference - how the account was signed in
AzP string `json:"azp,omitempty"` // Authorized Party (a.k.a. Relying Party) - the intended token consumer
ClientID string `json:"client_id,omitempty"` // Client ID - the OAuth client that requested the token
Scope SpaceDelimited `json:"scope,omitzero"` // Scope - granted OAuth 2.1 scopes
}
// GetTokenClaims implements [Claims].
// Any struct embedding TokenClaims gets this method for free via promotion.
func (tc *TokenClaims) GetTokenClaims() *TokenClaims { return tc }
// Claims is implemented for free by any struct that embeds [TokenClaims].
//
// type AppClaims struct {
// jwt.TokenClaims // promotes GetTokenClaims() - zero boilerplate
// RoleList string `json:"roles"`
// }
type Claims interface {
GetTokenClaims() *TokenClaims
}
// StandardClaims embeds [TokenClaims] and adds the OIDC Core §5.1
// UserInfo standard profile claims. Embed StandardClaims in your own type to
// get all fields with zero boilerplate:
//
// type AppClaims struct {
// jwt.StandardClaims // promotes GetTokenClaims()
// Roles []string `json:"roles"`
// }
//
// https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
type StandardClaims struct {
TokenClaims // promotes GetTokenClaims() - satisfies Claims automatically
// Profile fields (OIDC Core §5.1)
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
Profile string `json:"profile,omitempty"` // URL of end-user's profile page
Picture string `json:"picture,omitempty"` // URL of end-user's profile picture
Website string `json:"website,omitempty"` // URL of end-user's web page
// Contact fields
Email string `json:"email,omitempty"`
EmailVerified NullBool `json:"email_verified,omitzero"`
PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified NullBool `json:"phone_number_verified,omitzero"`
// Locale / time fields
Gender string `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"` // YYYY, YYYY-MM, or YYYY-MM-DD (§5.1)
Zoneinfo string `json:"zoneinfo,omitempty"` // IANA tz, e.g. "Europe/Paris"
Locale string `json:"locale,omitempty"` // BCP 47, e.g. "en-US"
UpdatedAt int64 `json:"updated_at,omitempty"` // seconds since Unix epoch
}

2851
auth/jwt/coverage_test.go Normal file

File diff suppressed because it is too large Load Diff

158
auth/jwt/doc.go Normal file
View File

@ -0,0 +1,158 @@
// Package jwt is a lightweight JWT/JWS/JWK library for JOSE, OIDC, and
// OAuth 2.1, designed from first principles for modern Go (1.26+)
// and current standards (OIDC Core 1.0 errata set 2, MCP).
//
// High convenience. Low boilerplate. Easy to customize. Focused:
//
// - You're either building an Issuer (sign JWTs) or Relying Party (verifies and validates JWTs)
// - You're implementing part of JOSE, OIDC or OAuth2 and may have a /jwks.json endpoint
// - You probably do a little of all sides
// - You want type-safe keys (but you don't want to have to type-switch on them)
// - You almost always need custom Claims (token Payload)
// - You almost never need a custom header (but [Header] / [RFCHeader] make it easy)
// - You may also be implementing MCP support for Ai / Agents
//
// Rather than implementing to the spec article by article, this library implements by flow.
//
// This was created with Ai assistance to be able to iterate quickly over different design choices, but every line of the code has been manually reviewed for correctness, as well as many of the tests.
//
// # Design choices
//
// Convenience is not convenient if it gets in your way. This is a library, not
// a framework: it gives you composable pieces you call and control, not
// scaffolding you must conform to.
//
// - Sane defaults for everything, without hiding anything you may need to inspect.
// - There should be one obvious right way to do it.
// - Claims are the most important builder-facing detail.
// - Use simple embedding for maximum convenience without sacrificing optionality.
// - [TokenClaims] for minimal auth info, [StandardClaims] for typical user info.
// (both satisfy [Claims] for free via Go method promotion)
// - [RawJWT.UnmarshalClaims] to get your custom type-safe claims effortlessly.
// - [Validator] for typical auth validation - strict by default, permissive when configured
// (or bring your own, or ignore it and do it how you like)
// Use [NewIDTokenValidator] or [NewAccessTokenValidator] for sensible defaults.
// A zero-value Validator returns [ErrMisconfigured] - always use a constructor.
// - [RFCHeader] is always used in the standard way, and tightly coupled to signing and
// verification - it stays fully customizable as part of the JWT interfaces
// (embedding [RawJWT] and [RFCHeader] make it easy to satisfy [VerifiableJWT] or [SignableJWT])
// - Accessible error details (so that you don't have to round trip just to get the next one)
//
// Key takeaway: Your claims are your own. You can take what you get for free, or add what you need at no cost to you.
//
// # Use case: Issuer (& Relying Party)
//
// You're building the thing that has the Private Keys, signs the tokens + verifies tokens and validates claims.
// - create a [NewSigner] with the private keys
// - use json.Marshal(&signer.WellKnownJWKs) to publish a /jwks.json endpoint
// - use [Signer.SignToString] + [TokenClaims] or [StandardClaims] to create a token string
// (or [Signer.Sign] + [Encode] for the signed JWT object)
// - use [Signer.Verifier] to verify the JWT (bearer token)
// - use [RawJWT.UnmarshalClaims] to get your user info
// - use [Validator.Validate] to validate the claims (user info payload)
// - use custom validation for your own Claims type, or by hand - dealer's choice
//
// # Use case: Relying Party
//
// You're building a thing that uses Public Keys to verify and validate tokens.
// - you may already know the public keys (and redeploy when they change)
// - or you fetch them at runtime from a /jwks.json endpoint (and cache and update periodically)
// - Relying party, known keys: use [NewVerifier] with a []PublicKey slice.
// - Relying party, remote keys: use keyfetch.KeyFetcher to cache and lazy-refresh keys.
// - use [Verifier.VerifyJWT] to decode and verify in one call (or [Decode] + [Verifier.Verify] for two-step)
// - use [RawJWT.UnmarshalClaims] to get your user info
// - use [Validator.Validate] to validate the claims (user info payload)
// - use custom validation for your own Claims type, or by hand - dealer's choice
//
// # Use case: MCP / Agents
//
// An MCP Host (the AI application) is a Relying Party to the MCP Server.
// The MCP Server may be an Issuer - minting tokens specifically for Agents
// to call your API - or it may be a Relying Party to your main auth system,
// forwarding tokens it received from an upstream Issuer.
//
// In either case the same building blocks apply: the Host verifies and
// validates tokens from the Server, and the Server either signs its own
// tokens ([NewSigner]) or verifies tokens from your auth provider
// ([NewVerifier] or keyfetch.KeyFetcher).
//
// # OAuth 2.1 Access Tokens
//
// For APIs that accept OAuth 2.1 access tokens (typ: "at+jwt", RFC 9068),
// use [NewAccessTokenValidator] with [TokenClaims] (which includes the
// client_id and scope fields):
//
// v := jwt.NewAccessTokenValidator(issuers, audiences, relyingParties)
// if err := v.Validate(nil, &claims, time.Now()); err != nil { /* ... */ }
//
// - [NewAccessToken] creates a JWS with the correct "at+jwt" typ header
// - [NewIDTokenValidator] creates a validator for OIDC ID tokens
// - [SpaceDelimited] is a slice that marshals as a space-separated string in JSON,
// with trinary semantics: nil (absent/omitzero), empty non-nil (present as ""),
// or populated ("openid profile")
//
// # Loading keys from files
//
// The keyfile package loads cryptographic keys from local files in JWK,
// PEM, or DER format. All functions auto-compute KID from the RFC 7638
// thumbprint when not already set:
//
// - keyfile.LoadPrivatePEM / keyfile.LoadPublicPEM for PEM files
// - keyfile.LoadPrivateDER / keyfile.LoadPublicDER for DER files
// - keyfile.LoadPublicJWK / keyfile.LoadPrivateJWK / keyfile.LoadWellKnownJWKs for JWK/JWKS files
//
// For fetching keys from remote URLs, use keyfetch.FetchURL (JWKS endpoints)
// or keyfetch.FetchOIDC (OIDC discovery).
//
// # Security
//
// You don't need to be a crypto expert to use this library - but if you are, hopefully
// you find it to be the best you've ever used.
//
// 1. YAGNI: Don't implement what you don't need = less surface area = greater security.
//
// The researchers who write specifications are notorious for imagining every
// hypothetical - which has resulted in numerous security flaws over the years.
// There's nothing in here that I haven't seen in the wild and found useful.
// And I'm happy to extend if needed.
//
// 2. Verify AND Validate
//
// As an Issuer (owner) you [Signer.Sign] and then [Encode].
//
// As a Relying Party (client) you [Decode], [Verifier.Verify] and [Validator.Validate].
//
// Why not a single step? Because Claims (sometimes called "User" in other libs) is the thing
// you actually care about, and actually want type safety for. After trying various approaches
// with embedding and generics, what I landed on is that the most ergonomic type-safe way
// to Verify a JWT and Validate Claims is to have the two be separate operations.
//
// It's why you get to use this library as a library and how you get to have all of the
// convenience without sacrificing control and customization of the thing you're most likely
// to want to be able to customize (and debug).
//
// 3. Algorithms: The fewer the merrier.
//
// Only asymmetric (public-key) algorithms are implemented.
//
// You should use Ed25519. It's the end-game algorithm - all upside, no known
// downsides, and it's supported ubiquitously - Go, JavaScript, Web Browsers, Node, Rust,
// etc.
//
// Ed25519 is the recommended algorithm.
// ECDSA is provided for backwards compatibility with existing systems.
// RSA is provided only for backwards compatibility - it's larger, slower, with no real benefit.
//
// - EC P-256 => ES256 (ECDSA + SHA-256, RFC 7518 §3.4)
// - EC P-384 => ES384 (ECDSA + SHA-384)
// - EC P-521 => ES512 (ECDSA + SHA-512)
// - RSA => RS256 (PKCS#1 v1.5 + SHA-256, RFC 7518 §3.3)
// - Ed25519 => EdDSA (RFC 8037)
//
// Supported algorithms are derived automatically from the key type - you never
// configure alg directly.
//
// The verification process selects a key by matching the "kid" (KeyID) of token
// and the key and then checking "alg" before any cryptographic operation is attempted.
// An alg/key-type mismatch is a hard error.
package jwt

87
auth/jwt/errors.go Normal file
View File

@ -0,0 +1,87 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package jwt
import (
"errors"
"fmt"
)
// Sentinel errors for decode, signature, key, signing, and verification.
//
// These are the operational errors returned by [Decode], [Verifier.Verify],
// [Signer.SignJWT], [NewSigner], and related functions. For claim validation
// errors, see the Err* variables below.
var (
// Decode errors - returned by [Decode] and [RawJWT.UnmarshalClaims]
// when the compact token or its components are malformed.
ErrMalformedToken = errors.New("malformed token")
ErrInvalidHeader = errors.New("invalid header")
ErrInvalidPayload = errors.New("invalid payload")
// Signature and algorithm errors - returned during signing and
// verification when the signature, algorithm, or key type is wrong.
ErrSignatureInvalid = errors.New("signature invalid")
ErrUnsupportedAlg = errors.New("unsupported algorithm")
ErrAlgConflict = errors.New("algorithm conflict")
ErrUnsupportedKeyType = errors.New("unsupported key type")
ErrUnsupportedCurve = errors.New("unsupported curve")
// Key errors - returned when key material is invalid or insufficient.
ErrInvalidKey = errors.New("invalid key")
ErrKeyTooSmall = fmt.Errorf("%w: key too small", ErrInvalidKey)
ErrMissingKeyData = fmt.Errorf("%w: missing key data", ErrInvalidKey)
ErrUnsupportedFormat = errors.New("unsupported format")
// Verification errors - returned by [Verifier.Verify] and
// [Signer.SignJWT] when no key matches the token's kid.
ErrUnknownKID = errors.New("unknown kid")
ErrNoVerificationKey = errors.New("no verification keys")
// Signing errors - returned by [NewSigner] and [Signer.SignJWT].
ErrNoSigningKey = errors.New("no signing key")
// Sanity errors - internal invariant violations that should never
// happen given the library's own validation.
ErrSanityFail = errors.New("something impossible happened")
)
// Sentinel errors for claim validation.
//
// [Validator.Validate] returns all failures at once via [errors.Join].
// Check for specific issues with [errors.Is]:
//
// err := v.Validate(nil, &claims, time.Now())
// if errors.Is(err, jwt.ErrAfterExp) { /* token expired */ }
// if errors.Is(err, jwt.ErrInvalidClaim) { /* any value error */ }
//
// The time-based sentinels (ErrAfterExp, ErrBeforeNBf, etc.) wrap
// ErrInvalidClaim, so a single errors.Is(err, ErrInvalidClaim) check
// catches all value errors.
//
// Use [ValidationErrors] to extract structured [*ValidationError] values
// for API responses, or [GetOAuth2Error] for OAuth 2.0 error responses.
var (
// Claim-level errors.
ErrMissingClaim = errors.New("missing required claim")
ErrInvalidClaim = errors.New("invalid claim value")
ErrInvalidTyp = errors.New("invalid typ header")
ErrInsufficientScope = errors.New("insufficient scope")
// Time-based claim errors - each wraps ErrInvalidClaim.
ErrAfterExp = fmt.Errorf("%w: exp: token expired", ErrInvalidClaim)
ErrBeforeNBf = fmt.Errorf("%w: nbf: token not yet valid", ErrInvalidClaim)
ErrBeforeIAt = fmt.Errorf("%w: iat: issued in the future", ErrInvalidClaim)
ErrBeforeAuthTime = fmt.Errorf("%w: auth_time: in the future", ErrInvalidClaim)
ErrAfterAuthMaxAge = fmt.Errorf("%w: auth_time: exceeds max age", ErrInvalidClaim)
// Server-side misconfiguration - the validator itself is invalid.
// Callers should treat this as a 500 (server error), not 401 (unauthorized).
ErrMisconfigured = errors.New("validator misconfigured")
)

View File

@ -1,3 +1,3 @@
module github.com/therootcompany/golib/auth/jwt
go 1.24.6
go 1.26.1

151
auth/jwt/jwa.go Normal file
View File

@ -0,0 +1,151 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
_ "crypto/sha256" // register SHA-256 with crypto.Hash
_ "crypto/sha512" // register SHA-384 and SHA-512 with crypto.Hash
"encoding/asn1"
"fmt"
"math/big"
)
// curveInfo holds the JWK/JWS identifiers and parameters for an EC curve.
type curveInfo struct {
Curve elliptic.Curve // Go curve object
Crv string // JWK "crv" value: "P-256", "P-384", "P-521"
Alg string // JWS algorithm: "ES256", "ES384", "ES512"
Hash crypto.Hash // signing hash: SHA-256, SHA-384, SHA-512
KeySize int // coordinate byte length: (BitSize+7)/8
}
// Canonical curveInfo values - one var per supported curve.
var (
p256 = curveInfo{elliptic.P256(), "P-256", "ES256", crypto.SHA256, 32}
p384 = curveInfo{elliptic.P384(), "P-384", "ES384", crypto.SHA384, 48}
p521 = curveInfo{elliptic.P521(), "P-521", "ES512", crypto.SHA512, 66}
)
// ecInfoForAlg returns the curveInfo for the given elliptic curve and validates
// that the curve's algorithm matches expectedAlg. This is the verification-side
// check: the key's curve must produce the algorithm the token claims.
func ecInfoForAlg(curve elliptic.Curve, expectedAlg string) (curveInfo, error) {
ci, err := ecInfo(curve)
if err != nil {
return ci, err
}
if ci.Alg != expectedAlg {
return curveInfo{}, fmt.Errorf("key curve %s vs token alg %s: %w", ci.Alg, expectedAlg, ErrAlgConflict)
}
return ci, nil
}
// ecInfo returns the curveInfo for the given elliptic curve.
func ecInfo(curve elliptic.Curve) (curveInfo, error) {
switch curve {
case elliptic.P256():
return p256, nil
case elliptic.P384():
return p384, nil
case elliptic.P521():
return p521, nil
default:
return curveInfo{}, fmt.Errorf("EC curve %s: %w", curve.Params().Name, ErrUnsupportedCurve)
}
}
// ecInfoByCrv returns the curveInfo for a JWK "crv" string.
func ecInfoByCrv(crv string) (curveInfo, error) {
switch crv {
case "P-256":
return p256, nil
case "P-384":
return p384, nil
case "P-521":
return p521, nil
default:
return curveInfo{}, fmt.Errorf("EC crv %q: %w", crv, ErrUnsupportedCurve)
}
}
// signingParams determines the JWS signing parameters for a crypto.Signer.
//
// It type-switches on s.Public() (not on s directly) so that non-standard
// crypto.Signer implementations (KMS, HSM) work as long as they expose a
// standard public key type.
//
// Returns:
// - alg: JWS algorithm string (ES256, ES384, ES512, RS256, EdDSA)
// - hash: crypto.Hash for pre-hashing; 0 for Ed25519 (sign raw message)
// - ecKeySize: ECDSA coordinate byte length; >0 signals that the
// signature needs ASN.1 DER to IEEE P1363 conversion
func signingParams(s crypto.Signer) (alg string, hash crypto.Hash, ecKeySize int, err error) {
switch pub := s.Public().(type) {
case *ecdsa.PublicKey:
ci, err := ecInfo(pub.Curve)
if err != nil {
return "", 0, 0, err
}
return ci.Alg, ci.Hash, ci.KeySize, nil
case *rsa.PublicKey:
return "RS256", crypto.SHA256, 0, nil
case ed25519.PublicKey:
return "EdDSA", 0, 0, nil
default:
return "", 0, 0, fmt.Errorf("%T: %w", pub, ErrUnsupportedKeyType)
}
}
// signingInputBytes builds the protected.payload byte slice used as the signing input.
func signingInputBytes(protected, payload []byte) []byte {
out := make([]byte, 0, len(protected)+1+len(payload))
out = append(out, protected...)
out = append(out, '.')
out = append(out, payload...)
return out
}
// digestFor hashes data with the given crypto.Hash.
func digestFor(h crypto.Hash, data []byte) ([]byte, error) {
if !h.Available() {
return nil, fmt.Errorf("hash %v: %w", h, ErrUnsupportedAlg)
}
hh := h.New()
hh.Write(data)
return hh.Sum(nil), nil
}
// ecdsaDERToP1363 converts an ASN.1 DER-encoded ECDSA signature to
// the fixed-width IEEE P1363 format used by JWS.
func ecdsaDERToP1363(der []byte, keySize int) ([]byte, error) {
var sig struct{ R, S *big.Int }
rest, err := asn1.Unmarshal(der, &sig)
if err != nil {
return nil, err
}
if len(rest) > 0 {
return nil, fmt.Errorf("%d trailing ASN.1 bytes: %w", len(rest), ErrSignatureInvalid)
}
// Validate that R and S fit in keySize bytes before FillBytes.
rLen := (sig.R.BitLen() + 7) / 8
sLen := (sig.S.BitLen() + 7) / 8
if rLen > keySize || sLen > keySize {
return nil, fmt.Errorf("R (%d bytes) or S (%d bytes) exceeds key size %d: %w",
rLen, sLen, keySize, ErrSignatureInvalid)
}
out := make([]byte, 2*keySize)
sig.R.FillBytes(out[:keySize])
sig.S.FillBytes(out[keySize:])
return out, nil
}

716
auth/jwt/jwk.go Normal file
View File

@ -0,0 +1,716 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
)
// CryptoPublicKey is the constraint for public key types stored in [PublicKey].
//
// All standard Go public key types (*ecdsa.PublicKey, *rsa.PublicKey,
// ed25519.PublicKey) implement this interface per the Go standard library
// recommendation.
type CryptoPublicKey interface {
Equal(x crypto.PublicKey) bool
}
// PublicKey wraps a parsed public key with its JWKS metadata.
//
// PublicKey is the in-memory representation of a JWK.
// [PublicKey.KeyType] returns the JWK kty string ("EC", "RSA", or "OKP").
// To access the raw Go key, type-switch on Key:
//
// switch key := pk.Key.(type) {
// case *ecdsa.PublicKey: // ...
// case *rsa.PublicKey: // ...
// case ed25519.PublicKey: // ...
// }
//
// For signing keys, use [PrivateKey] instead - it holds the [crypto.Signer]
// and derives a PublicKey on demand.
type PublicKey struct {
Key CryptoPublicKey
KID string
Use string
Alg string
KeyOps []string
}
// KeyType returns the JWK "kty" string for the key: "EC", "RSA", or "OKP".
// Returns "" if the key type is unrecognized.
//
// To access the underlying Go key, use a type switch on Key:
//
// switch key := k.Key.(type) {
// case *ecdsa.PublicKey: // kty "EC"
// // key is *ecdsa.PublicKey
// case *rsa.PublicKey: // kty "RSA"
// // key is *rsa.PublicKey
// case ed25519.PublicKey: // kty "OKP"
// // key is ed25519.PublicKey
// default:
// // unrecognized key type
// }
func (k PublicKey) KeyType() string {
switch k.Key.(type) {
case *ecdsa.PublicKey:
return "EC"
case *rsa.PublicKey:
return "RSA"
case ed25519.PublicKey:
return "OKP"
default:
return ""
}
}
// MarshalJSON implements [json.Marshaler], encoding the key as a JWK JSON object.
// Private key fields are never included.
func (k PublicKey) MarshalJSON() ([]byte, error) {
pk, err := encode(k)
if err != nil {
return nil, err
}
return json.Marshal(pk)
}
// UnmarshalJSON implements [json.Unmarshaler], parsing a JWK JSON object.
// Private key fields (d, p, q, etc.) are silently ignored.
// If the JWK has no "kid" field, the KID is auto-computed via [PublicKey.Thumbprint].
func (k *PublicKey) UnmarshalJSON(data []byte) error {
var kj rawKey
if err := json.Unmarshal(data, &kj); err != nil {
return fmt.Errorf("parse JWK: %w", err)
}
decoded, err := decodeOne(kj)
if err != nil {
return err
}
*k = *decoded
return nil
}
// Thumbprint computes the RFC 7638 JWK Thumbprint (SHA-256 of the canonical
// key JSON with fields in lexicographic order). The result is base64url-encoded.
//
// https://www.rfc-editor.org/rfc/rfc7638.html
//
// Canonical forms per RFC 7638:
// - EC: {"crv":..., "kty":"EC", "x":..., "y":...}
// - RSA: {"e":..., "kty":"RSA", "n":...}
// - OKP: {"crv":"Ed25519", "kty":"OKP", "x":...}
func (k PublicKey) Thumbprint() (string, error) {
rk, err := encode(k)
if err != nil {
return "", err
}
// Build canonical JSON with fields in lexicographic order per RFC 7638.
var canonical []byte
switch rk.Kty {
case "EC":
canonical, err = json.Marshal(struct {
Crv string `json:"crv"`
Kty string `json:"kty"`
X string `json:"x"`
Y string `json:"y"`
}{Crv: rk.Crv, Kty: rk.Kty, X: rk.X, Y: rk.Y})
case "RSA":
canonical, err = json.Marshal(struct {
E string `json:"e"`
Kty string `json:"kty"`
N string `json:"n"`
}{E: rk.E, Kty: rk.Kty, N: rk.N})
case "OKP":
canonical, err = json.Marshal(struct {
Crv string `json:"crv"`
Kty string `json:"kty"`
X string `json:"x"`
}{Crv: rk.Crv, Kty: rk.Kty, X: rk.X})
default:
return "", fmt.Errorf("thumbprint: kty %q: %w", rk.Kty, ErrUnsupportedKeyType)
}
if err != nil {
return "", fmt.Errorf("thumbprint: marshal canonical JSON: %w", err)
}
sum := sha256.Sum256(canonical)
return base64.RawURLEncoding.EncodeToString(sum[:]), nil
}
// PrivateKey wraps a [crypto.Signer] (private key) with its JWKS metadata.
//
// PrivateKey satisfies [json.Marshaler] and [json.Unmarshaler]:
// marshaling includes the private key material (the "d" field and RSA primes);
// unmarshaling reconstructs a fully operational signing key from a JWK with
// private fields present. Never publish the marshaled output - it contains
// private key material.
//
// Use [FromPrivateKey] to construct.
type PrivateKey struct {
privKey crypto.Signer
KID string
Use string
Alg string
KeyOps []string
}
// PublicKey derives the [PublicKey] for this signing key.
// KID, Use, and Alg are copied directly. KeyOps are translated to their
// public-key equivalents: "sign"=>"verify", "decrypt"=>"encrypt",
// "unwrapKey"=>"wrapKey". Any op with no public equivalent is omitted.
//
// Returns an error if the Signer's Public() method does not return a
// known CryptoPublicKey type - this should never happen for keys created
// through this library.
func (k *PrivateKey) PublicKey() (*PublicKey, error) {
pub, ok := k.privKey.Public().(CryptoPublicKey)
if !ok {
return nil, fmt.Errorf("%w: private key type %T did not produce a known public key type", ErrSanityFail, k.privKey)
}
return &PublicKey{
Key: pub,
KID: k.KID,
Use: k.Use,
Alg: k.Alg,
KeyOps: toPublicKeyOps(k.KeyOps),
}, nil
}
// toPublicKeyOps translates private-key key_ops values to their public-key
// counterparts per RFC 7517 §4.3. Operations with no public-key equivalent
// (e.g. "deriveKey", "deriveBits") are dropped.
func toPublicKeyOps(ops []string) []string {
if len(ops) == 0 {
return ops
}
out := make([]string, 0, len(ops))
for _, op := range ops {
switch op {
case "sign":
out = append(out, "verify")
case "decrypt":
out = append(out, "encrypt")
case "unwrapKey":
out = append(out, "wrapKey")
case "verify", "encrypt", "wrapKey":
// Already a public-key op - pass through unchanged.
out = append(out, op)
}
}
if len(out) == 0 {
return nil
}
return out
}
// Thumbprint computes the RFC 7638 thumbprint for this key's public side.
// It delegates to [PublicKey.Thumbprint] on the result of [PrivateKey.PublicKey].
func (k *PrivateKey) Thumbprint() (string, error) {
pub, err := k.PublicKey()
if err != nil {
return "", err
}
return pub.Thumbprint()
}
// NewPrivateKey generates a new private key using the best universally
// available algorithm, currently Ed25519. The algorithm may change in
// future versions; use [FromPrivateKey] to wrap a specific key type.
//
// The KID is auto-computed from the RFC 7638 thumbprint of the public key.
//
// Ed25519 is the recommended default: fast, compact 64-byte signatures, and
// deterministic signing (no per-signature random nonce, unlike ECDSA).
func NewPrivateKey() (*PrivateKey, error) {
_, priv, err := ed25519.GenerateKey(nil)
if err != nil {
return nil, fmt.Errorf("NewPrivateKey: generate Ed25519 key: %w", err)
}
pk := &PrivateKey{privKey: priv}
kid, err := pk.Thumbprint()
if err != nil {
return nil, fmt.Errorf("NewPrivateKey: compute thumbprint: %w", err)
}
pk.KID = kid
return pk, nil
}
// MarshalJSON implements [json.Marshaler], encoding the key as a JWK JSON object
// that includes private key material (the "d" field and RSA CRT components).
// Never publish the result - it contains the private key.
func (k PrivateKey) MarshalJSON() ([]byte, error) {
pk, err := encodePrivate(k)
if err != nil {
return nil, err
}
return json.Marshal(pk)
}
// UnmarshalJSON implements [json.Unmarshaler], parsing a JWK JSON object that
// contains private key material. The "d" field (and RSA primes) must be present;
// public-key-only JWKs return an error. If the JWK has no "kid" field, the KID
// is auto-computed via [PublicKey.Thumbprint].
func (k *PrivateKey) UnmarshalJSON(data []byte) error {
var kj rawKey
if err := json.Unmarshal(data, &kj); err != nil {
return fmt.Errorf("parse JWK: %w", err)
}
decoded, err := decodePrivate(kj)
if err != nil {
return err
}
*k = *decoded
return nil
}
// rawKey is the unexported JSON wire representation of a JWK object.
// It is used internally by [PublicKey] and [PrivateKey] JSON methods.
type rawKey struct {
Kty string `json:"kty"`
KID string `json:"kid,omitempty"`
Crv string `json:"crv,omitempty"`
X string `json:"x,omitempty"`
Y string `json:"y,omitempty"`
D string `json:"d,omitempty"` // EC/OKP: private scalar; RSA: private exponent
N string `json:"n,omitempty"`
E string `json:"e,omitempty"`
P string `json:"p,omitempty"` // RSA: first prime factor
Q string `json:"q,omitempty"` // RSA: second prime factor
DP string `json:"dp,omitempty"` // RSA: d mod (p-1)
DQ string `json:"dq,omitempty"` // RSA: d mod (q-1)
QI string `json:"qi,omitempty"` // RSA: q^-1 mod p
Use string `json:"use,omitempty"`
Alg string `json:"alg,omitempty"`
KeyOps []string `json:"key_ops,omitempty"`
}
// WellKnownJWKs is a JSON Web Key Set as served by a /.well-known/jwks.json
// endpoint. It contains only public keys - private material is stripped
// during unmarshalling. Use json.Marshal and json.Unmarshal directly - each
// [PublicKey] in Keys handles its own encoding via MarshalJSON / UnmarshalJSON.
type WellKnownJWKs struct {
Keys []PublicKey `json:"keys"`
}
// encode converts a [PublicKey] to its [rawKey] wire representation.
// Used by [PublicKey.MarshalJSON] and [PublicKey.Thumbprint].
func encode(k PublicKey) (rawKey, error) {
rk := rawKey{KID: k.KID, Use: k.Use, Alg: k.Alg, KeyOps: k.KeyOps}
switch key := k.Key.(type) {
case *ecdsa.PublicKey:
ci, err := ecInfo(key.Curve)
if err != nil {
return rawKey{}, err
}
b, err := key.Bytes() // uncompressed: 0x04 || X || Y
if err != nil {
return rawKey{}, fmt.Errorf("encode EC key: %w", err)
}
rk.Kty = "EC"
rk.Crv = ci.Crv
rk.X = base64.RawURLEncoding.EncodeToString(b[1 : 1+ci.KeySize])
rk.Y = base64.RawURLEncoding.EncodeToString(b[1+ci.KeySize:])
return rk, nil
case *rsa.PublicKey:
eInt := big.NewInt(int64(key.E))
rk.Kty = "RSA"
rk.N = base64.RawURLEncoding.EncodeToString(key.N.Bytes())
rk.E = base64.RawURLEncoding.EncodeToString(eInt.Bytes())
return rk, nil
case ed25519.PublicKey:
rk.Kty = "OKP"
rk.Crv = "Ed25519"
rk.X = base64.RawURLEncoding.EncodeToString([]byte(key))
return rk, nil
default:
return rawKey{}, fmt.Errorf("%T: %w", k.Key, ErrUnsupportedKeyType)
}
}
// encodePrivate converts a [PrivateKey] to its [rawKey] wire representation,
// including private key material (d, and RSA CRT components p/q/dp/dq/qi).
// Used by [PrivateKey.MarshalJSON].
func encodePrivate(k PrivateKey) (rawKey, error) {
pub, err := k.PublicKey()
if err != nil {
return rawKey{}, err
}
rk, err := encode(*pub)
if err != nil {
return rawKey{}, err
}
switch priv := k.privKey.(type) {
case *ecdsa.PrivateKey:
dBytes, err := priv.Bytes()
if err != nil {
return rawKey{}, fmt.Errorf("encode EC private key: %w", err)
}
rk.D = base64.RawURLEncoding.EncodeToString(dBytes)
case *rsa.PrivateKey:
rk.D = base64.RawURLEncoding.EncodeToString(priv.D.Bytes())
if len(priv.Primes) >= 2 {
priv.Precompute()
rk.P = base64.RawURLEncoding.EncodeToString(priv.Primes[0].Bytes())
rk.Q = base64.RawURLEncoding.EncodeToString(priv.Primes[1].Bytes())
if priv.Precomputed.Dp != nil {
rk.DP = base64.RawURLEncoding.EncodeToString(priv.Precomputed.Dp.Bytes())
rk.DQ = base64.RawURLEncoding.EncodeToString(priv.Precomputed.Dq.Bytes())
rk.QI = base64.RawURLEncoding.EncodeToString(priv.Precomputed.Qinv.Bytes())
}
}
case ed25519.PrivateKey:
rk.D = base64.RawURLEncoding.EncodeToString(priv.Seed())
default:
return rawKey{}, fmt.Errorf("%T: %w", k.privKey, ErrUnsupportedKeyType)
}
return rk, nil
}
// FromPublicKey wraps a Go crypto public key in a [PublicKey] with
// auto-computed KID (RFC 7638 thumbprint) and Alg.
//
// Supported key types: *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey.
// Returns an error for unsupported types or if the thumbprint cannot be computed.
func FromPublicKey(pub crypto.PublicKey) (*PublicKey, error) {
cpk, ok := pub.(CryptoPublicKey)
if !ok {
return nil, fmt.Errorf("%T: %w", pub, ErrUnsupportedKeyType)
}
pk := &PublicKey{Key: cpk}
// Derive Alg from key type.
switch key := pub.(type) {
case *ecdsa.PublicKey:
ci, err := ecInfo(key.Curve)
if err != nil {
return nil, err
}
pk.Alg = ci.Alg
case *rsa.PublicKey:
pk.Alg = "RS256"
case ed25519.PublicKey:
pk.Alg = "EdDSA"
default:
return nil, fmt.Errorf("%T: %w", pub, ErrUnsupportedKeyType)
}
kid, err := pk.Thumbprint()
if err != nil {
return nil, fmt.Errorf("compute thumbprint: %w", err)
}
pk.KID = kid
return pk, nil
}
// FromPrivateKey wraps a [crypto.Signer] in a [PrivateKey] with
// the given KID and auto-derived Alg.
//
// Returns [ErrUnsupportedKeyType] if the signer is not a supported type.
// If kid is empty, [NewSigner] will auto-compute it from the key's
// RFC 7638 JWK Thumbprint. For standalone use, call [PrivateKey.Thumbprint]
// and set KID manually.
func FromPrivateKey(signer crypto.Signer, kid string) (*PrivateKey, error) {
alg, _, _, err := signingParams(signer)
if err != nil {
return nil, err
}
return &PrivateKey{privKey: signer, KID: kid, Alg: alg}, nil
}
// decodeOne parses a single rawKey wire struct into a [PublicKey].
// If the JWK has no "kid" field, the KID is auto-computed via [PublicKey.Thumbprint].
//
// Supported key types:
// - "RSA" - minimum 1024-bit (RS256)
// - "EC" - P-256, P-384, P-521 (ES256, ES384, ES512)
// - "OKP" - Ed25519 crv (EdDSA, RFC 8037) https://www.rfc-editor.org/rfc/rfc8037.html
func decodeOne(kj rawKey) (*PublicKey, error) {
var pk *PublicKey
switch kj.Kty {
case "RSA":
key, err := decodeRSA(kj)
if err != nil {
return nil, fmt.Errorf("parse RSA key %q: %w", kj.KID, err)
}
pk = kj.newPublicKey(key)
case "EC":
key, err := decodeEC(kj)
if err != nil {
return nil, fmt.Errorf("parse EC key %q: %w", kj.KID, err)
}
pk = kj.newPublicKey(key)
case "OKP":
key, err := decodeOKP(kj)
if err != nil {
return nil, fmt.Errorf("parse OKP key %q: %w", kj.KID, err)
}
pk = kj.newPublicKey(key)
default:
return nil, fmt.Errorf("kid %q: kty %q: %w", kj.KID, kj.Kty, ErrUnsupportedKeyType)
}
if pk.KID == "" {
kid, err := pk.Thumbprint()
if err != nil {
return nil, fmt.Errorf("compute thumbprint: %w", err)
}
pk.KID = kid
}
return pk, nil
}
// decodePrivate parses a rawKey wire struct that contains private key material
// into a [PrivateKey]. If the JWK has no "kid" field, the KID is auto-computed
// via [PublicKey.Thumbprint]. Returns an error if the "d" field is missing.
func decodePrivate(kj rawKey) (*PrivateKey, error) {
if kj.D == "" {
return nil, fmt.Errorf("\"d\" field missing: %w", ErrMissingKeyData)
}
var pk *PrivateKey
switch kj.Kty {
case "EC":
ci, err := ecInfoByCrv(kj.Crv)
if err != nil {
return nil, fmt.Errorf("parse EC private key %q: %w", kj.KID, err)
}
dBytes, err := decodeB64Field("EC", kj.KID, "d", kj.D)
if err != nil {
return nil, err
}
// ParseRawPrivateKey validates the scalar and derives the public key.
priv, err := ecdsa.ParseRawPrivateKey(ci.Curve, dBytes)
if err != nil {
return nil, fmt.Errorf("parse EC private key %q: %w: %w", kj.KID, ErrInvalidKey, err)
}
pk = kj.newPrivateKey(priv)
case "RSA":
pub, err := decodeRSA(kj)
if err != nil {
return nil, fmt.Errorf("parse RSA private key %q: %w", kj.KID, err)
}
dBytes, err := decodeB64Field("RSA", kj.KID, "d", kj.D)
if err != nil {
return nil, err
}
priv := &rsa.PrivateKey{
PublicKey: *pub,
D: new(big.Int).SetBytes(dBytes),
}
if kj.P != "" && kj.Q != "" {
p, err := decodeB64Field("RSA", kj.KID, "p", kj.P)
if err != nil {
return nil, err
}
q, err := decodeB64Field("RSA", kj.KID, "q", kj.Q)
if err != nil {
return nil, err
}
priv.Primes = []*big.Int{
new(big.Int).SetBytes(p),
new(big.Int).SetBytes(q),
}
priv.Precompute()
}
if err := priv.Validate(); err != nil {
return nil, fmt.Errorf("parse RSA private key %q: %w: %w", kj.KID, ErrInvalidKey, err)
}
pk = kj.newPrivateKey(priv)
case "OKP":
if kj.Crv != "Ed25519" {
return nil, fmt.Errorf("parse OKP private key %q: crv %q: %w", kj.KID, kj.Crv, ErrUnsupportedCurve)
}
seed, err := decodeB64Field("Ed25519", kj.KID, "d", kj.D)
if err != nil {
return nil, err
}
if len(seed) != ed25519.SeedSize {
return nil, fmt.Errorf("parse Ed25519 private key %q: seed size %d, want %d: %w", kj.KID, len(seed), ed25519.SeedSize, ErrInvalidKey)
}
priv := ed25519.NewKeyFromSeed(seed)
pk = kj.newPrivateKey(priv)
default:
return nil, fmt.Errorf("kid %q: kty %q: %w", kj.KID, kj.Kty, ErrUnsupportedKeyType)
}
if pk.KID == "" {
kid, err := pk.Thumbprint()
if err != nil {
return nil, fmt.Errorf("compute thumbprint: %w", err)
}
pk.KID = kid
}
return pk, nil
}
// newPublicKey creates a [PublicKey] from a crypto key, copying metadata
// (KID, Use, Alg, KeyOps) from the rawKey.
func (kj rawKey) newPublicKey(key CryptoPublicKey) *PublicKey {
return &PublicKey{Key: key, KID: kj.KID, Use: kj.Use, Alg: kj.Alg, KeyOps: kj.KeyOps}
}
// newPrivateKey creates a [PrivateKey] from a crypto.Signer, copying metadata
// (KID, Use, Alg, KeyOps) from the rawKey.
func (kj rawKey) newPrivateKey(signer crypto.Signer) *PrivateKey {
return &PrivateKey{privKey: signer, KID: kj.KID, Use: kj.Use, Alg: kj.Alg, KeyOps: kj.KeyOps}
}
// decodeB64Field decodes a base64url-encoded JWK field value, returning a
// descriptive error that includes the key type, KID, and field name.
func decodeB64Field(kty, kid, field, value string) ([]byte, error) {
b, err := base64.RawURLEncoding.DecodeString(value)
if err != nil {
return nil, fmt.Errorf("parse %s private key %q: invalid %s: %w: %w", kty, kid, field, ErrInvalidKey, err)
}
return b, nil
}
func decodeRSA(kj rawKey) (*rsa.PublicKey, error) {
n, err := base64.RawURLEncoding.DecodeString(kj.N)
if err != nil {
return nil, fmt.Errorf("invalid n: %w: %w", ErrInvalidKey, err)
}
e, err := base64.RawURLEncoding.DecodeString(kj.E)
if err != nil {
return nil, fmt.Errorf("invalid e: %w: %w", ErrInvalidKey, err)
}
eInt := new(big.Int).SetBytes(e)
if !eInt.IsInt64() {
return nil, fmt.Errorf("RSA exponent too large: %w", ErrInvalidKey)
}
eVal := eInt.Int64()
// Minimum exponent of 3 rejects degenerate keys (e=1 makes RSA trivial).
// Cap at MaxInt32 so the value fits in an int on 32-bit platforms.
if eVal < 3 {
return nil, fmt.Errorf("RSA exponent must be at least 3, got %d: %w", eVal, ErrInvalidKey)
}
if eVal > 1<<31-1 {
return nil, fmt.Errorf("RSA exponent too large for 32-bit platforms: %d: %w", eVal, ErrInvalidKey)
}
key := &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(eVal),
}
// 1024-bit minimum: lower than the 2048-bit industry recommendation,
// but allows real-world compatibility with older keys and is useful
// for testing. Production deployments should use 2048+ bits.
if key.Size() < 128 { // 1024 bits minimum
return nil, fmt.Errorf("%d bits: %w", key.Size()*8, ErrKeyTooSmall)
}
return key, nil
}
func decodeEC(kj rawKey) (*ecdsa.PublicKey, error) {
x, err := base64.RawURLEncoding.DecodeString(kj.X)
if err != nil {
return nil, fmt.Errorf("invalid x: %w: %w", ErrInvalidKey, err)
}
y, err := base64.RawURLEncoding.DecodeString(kj.Y)
if err != nil {
return nil, fmt.Errorf("invalid y: %w: %w", ErrInvalidKey, err)
}
ci, err := ecInfoByCrv(kj.Crv)
if err != nil {
return nil, err
}
// Build the uncompressed point (0x04 || X || Y), left-padding each
// coordinate to the expected byte length. ParseUncompressedPublicKey
// validates that the point is on the curve.
if len(x) > ci.KeySize {
return nil, fmt.Errorf("x coordinate too long for %s: got %d bytes, want %d: %w", kj.Crv, len(x), ci.KeySize, ErrInvalidKey)
}
if len(y) > ci.KeySize {
return nil, fmt.Errorf("y coordinate too long for %s: got %d bytes, want %d: %w", kj.Crv, len(y), ci.KeySize, ErrInvalidKey)
}
uncompressed := make([]byte, 1+2*ci.KeySize)
uncompressed[0] = 0x04
copy(uncompressed[1+ci.KeySize-len(x):1+ci.KeySize], x) // left-pad X
copy(uncompressed[1+2*ci.KeySize-len(y):], y) // left-pad Y
key, err := ecdsa.ParseUncompressedPublicKey(ci.Curve, uncompressed)
if err != nil {
return nil, fmt.Errorf("EC point not on curve %s: %w: %w", kj.Crv, ErrInvalidKey, err)
}
return key, nil
}
// ParsePublicJWK parses a single JWK JSON object into a [PublicKey].
// KID is auto-computed from the RFC 7638 thumbprint if not present in the JWK.
func ParsePublicJWK(data []byte) (*PublicKey, error) {
var pk PublicKey
if err := json.Unmarshal(data, &pk); err != nil {
return nil, err
}
return &pk, nil
}
// ParsePrivateJWK parses a single JWK JSON object with private key material
// into a [PrivateKey]. The "d" field must be present.
// KID is auto-computed from the RFC 7638 thumbprint if not present in the JWK.
func ParsePrivateJWK(data []byte) (*PrivateKey, error) {
var pk PrivateKey
if err := json.Unmarshal(data, &pk); err != nil {
return nil, err
}
return &pk, nil
}
// ParseWellKnownJWKs parses a JWKS document ({"keys": [...]}) into a [WellKnownJWKs].
// Each key's KID is auto-computed from the RFC 7638 thumbprint if not present.
func ParseWellKnownJWKs(data []byte) (WellKnownJWKs, error) {
var jwks WellKnownJWKs
if err := json.Unmarshal(data, &jwks); err != nil {
return WellKnownJWKs{}, err
}
return jwks, nil
}
func decodeOKP(kj rawKey) (ed25519.PublicKey, error) {
if kj.Crv != "Ed25519" {
return nil, fmt.Errorf("crv %q (only Ed25519 supported): %w", kj.Crv, ErrUnsupportedCurve)
}
x, err := base64.RawURLEncoding.DecodeString(kj.X)
if err != nil {
return nil, fmt.Errorf("invalid x: %w: %w", ErrInvalidKey, err)
}
if len(x) != ed25519.PublicKeySize {
return nil, fmt.Errorf("Ed25519 key size %d bytes, want %d: %w", len(x), ed25519.PublicKeySize, ErrInvalidKey)
}
return ed25519.PublicKey(x), nil
}

View File

@ -1,4 +1,4 @@
// Copyright 2025 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -9,475 +9,349 @@
package jwt
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"math/big"
"slices"
"strings"
"time"
)
type Keypair struct {
Thumbprint string
PrivateKey *ecdsa.PrivateKey
// VerifiableJWT is the read-only interface implemented by [*JWT] and any
// custom JWT type. It exposes only the parsed header and payload - no mutation.
//
// Use [Verifier.VerifyJWT] to get a verified [*JWT], then call
// [RawJWT.UnmarshalClaims] to decode the payload. Or use [Decode] + [Verifier.Verify]
// for routing by header fields before verifying the signature.
type VerifiableJWT interface {
GetProtected() []byte
GetPayload() []byte
GetSignature() []byte
// GetHeader returns a copy of the decoded JOSE header fields.
GetHeader() RFCHeader
}
type JWK struct {
Kty string `json:"kty"`
Crv string `json:"crv"`
D string `json:"d"`
X string `json:"x"`
Y string `json:"y"`
// SignableJWT extends [VerifiableJWT] with the two hooks [Signer.SignJWT] needs.
// [*JWT] satisfies both [VerifiableJWT] and [SignableJWT].
//
// Custom JWT types implement SetHeader to merge the signer's standard
// fields (alg, kid, typ) with any custom header fields and store the
// encoded protected bytes. They implement SetSignature to store the result.
// No cryptographic knowledge is required - the [Signer] handles all of that.
type SignableJWT interface {
VerifiableJWT
// SetHeader encodes hdr as base64url and stores it as the protected
// header. The signer reads the result via [GetProtected].
SetHeader(hdr Header) error
// SetSignature stores the computed signature bytes.
SetSignature(sig []byte)
}
type JWT string
func (jwt JWT) Split() (string, string, string, error) {
parts := strings.Split(string(jwt), ".")
if len(parts) != 3 {
return "", "", "", fmt.Errorf("invalid JWT format")
// RawJWT holds the three base64url-encoded segments of a compact JWT.
// Embed it in custom JWT types to get [RawJWT.GetProtected],
// [RawJWT.GetPayload], [RawJWT.GetSignature], and [RawJWT.SetClaims]
// for free. Custom types only need to add GetHeader to satisfy
// [VerifiableJWT], plus SetHeader and SetSignature for [SignableJWT].
type RawJWT struct {
Protected []byte // base64url-encoded header
Payload []byte // base64url-encoded claims
Signature []byte // decoded signature bytes
}
rawHeader, rawPayload, rawSig := parts[0], parts[1], parts[2]
return rawHeader, rawPayload, rawSig, nil
// GetProtected implements [VerifiableJWT].
func (raw *RawJWT) GetProtected() []byte { return raw.Protected }
// GetPayload implements [VerifiableJWT].
func (raw *RawJWT) GetPayload() []byte { return raw.Payload }
// GetSignature implements [VerifiableJWT].
func (raw *RawJWT) GetSignature() []byte { return raw.Signature }
// SetSignature implements [SignableJWT].
func (raw *RawJWT) SetSignature(sig []byte) { raw.Signature = sig }
// MarshalJSON encodes the RawJWT as a flattened JWS JSON object
// (RFC 7515 appendix A.7):
//
// {"protected":"...","payload":"...","signature":"..."}
//
// Protected and Payload are already base64url strings and are written as-is.
// Signature is raw bytes and is base64url-encoded for the JSON output.
func (raw *RawJWT) MarshalJSON() ([]byte, error) {
return json.Marshal(flatJWS{
Protected: string(raw.Protected),
Payload: string(raw.Payload),
Signature: base64.RawURLEncoding.EncodeToString(raw.Signature),
})
}
func (jwt JWT) Decode() (JWS, error) {
h64, p64, s64, err := jwt.Split()
// UnmarshalJSON decodes a flattened JWS JSON object into the RawJWT.
func (raw *RawJWT) UnmarshalJSON(data []byte) error {
var v flatJWS
if err := json.Unmarshal(data, &v); err != nil {
return err
}
raw.Protected = []byte(v.Protected)
raw.Payload = []byte(v.Payload)
sig, err := base64.RawURLEncoding.DecodeString(v.Signature)
if err != nil {
return JWS{}, err
return fmt.Errorf("signature base64: %w", err)
}
var jws JWS
var sigEnc string
jws.Protected, jws.Payload, sigEnc = h64, p64, s64
header, err := base64.RawURLEncoding.DecodeString(jws.Protected)
if err != nil {
return jws, fmt.Errorf("invalid header encoding: %v", err)
}
if err := json.Unmarshal(header, &jws.Header); err != nil {
return jws, fmt.Errorf("invalid header JSON: %v", err)
}
payload, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return jws, fmt.Errorf("invalid claims encoding: %v", err)
}
if err := json.Unmarshal(payload, &jws.Claims); err != nil {
return jws, fmt.Errorf("invalid claims JSON: %v", err)
}
if err := jws.Signature.UnmarshalJSON([]byte(sigEnc)); err != nil {
return jws, fmt.Errorf("invalid signature encoding: %v", err)
}
return jws, nil
}
type JWS struct {
Protected string `json:"-"` // base64
Header MyHeader `json:"headers"`
Payload string `json:"-"` // base64
Claims MyClaims `json:"claims"`
Signature URLBase64 `json:"signature"`
Verified bool `json:"-"`
}
type MyHeader struct {
StandardHeader
}
type StandardHeader struct {
Alg string `json:"alg"`
Kid string `json:"kid"`
Typ string `json:"typ"`
}
type MyClaims struct {
StandardClaims
Email string `json:"email"`
EmployeeID string `json:"employee_id"`
FamilyName string `json:"family_name"`
GivenName string `json:"given_name"`
Roles []string `json:"roles"`
}
type StandardClaims struct {
Iss string `json:"iss"`
Sub string `json:"sub"`
Aud string `json:"aud"`
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
AuthTime int64 `json:"auth_time"`
Nonce string `json:"nonce,omitempty"`
Amr []string `json:"amr"`
Azp string `json:"azp,omitempty"`
Jti string `json:"jti"`
}
func UnmarshalJWK(jwk JWK) (*ecdsa.PrivateKey, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("invalid JWK X: %v", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("invalid JWK Y: %v", err)
}
d, err := base64.RawURLEncoding.DecodeString(jwk.D)
if err != nil {
return nil, fmt.Errorf("invalid JWK D: %v", err)
}
return &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
},
D: new(big.Int).SetBytes(d),
}, nil
}
func NewJWS(email, employeeID, issuer, thumbprint string, roles []string) (JWS, error) {
var jws JWS
jws.Header.StandardHeader = StandardHeader{
Alg: "ES256",
Kid: thumbprint,
Typ: "JWT",
}
headerJSON, _ := json.Marshal(jws.Header)
jws.Protected = base64.RawURLEncoding.EncodeToString(headerJSON)
now := time.Now().Unix()
jtiBytes := make([]byte, 16)
if _, err := rand.Read(jtiBytes); err != nil {
return JWS{}, fmt.Errorf("failed to generate Jti: %v", err)
}
jti := base64.RawURLEncoding.EncodeToString(jtiBytes)
emailName := strings.Split(email, "@")[0]
jws.Claims = MyClaims{
StandardClaims: StandardClaims{
AuthTime: now,
Exp: now + 15*60*37, // TODO remove
Iat: now,
Iss: issuer,
Jti: jti,
Sub: email,
Amr: []string{"pwd"},
},
Email: email,
EmployeeID: employeeID,
FamilyName: "McTestface",
GivenName: strings.ToUpper(emailName),
Roles: roles,
}
claimsJSON, _ := json.Marshal(jws.Claims)
jws.Payload = base64.RawURLEncoding.EncodeToString(claimsJSON)
return jws, nil
}
func (jws *JWS) Sign(key *ecdsa.PrivateKey) ([]byte, error) {
var err error
jws.Signature, err = SignJWS(jws.Protected, jws.Payload, key)
return jws.Signature, err
}
// UnsafeVerify only checks the signature, use Validate to check all values
func (jws *JWS) UnsafeVerify(pub *ecdsa.PublicKey) bool {
hash := sha256.Sum256([]byte(jws.Protected + "." + jws.Payload))
n := len(jws.Signature)
if n != 64 {
// return fmt.Errorf("expected a 64-byte signature consisting of two 32-byte r and s components, but got %d instead (perhaps ASN.1 or other format)", n)
return false
}
r := new(big.Int).SetBytes(jws.Signature[:32])
s := new(big.Int).SetBytes(jws.Signature[32:])
jws.Verified = ecdsa.Verify(pub, hash[:], r, s)
return jws.Verified
}
// ValidateParams holds validation configuration.
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
type ValidateParams struct {
Now time.Time
IgnoreIss bool
Iss string
IgnoreSub bool
Sub string
IgnoreAud bool
Aud string
IgnoreExp bool
IgnoreJti bool
Jti string
IgnoreIat bool
IgnoreAuthTime bool
MaxAge time.Duration
IgnoreNonce bool
Nonce string
IgnoreAmr bool
RequiredAmrs []string
IgnoreAzp bool
Azp string
IgnoreSig bool
}
// Validate checks common JWS fields and issuer, collecting all errors.
func (jws *JWS) Validate(params ValidateParams) ([]string, error) {
var errs []string
if params.Now.IsZero() {
params.Now = time.Now()
}
// Required to exist and match
if len(params.Iss) > 0 || !params.IgnoreIss {
if len(jws.Claims.Iss) == 0 {
errs = append(errs, ("missing or malformed 'iss' (token issuer, identifier for public key)"))
} else if jws.Claims.Iss != params.Iss {
errs = append(errs, fmt.Sprintf("'iss' (token issuer) mismatch: got %s, expected %s", jws.Claims.Iss, params.Iss))
}
}
// Required to exist, optional match
if len(jws.Claims.Sub) == 0 {
if !params.IgnoreSub {
errs = append(errs, ("missing or malformed 'sub' (subject, typically pairwise user id)"))
}
} else if len(params.Sub) > 0 {
if params.Sub != jws.Claims.Sub {
errs = append(errs, fmt.Sprintf("'sub' (subject, typically pairwise user id) mismatch: got %s, expected %s", jws.Claims.Sub, params.Sub))
}
}
// Required to exist and match
if len(params.Aud) > 0 || !params.IgnoreAud {
if len(jws.Claims.Aud) == 0 {
errs = append(errs, ("missing or malformed 'aud' (audience receiving token)"))
} else if jws.Claims.Aud != params.Aud {
errs = append(errs, fmt.Sprintf("'aud' (audience receiving token) mismatch: got %s, expected %s", jws.Claims.Aud, params.Aud))
}
}
// Required to exist and not be in the past
if !params.IgnoreExp {
if jws.Claims.Exp <= 0 {
errs = append(errs, ("missing or malformed 'exp' (expiration date in seconds)"))
} else if jws.Claims.Exp < params.Now.Unix() {
duration := time.Since(time.Unix(jws.Claims.Exp, 0))
expTime := time.Unix(jws.Claims.Exp, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf("token expired %s ago (%s)", formatDuration(duration), expTime))
}
}
// Required to exist and not be in the future
if !params.IgnoreIat {
if jws.Claims.Iat <= 0 {
errs = append(errs, ("missing or malformed 'iat' (issued at, when token was signed)"))
} else if jws.Claims.Iat > params.Now.Unix() {
duration := time.Unix(jws.Claims.Iat, 0).Sub(params.Now)
iatTime := time.Unix(jws.Claims.Iat, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf("'iat' (issued at, when token was signed) is %s in the future (%s)", formatDuration(duration), iatTime))
}
}
// Should exist, in the past, with optional max age
if params.MaxAge > 0 || !params.IgnoreAuthTime {
if jws.Claims.AuthTime == 0 {
errs = append(errs, ("missing or malformed 'auth_time' (time of real-world user authentication, in seconds)"))
} else {
authTime := time.Unix(jws.Claims.AuthTime, 0)
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
age := params.Now.Sub(authTime)
diff := age - params.MaxAge
if jws.Claims.AuthTime > params.Now.Unix() {
fromNow := time.Unix(jws.Claims.AuthTime, 0).Sub(params.Now)
authTimeStr := time.Unix(jws.Claims.AuthTime, 0).Format("2006-01-02 15:04:05 MST")
errs = append(errs, fmt.Sprintf(
"'auth_time' (time of real-world user authentication) of %s is %s in the future (server time %s)",
authTimeStr, formatDuration(fromNow), params.Now.Format("2006-01-02 15:04:05 MST")),
)
} else if age > params.MaxAge {
errs = append(errs, fmt.Sprintf(
"'auth_time' (time of real-world user authentication) of %s is %s old, which exceeds the max age of %s (%ds) by %s",
authTimeStr, formatDuration(age), formatDuration(params.MaxAge), params.MaxAge/time.Second, formatDuration(diff)),
)
}
}
}
// Optional
if params.Jti != jws.Claims.Jti {
if len(params.Jti) > 0 {
errs = append(errs, fmt.Sprintf("'jti' (jwt id) mismatch: got %s, expected %s", jws.Claims.Jti, params.Jti))
} else if !params.IgnoreJti {
errs = append(errs, fmt.Sprintf("unchecked 'jti' (jwt id): %s", jws.Claims.Jti))
}
}
// Optional
if params.Nonce != jws.Claims.Nonce {
if len(params.Nonce) > 0 {
errs = append(errs, fmt.Sprintf("'nonce' (one-time random salt, as string) mismatch: got %s, expected %s", jws.Claims.Nonce, params.Nonce))
} else if !params.IgnoreNonce {
errs = append(errs, fmt.Sprintf("unchecked 'nonce' (one-time random salt): %s", jws.Claims.Nonce))
}
}
// Acr check not implemented because the use case is not yet clear
// Should exist, optional match
if !params.IgnoreAmr {
if len(jws.Claims.Amr) == 0 {
errs = append(errs, ("missing or malformed 'amr' (authorization methods, as json list)"))
} else {
if len(params.RequiredAmrs) > 0 {
for _, required := range params.RequiredAmrs {
if !slices.Contains(jws.Claims.Amr, required) {
errs = append(errs, fmt.Sprintf("missing required '%s' from 'amr' (authorization methods, as json list)", required))
}
}
}
// TODO specify multiple amrs in a tiered list (must have at least one from each list)
// count := 0
// if len(params.AcceptableAmrs) > 0 {
// for _, amr := range jws.Claims.Amr {
// if slices.Contains(params.AcceptableAmrs, amr) {
// count += 1
// }
// }
// }
}
}
// Optional, should match if exists
if params.Azp != jws.Claims.Azp {
if len(params.Azp) > 0 {
errs = append(errs, ("missing or malformed 'azp' (authorized party which presents token)"))
} else if !params.IgnoreAzp {
errs = append(errs, fmt.Sprintf("'azp' mismatch (authorized party which presents token): got %s, expected %s", jws.Claims.Azp, params.Azp))
}
}
// Must be checked
if !params.IgnoreSig {
if !jws.Verified {
errs = append(errs, ("signature was not checked"))
}
}
if len(errs) > 0 {
timeInfo := fmt.Sprintf("info: server time is %s", params.Now.Format("2006-01-02 15:04:05 MST"))
if loc, err := time.LoadLocation("Local"); err == nil {
timeInfo += fmt.Sprintf(" %s", loc)
}
errs = append(errs, timeInfo)
return errs, fmt.Errorf("has errors")
}
return nil, nil
}
func SignJWS(header, payload string, key *ecdsa.PrivateKey) ([]byte, error) {
hash := sha256.Sum256([]byte(header + "." + payload))
r, s, err := ecdsa.Sign(rand.Reader, key, hash[:])
if err != nil {
return nil, fmt.Errorf("failed to sign: %v", err)
}
return append(r.Bytes(), s.Bytes()...), nil
}
func (jws JWS) Encode() string {
sigEnc := base64.RawURLEncoding.EncodeToString(jws.Signature)
return jws.Protected + "." + jws.Payload + "." + sigEnc
}
func EncodeToJWT(signingInput string, signature []byte) string {
sigEnc := base64.RawURLEncoding.EncodeToString(signature)
return signingInput + "." + sigEnc
}
func (jwk JWK) Thumbprint() (string, error) {
data := map[string]string{
"crv": jwk.Crv,
"kty": jwk.Kty,
"x": jwk.X,
"y": jwk.Y,
}
jsonData, err := json.Marshal(data)
if err != nil {
return "", err
}
hash := sha256.Sum256(jsonData)
return base64.RawURLEncoding.EncodeToString(hash[:]), nil
}
// URLBase64 unmarshals to bytes and marshals to a raw url base64 string
type URLBase64 []byte
func (s URLBase64) String() string {
encoded := base64.RawURLEncoding.EncodeToString(s)
return encoded
}
// MarshalJSON implements JSON marshaling to URL-safe base64.
func (s URLBase64) MarshalJSON() ([]byte, error) {
encoded := base64.RawURLEncoding.EncodeToString(s)
return json.Marshal(encoded)
}
// UnmarshalJSON implements JSON unmarshaling from URL-safe base64.
func (s *URLBase64) UnmarshalJSON(data []byte) error {
dst, err := base64.RawURLEncoding.AppendDecode([]byte{}, data)
if err != nil {
return fmt.Errorf("decode base64url signature: %w", err)
}
*s = dst
raw.Signature = sig
return nil
}
func formatDuration(d time.Duration) string {
if d < 0 {
d = -d
}
days := int(d / (24 * time.Hour))
d -= time.Duration(days) * 24 * time.Hour
hours := int(d / time.Hour)
d -= time.Duration(hours) * time.Hour
minutes := int(d / time.Minute)
d -= time.Duration(minutes) * time.Minute
seconds := int(d / time.Second)
var parts []string
if days > 0 {
parts = append(parts, fmt.Sprintf("%dd", days))
}
if hours > 0 {
parts = append(parts, fmt.Sprintf("%dh", hours))
}
if minutes > 0 {
parts = append(parts, fmt.Sprintf("%dm", minutes))
}
if seconds > 0 || len(parts) == 0 {
parts = append(parts, fmt.Sprintf("%ds", seconds))
}
if seconds == 0 || len(parts) == 0 {
d -= time.Duration(seconds) * time.Second
millis := int(d / time.Millisecond)
parts = append(parts, fmt.Sprintf("%dms", millis))
// flatJWS is the flattened JWS JSON serialization (RFC 7515 appendix A.7).
type flatJWS struct {
Protected string `json:"protected"`
Payload string `json:"payload"`
Signature string `json:"signature"`
}
return strings.Join(parts, " ")
// SetClaims JSON-encodes claims and stores the result as the
// base64url-encoded payload. This is the payload counterpart of
// [SetHeader] -- use it when constructing a custom JWT type before signing.
func (raw *RawJWT) SetClaims(claims Claims) error {
data, err := json.Marshal(claims)
if err != nil {
return fmt.Errorf("marshal claims: %w", err)
}
raw.Payload = []byte(base64.RawURLEncoding.EncodeToString(data))
return nil
}
// UnmarshalHeader decodes the protected header into v.
//
// Use this to extract custom JOSE header fields beyond alg/kid/typ.
// v must satisfy [Header] - typically a pointer to a struct that embeds
// [RFCHeader] so the standard fields are captured alongside custom ones:
//
// type DPoPHeader struct {
// jwt.RFCHeader
// JWK json.RawMessage `json:"jwk"`
// }
//
// raw, err := jwt.DecodeRaw(tokenStr)
// var h DPoPHeader
// if err := raw.UnmarshalHeader(&h); err != nil { /* ... */ }
//
// Promoted to [*JWT] via embedding, so it works after [Decode] too.
func (raw *RawJWT) UnmarshalHeader(h Header) error {
data, err := base64.RawURLEncoding.AppendDecode([]byte{}, raw.Protected)
if err != nil {
return fmt.Errorf("header base64: %w: %w", ErrInvalidHeader, err)
}
if err := json.Unmarshal(data, h); err != nil {
return fmt.Errorf("header json: %w: %w", ErrInvalidHeader, err)
}
return nil
}
// JWT is a decoded JSON Web Token.
//
// Technically this is a JWS (JSON Web Signature, RFC 7515) - the signed
// compact serialization that carries a header, payload, and signature.
// The term "JWT" (RFC 7519) strictly refers to the encoded string, but
// in practice everyone calls the decoded structure a JWT too, so we do
// the same.
//
// It holds only the parsed structure - header, raw base64url fields, and
// decoded signature bytes. It carries no Claims interface and no Verified flag;
// use [Verifier.VerifyJWT] or [Decode]+[Verifier.Verify] to authenticate the token
// and [RawJWT.UnmarshalClaims] to decode the payload into a typed struct.
//
// *JWT satisfies [VerifiableJWT] and [SignableJWT].
type JWT struct {
RawJWT
header jwsHeader
}
// GetHeader returns a copy of the decoded JOSE header fields.
// Implements [VerifiableJWT]. The returned value is a copy - mutations do not affect the JWT.
func (jws *JWT) GetHeader() RFCHeader { return jws.header.RFCHeader }
// Encode produces the compact JWT string (header.payload.signature).
// It is a convenience wrapper around the package-level [Encode] function.
func (jws *JWT) Encode() (string, error) { return Encode(jws) }
// SetHeader merges hdr into the internal header, encodes it as
// base64url, and stores the result. Implements [SignableJWT].
//
// Custom JWT types override this to merge hdr with their own additional
// header fields before encoding.
func (jws *JWT) SetHeader(hdr Header) error {
jws.header.RFCHeader = *hdr.GetRFCHeader()
data, err := json.Marshal(jws.header)
if err != nil {
return err
}
jws.Protected = []byte(base64.RawURLEncoding.EncodeToString(data))
return nil
}
// SetTyp overrides the JOSE "typ" header field. The new value takes effect
// when [Signer.SignJWT] re-encodes the protected header. Use this after [New]
// to change the token type before signing:
//
// tok, _ := jwt.New(claims)
// tok.SetTyp(jwt.AccessTokenTyp)
// signer.SignJWT(tok)
func (jws *JWT) SetTyp(typ string) { jws.header.Typ = typ }
// jwsHeader is an example of the pattern callers use when embedding [RFCHeader]
// in a custom JWT type. Embed [RFCHeader], and all its fields are promoted
// through the struct. To implement a custom JWT type, copy this struct and
// replace [RFCHeader] embedding with whatever custom header fields you need.
type jwsHeader struct {
RFCHeader
}
// Header is satisfied for free by any struct that embeds [RFCHeader].
//
// type DPoPHeader struct {
// jwt.RFCHeader
// JWK json.RawMessage `json:"jwk"`
// }
// // *DPoPHeader satisfies Header via promoted GetRFCHeader().
type Header interface {
GetRFCHeader() *RFCHeader
}
// RFCHeader holds the standard JOSE header fields used in the JOSE protected header.
type RFCHeader struct {
Alg string `json:"alg"`
KID string `json:"kid,omitempty"`
Typ string `json:"typ,omitempty"`
}
// GetRFCHeader implements [Header].
// Any struct embedding RFCHeader gets this method for free via promotion.
func (h *RFCHeader) GetRFCHeader() *RFCHeader { return h }
// DecodeRaw splits a compact JWT string into its three base64url segments
// and decodes the signature bytes, but does not parse the header JSON.
//
// Use this when you need to unmarshal the header into a custom struct
// with fields beyond alg/kid/typ. Call [RawJWT.UnmarshalHeader] to decode
// the protected header, or build a full [*JWT] with [Decode] instead.
func DecodeRaw(tokenStr string) (*RawJWT, error) {
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
if len(parts) == 1 && parts[0] == "" {
parts = nil
}
return nil, fmt.Errorf("%w: expected 3 segments but got %d", ErrMalformedToken, len(parts))
}
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("signature base64: %w: %w", ErrSignatureInvalid, err)
}
return &RawJWT{
Protected: []byte(parts[0]),
Payload: []byte(parts[1]),
Signature: sig,
}, nil
}
// Decode parses a compact JWT string (header.payload.signature) into a JWS.
//
// It does not unmarshal the claims payload - call [RawJWT.UnmarshalClaims] after
// [Verifier.VerifyJWT] or [Verifier.Verify] to populate a typed claims struct.
func Decode(tokenStr string) (*JWT, error) {
raw, err := DecodeRaw(tokenStr)
if err != nil {
return nil, err
}
var jws JWT
jws.RawJWT = *raw
if err := jws.UnmarshalHeader(&jws.header); err != nil {
return nil, err
}
return &jws, nil
}
// UnmarshalClaims decodes the payload into claims.
//
// Always call [Verifier.VerifyJWT] or [Decode]+[Verifier.Verify] before
// UnmarshalClaims - the signature must be authenticated before trusting the
// payload.
//
// Promoted to [*JWT] via embedding, so it works after [Decode] too.
func (raw *RawJWT) UnmarshalClaims(claims Claims) error {
payload, err := base64.RawURLEncoding.AppendDecode([]byte{}, raw.Payload)
if err != nil {
return fmt.Errorf("payload base64: %w: %w", ErrInvalidPayload, err)
}
if err := json.Unmarshal(payload, claims); err != nil {
return fmt.Errorf("payload json: %w: %w", ErrInvalidPayload, err)
}
return nil
}
// New creates an unsigned JWT from the provided claims.
//
// The "alg" and "kid" header fields are set automatically by [Signer.SignJWT]
// based on the key type and [PrivateKey.KID]. Call [Encode] or [JWT.Encode] to
// produce the compact JWT string after signing.
func New(claims Claims) (*JWT, error) {
var jws JWT
jws.header.RFCHeader = RFCHeader{
// Alg and KID are set by Sign from the key type and PrivateKey.KID.
Typ: "JWT",
}
headerJSON, err := json.Marshal(jws.header)
if err != nil {
return nil, fmt.Errorf("marshal header: %w", err)
}
jws.Protected = []byte(base64.RawURLEncoding.EncodeToString(headerJSON))
claimsJSON, err := json.Marshal(claims)
if err != nil {
return nil, fmt.Errorf("marshal claims: %w", err)
}
jws.Payload = []byte(base64.RawURLEncoding.EncodeToString(claimsJSON))
return &jws, nil
}
// NewAccessToken creates an unsigned JWT from claims with "typ" set to
// "at+jwt" per RFC 9068 §2.1. Sign with [Signer.SignJWT]:
//
// tok, err := jwt.NewAccessToken(&claims)
// if err := signer.SignJWT(tok); err != nil { /* ... */ }
// token := tok.Encode()
//
// https://www.rfc-editor.org/rfc/rfc9068.html
func NewAccessToken(claims Claims) (*JWT, error) {
jws, err := New(claims)
if err != nil {
return nil, err
}
jws.SetTyp(AccessTokenTyp)
return jws, nil
}
// Encode produces the compact JWT string (header.payload.signature).
//
// Returns an error if the protected header's alg field is empty,
// indicating the token was never signed.
func Encode(jws VerifiableJWT) (string, error) {
h := jws.GetHeader()
if h.Alg == "" {
return "", fmt.Errorf("encode: %w: alg is empty (unsigned token)", ErrInvalidHeader)
}
protected := jws.GetProtected()
payload := jws.GetPayload()
sig := base64.RawURLEncoding.EncodeToString(jws.GetSignature())
out := make([]byte, 0, len(protected)+1+len(payload)+1+len(sig))
out = append(out, protected...)
out = append(out, '.')
out = append(out, payload...)
out = append(out, '.')
out = append(out, sig...)
return string(out), nil
}

335
auth/jwt/keyfetch/fetch.go Normal file
View File

@ -0,0 +1,335 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package keyfetch
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/therootcompany/golib/auth/jwt"
)
// Error sentinels for fetch operations.
var (
// ErrFetchFailed indicates a network, HTTP, or parsing failure during
// a JWKS fetch. The wrapped message includes details (status code,
// network error, parse error, etc).
ErrFetchFailed = errors.New("fetch failed")
// ErrKeysExpired indicates the cached keys are past their hard expiry.
// Returned alongside expired keys when RefreshTimeout fires before the
// refresh completes. Use [errors.Is] to check.
ErrKeysExpired = errors.New("cached keys expired")
// ErrEmptyKeySet indicates the JWKS document contains no keys.
ErrEmptyKeySet = errors.New("empty key set")
)
// maxResponseBody is the maximum JWKS response body size (1 MiB).
// A realistic JWKS with dozens of keys is well under 100 KiB.
const maxResponseBody = 1 << 20
// Default cache policy values, used when the corresponding [KeyFetcher]
// field is zero.
const (
defaultMinTTL = 1 * time.Minute // floor - server values below this are raised
defaultMaxTTL = 24 * time.Hour // ceiling - server values above this are clamped
defaultTTL = 15 * time.Minute // used when no cache headers are present
)
// defaultTimeout is the timeout used when no HTTP client is provided.
const defaultTimeout = 30 * time.Second
// asset holds the response body and computed cache timing from a fetch.
type asset struct {
data []byte
expiry time.Time // hard expiry - do not use after this time
stale time.Time // background refresh should start at this time
etag string // opaque validator for conditional re-fetch
lastModified string // date-based validator for conditional re-fetch
}
// fetchRaw retrieves raw bytes from a URL using the given HTTP client.
// If prev is non-nil, conditional request headers (If-None-Match,
// If-Modified-Since) are sent; a 304 response refreshes the cache
// timing on prev and returns it without re-downloading the body.
// If client is nil, a default client with a 30s timeout is used.
//
// The returned *http.Response has its Body consumed and closed; headers
// remain accessible.
func fetchRaw(ctx context.Context, url string, client *http.Client, p cachePolicy, prev *asset) (*asset, *http.Response, error) {
resp, err := doGET(ctx, url, client, prev)
if err != nil {
return nil, nil, fmt.Errorf("fetch %q: %w", url, err)
}
defer func() { _ = resp.Body.Close() }()
now := time.Now()
// 304 Not Modified - reuse previous body with refreshed cache timing.
if resp.StatusCode == http.StatusNotModified && prev != nil {
expiry, stale := cacheTimings(now, resp, p)
etag := resp.Header.Get("ETag")
if etag == "" {
etag = prev.etag
}
lastMod := resp.Header.Get("Last-Modified")
if lastMod == "" {
lastMod = prev.lastModified
}
return &asset{
data: prev.data,
expiry: expiry,
stale: stale,
etag: etag,
lastModified: lastMod,
}, resp, nil
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBody+1))
if err != nil {
return nil, nil, fmt.Errorf("fetch %q: read body: %w: %w", url, ErrFetchFailed, err)
}
if len(body) > maxResponseBody {
return nil, nil, fmt.Errorf("fetch %q: response exceeds %d byte limit: %w", url, maxResponseBody, ErrFetchFailed)
}
expiry, stale := cacheTimings(now, resp, p)
return &asset{
data: body,
expiry: expiry,
stale: stale,
etag: resp.Header.Get("ETag"),
lastModified: resp.Header.Get("Last-Modified"),
}, resp, nil
}
// cachePolicy holds resolved cache tuning parameters.
type cachePolicy struct {
minTTL time.Duration
maxTTL time.Duration
defaultTTL time.Duration
}
// defaultPolicy returns a cachePolicy using the package defaults.
func defaultPolicy() cachePolicy {
return cachePolicy{
minTTL: defaultMinTTL,
maxTTL: defaultMaxTTL,
defaultTTL: defaultTTL,
}
}
// cacheTimings computes expiry and stale times from the response headers.
// Stale time is always 3/4 of the TTL.
//
// Policy:
// - No usable max-age => defaultTTL (15m), stale at 3/4
// - max-age < minTTL (1m) => minTTL*2 expiry, minTTL stale
// - max-age > maxTTL (24h) => clamped to maxTTL, stale at 3/4
// - Otherwise => server value, stale at 3/4
func cacheTimings(now time.Time, resp *http.Response, p cachePolicy) (expiry, stale time.Time) {
serverTTL := parseCacheControlMaxAge(resp.Header.Get("Cache-Control"))
if age := parseAge(resp.Header.Get("Age")); age > 0 {
serverTTL -= age
}
var ttl time.Duration
switch {
case serverTTL <= 0:
// No cache headers or max-age=0 or Age consumed it all
ttl = p.defaultTTL
case serverTTL < p.minTTL:
// Server says cache briefly - use floor
return now.Add(p.minTTL * 2), now.Add(p.minTTL)
case serverTTL > p.maxTTL:
ttl = p.maxTTL
default:
ttl = serverTTL
}
return now.Add(ttl), now.Add(ttl * 3 / 4)
}
// FetchURL retrieves and parses a JWKS document from the given JWKS endpoint URL.
//
// The response body is limited to 1 MiB. If client is nil, a default client
// with a 30s timeout is used.
//
// The returned [*http.Response] has its Body already consumed and closed.
// Headers such as ETag, Last-Modified, and Cache-Control remain accessible
// and are used internally by [KeyFetcher] for cache management.
func FetchURL(ctx context.Context, jwksURL string, client *http.Client) ([]jwt.PublicKey, *http.Response, error) {
a, resp, err := fetchRaw(ctx, jwksURL, client, defaultPolicy(), nil)
if err != nil {
return nil, nil, err
}
keys, err := parseJWKS(a.data)
if err != nil {
return nil, nil, err
}
return keys, resp, nil
}
// fetchJWKS fetches and parses a JWKS document, returning the asset for
// cache timing and the parsed keys. prev is passed through to fetchRaw
// for conditional requests.
func fetchJWKS(ctx context.Context, jwksURL string, client *http.Client, p cachePolicy, prev *asset) (*asset, []jwt.PublicKey, error) {
a, _, err := fetchRaw(ctx, jwksURL, client, p, prev)
if err != nil {
return nil, nil, err
}
keys, err := parseJWKS(a.data)
if err != nil {
return nil, nil, err
}
return a, keys, nil
}
// parseJWKS unmarshals a JWKS document into public keys.
// Returns [ErrEmptyKeySet] if the key set is empty.
func parseJWKS(data []byte) ([]jwt.PublicKey, error) {
var jwks jwt.WellKnownJWKs
if err := json.Unmarshal(data, &jwks); err != nil {
return nil, fmt.Errorf("parse JWKS: %w: %w", ErrFetchFailed, err)
}
if len(jwks.Keys) == 0 {
return nil, fmt.Errorf("parse JWKS: %w", ErrEmptyKeySet)
}
return jwks.Keys, nil
}
// parseCacheControlMaxAge extracts the max-age value from a Cache-Control header.
// Returns 0 if the header is absent or does not contain a valid max-age directive.
func parseCacheControlMaxAge(header string) time.Duration {
for part := range strings.SplitSeq(header, ",") {
part = strings.TrimSpace(part)
if val, ok := strings.CutPrefix(part, "max-age="); ok {
n, err := strconv.Atoi(val)
if err == nil && n > 0 {
return time.Duration(n) * time.Second
}
}
}
return 0
}
// parseAge extracts the Age header value as a Duration.
// Returns 0 if the header is absent or unparseable.
func parseAge(header string) time.Duration {
if header == "" {
return 0
}
n, err := strconv.Atoi(strings.TrimSpace(header))
if err != nil || n <= 0 {
return 0
}
return time.Duration(n) * time.Second
}
// FetchOIDC fetches JWKS via OIDC discovery from the given base URL.
//
// It fetches {baseURL}/.well-known/openid-configuration, reads the jwks_uri
// field, then fetches and parses the JWKS from that URI.
//
// client is used for all HTTP requests; if nil, a default 30s-timeout client is used.
func FetchOIDC(ctx context.Context, baseURL string, client *http.Client) ([]jwt.PublicKey, *http.Response, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/openid-configuration"
jwksURI, err := fetchDiscoveryURI(ctx, discoveryURL, client)
if err != nil {
return nil, nil, err
}
return FetchURL(ctx, jwksURI, client)
}
// FetchOAuth2 fetches JWKS via OAuth 2.0 authorization server metadata from the
// given base URL.
//
// https://www.rfc-editor.org/rfc/rfc8414.html
//
// It fetches {baseURL}/.well-known/oauth-authorization-server, reads the
// jwks_uri field, then fetches and parses the JWKS from that URI.
//
// client is used for all HTTP requests; if nil, a default 30s-timeout client is used.
func FetchOAuth2(ctx context.Context, baseURL string, client *http.Client) ([]jwt.PublicKey, *http.Response, error) {
discoveryURL := strings.TrimRight(baseURL, "/") + "/.well-known/oauth-authorization-server"
jwksURI, err := fetchDiscoveryURI(ctx, discoveryURL, client)
if err != nil {
return nil, nil, err
}
return FetchURL(ctx, jwksURI, client)
}
// fetchDiscoveryURI fetches a discovery document and returns the validated
// jwks_uri from it. The URI is required to use HTTPS to prevent SSRF via a
// malicious discovery document pointing at an internal endpoint.
func fetchDiscoveryURI(ctx context.Context, discoveryURL string, client *http.Client) (string, error) {
resp, err := doGET(ctx, discoveryURL, client, nil)
if err != nil {
return "", fmt.Errorf("fetch discovery: %w", err)
}
defer func() { _ = resp.Body.Close() }()
var doc struct {
JWKsURI string `json:"jwks_uri"`
}
if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseBody)).Decode(&doc); err != nil {
return "", fmt.Errorf("parse discovery doc: %w: %w", ErrFetchFailed, err)
}
if doc.JWKsURI == "" {
return "", fmt.Errorf("discovery doc missing jwks_uri: %w", ErrFetchFailed)
}
if !strings.HasPrefix(doc.JWKsURI, "https://") {
return "", fmt.Errorf("jwks_uri must be https, got %q: %w", doc.JWKsURI, ErrFetchFailed)
}
return doc.JWKsURI, nil
}
// doGET performs an HTTP GET request and returns the response. It follows
// redirects (Go's default of up to 10), handles nil client defaults, and
// checks the final status code. If prev is non-nil, conditional request
// headers are sent and a 304 response is allowed. Callers must close
// resp.Body.
func doGET(ctx context.Context, url string, client *http.Client, prev *asset) (*http.Response, error) {
if client == nil {
client = &http.Client{Timeout: defaultTimeout}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchFailed, err)
}
if prev != nil {
if prev.etag != "" {
req.Header.Set("If-None-Match", prev.etag)
}
if prev.lastModified != "" {
req.Header.Set("If-Modified-Since", prev.lastModified)
}
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchFailed, err)
}
if resp.StatusCode == http.StatusNotModified && prev != nil {
return resp, nil
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
return nil, fmt.Errorf("status %d: %w", resp.StatusCode, ErrFetchFailed)
}
return resp, nil
}

View File

@ -0,0 +1,574 @@
package keyfetch
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/therootcompany/golib/auth/jwt"
)
// testJWKS generates a fresh Ed25519 key and returns the public key plus the
// serialized JWKS document bytes.
func testJWKS(t *testing.T) (jwt.PublicKey, []byte) {
t.Helper()
priv, err := jwt.NewPrivateKey()
if err != nil {
t.Fatalf("NewPrivateKey: %v", err)
}
pub, err := priv.PublicKey()
if err != nil {
t.Fatalf("PublicKey: %v", err)
}
jwks := jwt.WellKnownJWKs{Keys: []jwt.PublicKey{*pub}}
data, err := json.Marshal(jwks)
if err != nil {
t.Fatalf("marshal JWKS: %v", err)
}
return *pub, data
}
// --- FetchURL tests ---
func TestFetchURL_Success(t *testing.T) {
pub, jwksData := testJWKS(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("ETag", `"test-etag"`)
w.Header().Set("Cache-Control", "max-age=300")
w.Write(jwksData)
}))
defer srv.Close()
keys, resp, err := FetchURL(context.Background(), srv.URL, nil)
if err != nil {
t.Fatalf("FetchURL: %v", err)
}
if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
if keys[0].KID != pub.KID {
t.Errorf("KID mismatch: got %q, want %q", keys[0].KID, pub.KID)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
if got := resp.Header.Get("ETag"); got != `"test-etag"` {
t.Errorf("ETag header: got %q, want %q", got, `"test-etag"`)
}
}
func TestFetchURL_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer srv.Close()
_, _, err := FetchURL(context.Background(), srv.URL, nil)
if err == nil {
t.Fatal("expected error for 404")
}
if !errorContains(err, ErrFetchFailed) {
t.Errorf("expected ErrFetchFailed, got: %v", err)
}
}
func TestFetchURL_500(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer srv.Close()
_, _, err := FetchURL(context.Background(), srv.URL, nil)
if err == nil {
t.Fatal("expected error for 500")
}
if !errorContains(err, ErrFetchFailed) {
t.Errorf("expected ErrFetchFailed, got: %v", err)
}
}
func TestFetchURL_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{not valid json`))
}))
defer srv.Close()
_, _, err := FetchURL(context.Background(), srv.URL, nil)
if err == nil {
t.Fatal("expected error for malformed JSON")
}
if !errorContains(err, ErrFetchFailed) {
t.Errorf("expected ErrFetchFailed, got: %v", err)
}
}
func TestFetchURL_EmptyJWKS(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"keys":[]}`))
}))
defer srv.Close()
_, _, err := FetchURL(context.Background(), srv.URL, nil)
if err == nil {
t.Fatal("expected error for empty JWKS")
}
if !errors.Is(err, ErrEmptyKeySet) {
t.Errorf("expected ErrEmptyKeySet, got: %v", err)
}
}
// --- FetchOIDC tests ---
func TestFetchOIDC_Success(t *testing.T) {
_, jwksData := testJWKS(t)
var srvURL string
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
doc := fmt.Sprintf(`{"jwks_uri": "%s/jwks.json"}`, srvURL)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(doc))
case "/jwks.json":
w.Header().Set("Content-Type", "application/json")
w.Write(jwksData)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
srvURL = srv.URL
keys, _, err := FetchOIDC(context.Background(), srv.URL, srv.Client())
if err != nil {
t.Fatalf("FetchOIDC: %v", err)
}
if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
}
// srv is declared at function scope above; this variable name is fine in a
// separate test function.
func TestFetchOIDC_NonHTTPSJwksURI(t *testing.T) {
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return a jwks_uri with http:// instead of https://
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"jwks_uri": "http://example.com/jwks.json"}`))
}))
defer srv.Close()
_, _, err := FetchOIDC(context.Background(), srv.URL, srv.Client())
if err == nil {
t.Fatal("expected error for non-https jwks_uri")
}
if !errorContains(err, ErrFetchFailed) {
t.Errorf("expected ErrFetchFailed, got: %v", err)
}
}
// --- FetchOAuth2 tests ---
func TestFetchOAuth2_Success(t *testing.T) {
_, jwksData := testJWKS(t)
var srvURL string
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/oauth-authorization-server":
doc := fmt.Sprintf(`{"jwks_uri": "%s/jwks.json"}`, srvURL)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(doc))
case "/jwks.json":
w.Header().Set("Content-Type", "application/json")
w.Write(jwksData)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
srvURL = srv.URL
keys, _, err := FetchOAuth2(context.Background(), srv.URL, srv.Client())
if err != nil {
t.Fatalf("FetchOAuth2: %v", err)
}
if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
}
// --- KeyFetcher.Verifier() tests ---
func TestKeyFetcher_Verifier_CachesResult(t *testing.T) {
_, jwksData := testJWKS(t)
var fetchCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fetchCount++
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "max-age=3600")
w.Write(jwksData)
}))
defer srv.Close()
kf := &KeyFetcher{URL: srv.URL}
// First call fetches
v1, err := kf.Verifier()
if err != nil {
t.Fatalf("first Verifier call: %v", err)
}
if v1 == nil {
t.Fatal("expected non-nil verifier")
}
if fetchCount != 1 {
t.Fatalf("expected 1 fetch, got %d", fetchCount)
}
// Second call returns cached (within TTL)
v2, err := kf.Verifier()
if err != nil {
t.Fatalf("second Verifier call: %v", err)
}
if v2 != v1 {
t.Error("expected same verifier instance from cache")
}
if fetchCount != 1 {
t.Errorf("expected still 1 fetch, got %d", fetchCount)
}
}
func TestKeyFetcher_Verifier_InitialKeys(t *testing.T) {
pub, jwksData := testJWKS(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "max-age=3600")
w.Write(jwksData)
}))
defer srv.Close()
kf := &KeyFetcher{
URL: srv.URL,
InitialKeys: []jwt.PublicKey{pub},
RefreshTimeout: 5 * time.Second,
}
// First call should return immediately with initial keys (they are
// marked expired, but RefreshTimeout lets them be served while refresh
// runs in background).
v, err := kf.Verifier()
if err != nil && !errorContains(err, ErrKeysExpired) {
t.Fatalf("first Verifier call: %v", err)
}
if v == nil {
t.Fatal("expected non-nil verifier from InitialKeys")
}
// Wait for background refresh to complete
time.Sleep(500 * time.Millisecond)
// Now should have fresh keys
v2, err := kf.Verifier()
if err != nil {
t.Fatalf("second Verifier call after refresh: %v", err)
}
if v2 == nil {
t.Fatal("expected non-nil verifier after refresh")
}
}
func TestKeyFetcher_RefreshedAt(t *testing.T) {
_, jwksData := testJWKS(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "max-age=3600")
w.Write(jwksData)
}))
defer srv.Close()
kf := &KeyFetcher{URL: srv.URL}
// Before any fetch
if !kf.RefreshedAt().IsZero() {
t.Error("RefreshedAt should be zero before first fetch")
}
before := time.Now()
_, err := kf.Verifier()
if err != nil {
t.Fatalf("Verifier: %v", err)
}
after := time.Now()
rat := kf.RefreshedAt()
if rat.Before(before) || rat.After(after) {
t.Errorf("RefreshedAt %v not between %v and %v", rat, before, after)
}
}
// --- cacheTimings tests ---
func TestCacheTimings_MaxAge(t *testing.T) {
now := time.Now()
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Cache-Control", "max-age=600")
p := defaultPolicy()
expiry, stale := cacheTimings(now, resp, p)
wantExpiry := now.Add(600 * time.Second)
wantStale := now.Add(600 * time.Second * 3 / 4)
if !timesClose(expiry, wantExpiry, time.Second) {
t.Errorf("expiry: got %v, want ~%v", expiry, wantExpiry)
}
if !timesClose(stale, wantStale, time.Second) {
t.Errorf("stale: got %v, want ~%v", stale, wantStale)
}
}
func TestCacheTimings_NoHeaders(t *testing.T) {
now := time.Now()
resp := &http.Response{Header: http.Header{}}
p := defaultPolicy()
expiry, stale := cacheTimings(now, resp, p)
wantExpiry := now.Add(defaultTTL)
wantStale := now.Add(defaultTTL * 3 / 4)
if !timesClose(expiry, wantExpiry, time.Second) {
t.Errorf("expiry: got %v, want ~%v", expiry, wantExpiry)
}
if !timesClose(stale, wantStale, time.Second) {
t.Errorf("stale: got %v, want ~%v", stale, wantStale)
}
}
func TestCacheTimings_BelowMinTTL(t *testing.T) {
now := time.Now()
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Cache-Control", "max-age=5") // 5s < 1m minTTL
p := defaultPolicy()
expiry, stale := cacheTimings(now, resp, p)
// Below min: expiry = minTTL*2, stale = minTTL
wantExpiry := now.Add(p.minTTL * 2)
wantStale := now.Add(p.minTTL)
if !timesClose(expiry, wantExpiry, time.Second) {
t.Errorf("expiry: got %v, want ~%v", expiry, wantExpiry)
}
if !timesClose(stale, wantStale, time.Second) {
t.Errorf("stale: got %v, want ~%v", stale, wantStale)
}
}
func TestCacheTimings_AboveMaxTTL(t *testing.T) {
now := time.Now()
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Cache-Control", "max-age=200000") // ~55h > 24h maxTTL
p := defaultPolicy()
expiry, stale := cacheTimings(now, resp, p)
wantExpiry := now.Add(p.maxTTL)
wantStale := now.Add(p.maxTTL * 3 / 4)
if !timesClose(expiry, wantExpiry, time.Second) {
t.Errorf("expiry: got %v, want ~%v", expiry, wantExpiry)
}
if !timesClose(stale, wantStale, time.Second) {
t.Errorf("stale: got %v, want ~%v", stale, wantStale)
}
}
func TestCacheTimings_AgeHeader(t *testing.T) {
now := time.Now()
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Cache-Control", "max-age=600")
resp.Header.Set("Age", "100")
p := defaultPolicy()
expiry, stale := cacheTimings(now, resp, p)
// Effective TTL = 600 - 100 = 500s
wantExpiry := now.Add(500 * time.Second)
wantStale := now.Add(500 * time.Second * 3 / 4)
if !timesClose(expiry, wantExpiry, time.Second) {
t.Errorf("expiry: got %v, want ~%v", expiry, wantExpiry)
}
if !timesClose(stale, wantStale, time.Second) {
t.Errorf("stale: got %v, want ~%v", stale, wantStale)
}
}
// --- Conditional request (304) tests ---
func TestConditionalRequest_304(t *testing.T) {
_, jwksData := testJWKS(t)
var requestCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
if r.Header.Get("If-None-Match") == `"test-etag"` {
w.Header().Set("Cache-Control", "max-age=600")
w.WriteHeader(http.StatusNotModified)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("ETag", `"test-etag"`)
w.Header().Set("Cache-Control", "max-age=600")
w.Write(jwksData)
}))
defer srv.Close()
p := defaultPolicy()
// First fetch: gets full body
a1, resp1, err := fetchRaw(context.Background(), srv.URL, nil, p, nil)
if err != nil {
t.Fatalf("first fetch: %v", err)
}
if resp1.StatusCode != http.StatusOK {
t.Fatalf("first fetch status: got %d, want 200", resp1.StatusCode)
}
if a1.etag != `"test-etag"` {
t.Errorf("etag not captured: got %q", a1.etag)
}
// Second fetch with prev: should get 304
a2, resp2, err := fetchRaw(context.Background(), srv.URL, nil, p, a1)
if err != nil {
t.Fatalf("conditional fetch: %v", err)
}
if resp2.StatusCode != http.StatusNotModified {
t.Fatalf("conditional fetch status: got %d, want 304", resp2.StatusCode)
}
// Body should be reused from prev
if string(a2.data) != string(a1.data) {
t.Error("304 response did not reuse previous body")
}
// Cache timing should be refreshed
if a2.expiry.Equal(a1.expiry) {
t.Error("304 response should have refreshed expiry")
}
if requestCount != 2 {
t.Errorf("expected 2 requests, got %d", requestCount)
}
}
func TestConditionalRequest_LastModified(t *testing.T) {
_, jwksData := testJWKS(t)
lastMod := "Wed, 01 Jan 2025 00:00:00 GMT"
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("If-Modified-Since") == lastMod {
w.Header().Set("Cache-Control", "max-age=600")
w.WriteHeader(http.StatusNotModified)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Last-Modified", lastMod)
w.Header().Set("Cache-Control", "max-age=600")
w.Write(jwksData)
}))
defer srv.Close()
p := defaultPolicy()
a1, _, err := fetchRaw(context.Background(), srv.URL, nil, p, nil)
if err != nil {
t.Fatalf("first fetch: %v", err)
}
if a1.lastModified != lastMod {
t.Errorf("lastModified not captured: got %q", a1.lastModified)
}
a2, resp2, err := fetchRaw(context.Background(), srv.URL, nil, p, a1)
if err != nil {
t.Fatalf("conditional fetch: %v", err)
}
if resp2.StatusCode != http.StatusNotModified {
t.Fatalf("conditional fetch status: got %d, want 304", resp2.StatusCode)
}
if string(a2.data) != string(a1.data) {
t.Error("304 response did not reuse previous body")
}
}
// --- parseCacheControlMaxAge tests ---
func TestParseCacheControlMaxAge(t *testing.T) {
tests := []struct {
header string
want time.Duration
}{
{"max-age=300", 300 * time.Second},
{"public, max-age=600", 600 * time.Second},
{"max-age=0", 0},
{"no-cache", 0},
{"", 0},
{"max-age=abc", 0},
{"max-age=-1", 0},
}
for _, tt := range tests {
got := parseCacheControlMaxAge(tt.header)
if got != tt.want {
t.Errorf("parseCacheControlMaxAge(%q) = %v, want %v", tt.header, got, tt.want)
}
}
}
// --- parseAge tests ---
func TestParseAge(t *testing.T) {
tests := []struct {
header string
want time.Duration
}{
{"100", 100 * time.Second},
{"0", 0},
{"-5", 0},
{"", 0},
{"abc", 0},
}
for _, tt := range tests {
got := parseAge(tt.header)
if got != tt.want {
t.Errorf("parseAge(%q) = %v, want %v", tt.header, got, tt.want)
}
}
}
// --- helpers ---
func errorContains(err, target error) bool {
return errors.Is(err, target)
}
func timesClose(a, b time.Time, tolerance time.Duration) bool {
diff := a.Sub(b)
if diff < 0 {
diff = -diff
}
return diff <= tolerance
}

View File

@ -0,0 +1,350 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
// Package keyfetch lazily fetches and caches JWKS keys from remote URLs.
//
// [KeyFetcher] returns a [jwt.Verifier] on demand, refreshing keys in the
// background when they expire. For one-shot fetches without caching, use
// [FetchURL], [FetchOIDC], or [FetchOAuth2].
package keyfetch
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/therootcompany/golib/auth/jwt"
)
// cachedVerifier bundles a [*jwt.Verifier] with its freshness window.
// Stored atomically in [KeyFetcher]; immutable after creation.
type cachedVerifier struct {
verifier *jwt.Verifier
staleAt time.Time // background refresh should start
hardExp time.Time // do not serve after this time
refreshedAt time.Time // when this verifier was fetched
}
// KeyFetcher lazily fetches and caches JWKS keys from a remote URL,
// returning a [*jwt.Verifier] on demand.
//
// Cache timing is derived from the server's Cache-Control headers with
// sane defaults: 15m when absent, floored at 2m, capped at 24h. The
// "stale" point (3/4 of expiry) triggers a background refresh while
// serving cached keys. See [cacheTimings] for the full policy.
//
// When cached keys are fresh (before stale time), [KeyFetcher.Verifier]
// returns immediately with no network call. When stale but not expired,
// the cached Verifier is returned immediately and a background refresh
// is started. Only when there are no cached keys does [KeyFetcher.Verifier]
// block until the first successful fetch.
//
// When keys are past their hard expiry, RefreshTimeout controls the
// behavior: if zero (default), Verifier blocks until the refresh completes.
// If positive, Verifier waits up to that duration and returns the expired
// keys if the refresh hasn't finished.
//
// InitialKeys, if set, pre-populate the cache as immediately expired on
// the first call to [KeyFetcher.Verifier]. Combined with a positive
// RefreshTimeout, they are served immediately while a background refresh
// fetches fresh keys - useful for bootstrapping auth without a blocking
// fetch at startup.
//
// There is no persistent background goroutine: refreshes are started on
// demand and run until the HTTP client's timeout fires or the fetch succeeds.
//
// Use [KeyFetcher.RefreshedAt] to detect when keys have been updated
// (e.g. to persist them to disk for faster restarts).
//
// Use [NewKeyFetcher] to validate the URL upfront. Fields must be set
// before the first call to [KeyFetcher.Verifier]; do not modify them
// concurrently.
//
// KeyFetcher is safe for concurrent use. Multiple goroutines may call
// [KeyFetcher.Verifier] simultaneously without additional synchronization.
//
// Typical usage:
//
// fetcher, err := keyfetch.NewKeyFetcher("https://accounts.example.com/.well-known/jwks.json")
// // ...
// v, err := fetcher.Verifier()
type KeyFetcher struct {
// URL is the JWKS endpoint to fetch keys from.
URL string
// RefreshTimeout controls how long Verifier waits for a refresh when
// cached keys are past their hard expiry. If zero (default), Verifier
// blocks until the fetch completes. If positive, Verifier waits up to
// this duration and returns expired keys if the fetch hasn't finished.
//
// Has no effect when keys are stale but not expired (Verifier always
// returns immediately in that case) or when no cached keys exist
// (the first fetch always blocks).
RefreshTimeout time.Duration
// HTTPClient is the HTTP client used for all JWKS fetches. If nil, a
// default client with a 30s timeout is used. Providing a reusable client
// enables TCP connection pooling across refreshes.
//
// The client's Timeout controls how long a refresh may run. A long value
// (e.g. 120s) is appropriate - JWKS fetching is not tied to individual
// request lifetimes and should be allowed to eventually succeed.
HTTPClient *http.Client
// MinTTL is the minimum cache duration. Server values below this are raised.
// The floor case uses MinTTL as the stale time and MinTTL*2 as expiry.
// Defaults to 1 minute.
MinTTL time.Duration
// MaxTTL is the maximum cache duration. Server values above this are clamped.
// Defaults to 24 hours.
MaxTTL time.Duration
// TTL is the cache duration used when the server provides no Cache-Control
// max-age header. Defaults to 15 minutes.
TTL time.Duration
// InitialKeys pre-populate the cache as immediately expired on the first
// call to Verifier. Combined with a positive RefreshTimeout, they are
// served while a background refresh fetches fresh keys.
InitialKeys []jwt.PublicKey
fetchMu sync.Mutex // held during HTTP fetch
ctrlMu sync.Mutex // held briefly for refreshing/lastErr
cached atomic.Pointer[cachedVerifier]
lastAsset *asset // previous fetch result for conditional requests; guarded by fetchMu
initOnce sync.Once
initErr error // stored by initOnce for subsequent callers
refreshing bool // true while a background refresh goroutine is running
refreshDone chan struct{} // closed when the current refresh completes
lastErr error // last background refresh error; cleared on success
}
// NewKeyFetcher creates a [KeyFetcher] with the given JWKS endpoint URL.
// Returns an error if the URL is not a valid absolute URL.
func NewKeyFetcher(jwksURL string) (*KeyFetcher, error) {
u, err := url.Parse(jwksURL)
if err != nil {
return nil, fmt.Errorf("keyfetch: invalid URL: %w", err)
}
if !u.IsAbs() {
return nil, fmt.Errorf("keyfetch: URL must be absolute: %q", jwksURL)
}
return &KeyFetcher{URL: jwksURL}, nil
}
// RefreshedAt returns the time the cached keys were last successfully
// fetched. Returns the zero time if no fetch has completed yet.
//
// Use this to detect when keys have changed - for example, to persist
// them to disk only when they've been updated.
func (f *KeyFetcher) RefreshedAt() time.Time {
ci := f.cached.Load()
if ci == nil {
return time.Time{}
}
return ci.refreshedAt
}
// Verifier returns a [*jwt.Verifier] for verifying tokens.
//
// Verifier intentionally does not take a [context.Context]: the background
// JWKS refresh must not be canceled when a single client request finishes
// or times out. The HTTP client's own Timeout (or a 30s default) bounds
// the fetch duration instead.
//
// Cache staleness and expiry are determined by the wall clock (time.Now).
// This is intentional: cache management is runtime infrastructure, not
// claim validation. For testable time-based claim checks, see
// [jwt.Validator.Validate] which accepts a caller-supplied time.
//
// Fresh (before stale time): returned immediately, no network call.
//
// Stale (past stale time, before hard expiry): returned immediately.
// A background refresh is started if one is not already running.
//
// Expired (past hard expiry, RefreshTimeout > 0): a refresh is started
// and Verifier waits up to RefreshTimeout. If the refresh completes in
// time, fresh keys are returned. Otherwise, expired keys are returned.
//
// No cache or expired with RefreshTimeout=0: blocks until a fetch completes.
func (f *KeyFetcher) Verifier() (*jwt.Verifier, error) {
if len(f.InitialKeys) > 0 {
f.initOnce.Do(func() {
v, err := jwt.NewVerifier(f.InitialKeys)
if err != nil {
f.initErr = err
return
}
now := time.Now()
ci := &cachedVerifier{
verifier: v,
staleAt: now, // immediately expired - triggers refresh
hardExp: now,
}
f.cached.Store(ci)
})
f.ctrlMu.Lock()
if f.initErr != nil {
err := f.initErr
f.initErr = nil // allow subsequent calls to fall through to fetch
f.ctrlMu.Unlock()
return nil, fmt.Errorf("InitialKeys: %w", err)
}
f.ctrlMu.Unlock()
}
now := time.Now()
ci := f.cached.Load()
// Fast path: fresh keys (before stale time).
if ci != nil && now.Before(ci.staleAt) {
return ci.verifier, nil
}
// Stale path: keys still valid, return immediately and refresh in
// the background. No lock contention - ensureRefreshing holds ctrlMu
// only briefly.
if ci != nil && now.Before(ci.hardExp) {
f.ensureRefreshing()
return ci.verifier, nil
}
// Expired with timeout: wait for the refresh, fall back to expired keys.
if ci != nil && f.RefreshTimeout > 0 {
done := f.ensureRefreshing()
timer := time.NewTimer(f.RefreshTimeout)
defer timer.Stop()
select {
case <-done:
if newCI := f.cached.Load(); newCI != nil && time.Now().Before(newCI.staleAt) {
return newCI.verifier, nil
}
case <-timer.C:
}
// Timeout or refresh failed - return expired keys with error.
f.ctrlMu.Lock()
lastErr := f.lastErr
f.ctrlMu.Unlock()
if lastErr != nil {
return ci.verifier, errors.Join(ErrKeysExpired, lastErr)
}
return ci.verifier, ErrKeysExpired
}
// Blocking path: no usable cache - wait for a fetch.
// fetchMu serializes concurrent blocking callers; the re-check prevents
// a redundant fetch if another goroutine already refreshed.
f.fetchMu.Lock()
defer f.fetchMu.Unlock()
// Re-check after acquiring lock - another goroutine may have refreshed.
now = time.Now()
if ci := f.cached.Load(); ci != nil && now.Before(ci.staleAt) {
return ci.verifier, nil
}
return f.fetch()
}
// ensureRefreshing starts a background refresh if one is not already running.
// Returns a channel that is closed when the current refresh completes.
func (f *KeyFetcher) ensureRefreshing() <-chan struct{} {
f.ctrlMu.Lock()
defer f.ctrlMu.Unlock()
if !f.refreshing {
f.refreshing = true
f.refreshDone = make(chan struct{})
go f.backgroundRefresh()
}
return f.refreshDone
}
// backgroundRefresh fetches fresh keys without blocking callers.
// It acquires fetchMu for the duration of the HTTP request so that any
// concurrent blocking callers wait for this fetch rather than issuing
// a redundant request.
func (f *KeyFetcher) backgroundRefresh() {
f.fetchMu.Lock()
defer f.fetchMu.Unlock()
// Re-check: a blocking caller may have already refreshed while we
// were waiting to acquire the lock.
now := time.Now()
if ci := f.cached.Load(); ci != nil && now.Before(ci.staleAt) {
f.ctrlMu.Lock()
f.refreshing = false
close(f.refreshDone)
f.ctrlMu.Unlock()
return
}
_, err := f.fetch()
f.ctrlMu.Lock()
f.refreshing = false
close(f.refreshDone)
if err != nil {
f.lastErr = err
} else {
f.lastErr = nil
}
f.ctrlMu.Unlock()
}
// policy returns a cachePolicy from the fetcher's fields, falling back
// to package defaults for zero values.
func (f *KeyFetcher) policy() cachePolicy {
p := defaultPolicy()
if f.MinTTL > 0 {
p.minTTL = f.MinTTL
}
if f.MaxTTL > 0 {
p.maxTTL = f.MaxTTL
}
if f.TTL > 0 {
p.defaultTTL = f.TTL
}
return p
}
// fetch performs the HTTP request and stores the result. Must be called with fetchMu held.
func (f *KeyFetcher) fetch() (*jwt.Verifier, error) {
// Apply a context timeout only when no HTTPClient timeout is set,
// avoiding a redundant double-timeout.
ctx := context.Background()
if f.HTTPClient == nil || f.HTTPClient.Timeout <= 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
defer cancel()
}
a, keys, err := fetchJWKS(ctx, f.URL, f.HTTPClient, f.policy(), f.lastAsset)
if err != nil {
return nil, fmt.Errorf("fetch JWKS from %s: %w", f.URL, err)
}
f.lastAsset = a
v, err := jwt.NewVerifier(keys)
if err != nil {
return nil, fmt.Errorf("fetch JWKS from %s: %w", f.URL, err)
}
ci := &cachedVerifier{
verifier: v,
staleAt: a.stale,
hardExp: a.expiry,
refreshedAt: time.Now(),
}
f.cached.Store(ci)
return ci.verifier, nil
}

260
auth/jwt/keyfile/keyfile.go Normal file
View File

@ -0,0 +1,260 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
// Package keyfile loads cryptographic keys from local files in JWK, PEM,
// or DER format. All functions auto-compute KID from the RFC 7638 thumbprint
// when not already set.
//
// The Load* functions accept a file path and read from the local filesystem.
// The Parse* functions accept raw bytes, suitable for use with [embed.FS]
// or any other byte source.
//
// For JWK JSON format, the Parse functions in the [jwt] package
// ([jwt.ParsePublicJWK], [jwt.ParsePrivateJWK], [jwt.ParseWellKnownJWKs])
// can also be used directly.
//
// For fetching keys from remote URLs, use [keyfetch.FetchURL] (JWKS endpoints)
// or fetch the bytes yourself and pass them to Parse*.
package keyfile
import (
"crypto"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"github.com/therootcompany/golib/auth/jwt"
"os"
)
// --- Parse functions (bytes => key) ---
// ParsePublicPEM parses a PEM-encoded public key (SPKI "PUBLIC KEY" or
// PKCS#1 "RSA PUBLIC KEY") into a [jwt.PublicKey] with auto-computed KID.
func ParsePublicPEM(data []byte) (*jwt.PublicKey, error) {
block, _ := pem.Decode(data)
if block == nil {
return nil, fmt.Errorf("no PEM block found: %w", jwt.ErrInvalidKey)
}
return parsePublicPEMBlock(block)
}
// ParsePrivatePEM parses a PEM-encoded private key (PKCS#8 "PRIVATE KEY",
// PKCS#1 "RSA PRIVATE KEY", or SEC 1 "EC PRIVATE KEY") into a
// [jwt.PrivateKey] with auto-computed KID.
func ParsePrivatePEM(data []byte) (*jwt.PrivateKey, error) {
block, _ := pem.Decode(data)
if block == nil {
return nil, fmt.Errorf("no PEM block found: %w", jwt.ErrInvalidKey)
}
return parsePrivatePEMBlock(block)
}
// ParsePublicDER parses a DER-encoded public key into a [jwt.PublicKey] with
// auto-computed KID. It tries SPKI (PKIX) first, then PKCS#1 RSA.
func ParsePublicDER(data []byte) (*jwt.PublicKey, error) {
// Try SPKI / PKIX (most common modern format).
if pub, err := x509.ParsePKIXPublicKey(data); err == nil {
return jwt.FromPublicKey(pub)
}
// Try PKCS#1 RSA.
if pub, err := x509.ParsePKCS1PublicKey(data); err == nil {
return jwt.FromPublicKey(pub)
}
return nil, fmt.Errorf("unrecognized DER public key encoding (tried PKIX, PKCS1): %w", jwt.ErrInvalidKey)
}
// ParsePrivateDER parses a DER-encoded private key into a [jwt.PrivateKey]
// with auto-computed KID. It tries PKCS#8 first, then SEC 1 EC, then PKCS#1 RSA.
func ParsePrivateDER(data []byte) (*jwt.PrivateKey, error) {
// Try PKCS#8 (most common modern format, any algorithm).
if key, err := x509.ParsePKCS8PrivateKey(data); err == nil {
signer, ok := key.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("PKCS#8 key does not implement crypto.Signer: %w", jwt.ErrUnsupportedKeyType)
}
return fullPrivateKey(signer)
}
// Try SEC 1 EC.
if key, err := x509.ParseECPrivateKey(data); err == nil {
return fullPrivateKey(key)
}
// Try PKCS#1 RSA.
if key, err := x509.ParsePKCS1PrivateKey(data); err == nil {
return fullPrivateKey(key)
}
return nil, fmt.Errorf("unrecognized DER private key encoding (tried PKCS8, EC, PKCS1): %w", jwt.ErrInvalidKey)
}
// --- Load functions (file path => key) ---
// LoadPublicJWK loads a single JWK from a local file.
func LoadPublicJWK(path string) (*jwt.PublicKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return jwt.ParsePublicJWK(data)
}
// LoadPrivateJWK loads a single private JWK from a local file.
func LoadPrivateJWK(path string) (*jwt.PrivateKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return jwt.ParsePrivateJWK(data)
}
// LoadWellKnownJWKs loads a JWKS document from a local file.
func LoadWellKnownJWKs(path string) (jwt.WellKnownJWKs, error) {
data, err := os.ReadFile(path)
if err != nil {
return jwt.WellKnownJWKs{}, err
}
return jwt.ParseWellKnownJWKs(data)
}
// --- Save functions (key => file) ---
// SavePublicJWK writes a single public key as a JWK JSON file.
// The file is created with mode 0644 (world-readable).
func SavePublicJWK(path string, pub *jwt.PublicKey) error {
data, err := json.Marshal(pub)
if err != nil {
return fmt.Errorf("marshal public JWK: %w", err)
}
data = append(data, '\n')
return os.WriteFile(path, data, 0644)
}
// SavePublicJWKs writes a JWKS document ({"keys": [...]}) as a JSON file.
// The file is created with mode 0644 (world-readable).
func SavePublicJWKs(path string, keys []jwt.PublicKey) error {
jwks := jwt.WellKnownJWKs{Keys: keys}
data, err := json.Marshal(jwks)
if err != nil {
return fmt.Errorf("marshal JWKS: %w", err)
}
data = append(data, '\n')
return os.WriteFile(path, data, 0644)
}
// SavePrivateJWK writes a single private key as a JWK JSON file.
// The file is created with mode 0600 (owner-only).
func SavePrivateJWK(path string, priv *jwt.PrivateKey) error {
data, err := json.Marshal(priv)
if err != nil {
return fmt.Errorf("marshal private JWK: %w", err)
}
data = append(data, '\n')
return os.WriteFile(path, data, 0600)
}
// LoadPublicPEM loads a PEM-encoded public key from a local file.
func LoadPublicPEM(path string) (*jwt.PublicKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return ParsePublicPEM(data)
}
// LoadPrivatePEM loads a PEM-encoded private key from a local file.
func LoadPrivatePEM(path string) (*jwt.PrivateKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return ParsePrivatePEM(data)
}
// LoadPublicDER loads a DER-encoded public key from a local file.
func LoadPublicDER(path string) (*jwt.PublicKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return ParsePublicDER(data)
}
// LoadPrivateDER loads a DER-encoded private key from a local file.
func LoadPrivateDER(path string) (*jwt.PrivateKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return ParsePrivateDER(data)
}
// --- Internal helpers ---
// parsePublicPEMBlock parses a decoded PEM block into a [jwt.PublicKey].
func parsePublicPEMBlock(block *pem.Block) (*jwt.PublicKey, error) {
switch block.Type {
case "PUBLIC KEY":
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse SPKI public key: %w: %w", jwt.ErrInvalidKey, err)
}
return jwt.FromPublicKey(pub)
case "RSA PUBLIC KEY":
pub, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse PKCS#1 public key: %w: %w", jwt.ErrInvalidKey, err)
}
return jwt.FromPublicKey(pub)
default:
return nil, fmt.Errorf("PEM block type %q: %w", block.Type, jwt.ErrUnsupportedFormat)
}
}
// parsePrivatePEMBlock parses a decoded PEM block into a [jwt.PrivateKey].
func parsePrivatePEMBlock(block *pem.Block) (*jwt.PrivateKey, error) {
switch block.Type {
case "PRIVATE KEY":
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse PKCS#8 private key: %w: %w", jwt.ErrInvalidKey, err)
}
signer, ok := key.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("PKCS#8 key does not implement crypto.Signer: %w", jwt.ErrUnsupportedKeyType)
}
return fullPrivateKey(signer)
case "RSA PRIVATE KEY":
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse PKCS#1 private key: %w: %w", jwt.ErrInvalidKey, err)
}
return fullPrivateKey(key)
case "EC PRIVATE KEY":
key, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse SEC 1 EC private key: %w: %w", jwt.ErrInvalidKey, err)
}
return fullPrivateKey(key)
default:
return nil, fmt.Errorf("PEM block type %q: %w", block.Type, jwt.ErrUnsupportedFormat)
}
}
// fullPrivateKey wraps a crypto.Signer into a *PrivateKey with
// auto-computed KID (thumbprint) for file-loaded keys.
func fullPrivateKey(signer crypto.Signer) (*jwt.PrivateKey, error) {
pk, err := jwt.FromPrivateKey(signer, "")
if err != nil {
return nil, err
}
kid, err := pk.Thumbprint()
if err != nil {
return nil, fmt.Errorf("compute thumbprint: %w", err)
}
pk.KID = kid
return pk, nil
}

View File

@ -0,0 +1,570 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package keyfile_test
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"os"
"path/filepath"
"testing"
"github.com/therootcompany/golib/auth/jwt"
"github.com/therootcompany/golib/auth/jwt/keyfile"
)
func mustFromPrivateKey(t *testing.T, signer crypto.Signer) *jwt.PrivateKey {
t.Helper()
pk, err := jwt.FromPrivateKey(signer, "")
if err != nil {
t.Fatal(err)
}
return pk
}
// --- PEM round-trip tests ---
func TestParsePrivatePEM_Ed25519(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
der, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
t.Fatal(err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
pk, err := keyfile.ParsePrivatePEM(pemBytes)
if err != nil {
t.Fatalf("ParsePrivatePEM: %v", err)
}
if pk.KID == "" {
t.Error("KID should be auto-computed from thumbprint")
}
if pk.Alg != "EdDSA" {
t.Errorf("Alg: got %q, want %q", pk.Alg, "EdDSA")
}
if _, err := pk.PublicKey(); err != nil {
t.Error("should be able to derive public key from loaded private key")
}
}
func TestParsePrivatePEM_ECDSA(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
der, err := x509.MarshalECPrivateKey(priv)
if err != nil {
t.Fatal(err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
pk, err := keyfile.ParsePrivatePEM(pemBytes)
if err != nil {
t.Fatalf("ParsePrivatePEM: %v", err)
}
if pk.KID == "" {
t.Error("KID should be auto-computed")
}
if pk.Alg != "ES256" {
t.Errorf("Alg: got %q, want %q", pk.Alg, "ES256")
}
}
func TestParsePrivatePEM_RSA(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
})
pk, err := keyfile.ParsePrivatePEM(pemBytes)
if err != nil {
t.Fatalf("ParsePrivatePEM: %v", err)
}
if pk.KID == "" {
t.Error("KID should be auto-computed")
}
if pk.Alg != "RS256" {
t.Errorf("Alg: got %q, want %q", pk.Alg, "RS256")
}
}
func TestParsePublicPEM(t *testing.T) {
_, pub, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
der, err := x509.MarshalPKIXPublicKey(pub.Public())
if err != nil {
t.Fatal(err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: der})
pk, err := keyfile.ParsePublicPEM(pemBytes)
if err != nil {
t.Fatalf("ParsePublicPEM: %v", err)
}
if pk.KID == "" {
t.Error("KID should be auto-computed")
}
if pk.Alg != "EdDSA" {
t.Errorf("Alg: got %q, want %q", pk.Alg, "EdDSA")
}
}
func TestParsePrivatePEM_UnsupportedBlockType(t *testing.T) {
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("dummy")})
_, err := keyfile.ParsePrivatePEM(pemBytes)
if !errors.Is(err, jwt.ErrUnsupportedFormat) {
t.Fatalf("expected ErrUnsupportedFormat, got: %v", err)
}
}
func TestParsePrivatePEM_NoPEMBlock(t *testing.T) {
_, err := keyfile.ParsePrivatePEM([]byte("not pem data"))
if !errors.Is(err, jwt.ErrInvalidKey) {
t.Fatalf("expected ErrInvalidKey, got: %v", err)
}
}
// --- DER round-trip tests ---
func TestParsePrivateDER_PKCS8(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
der, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
t.Fatal(err)
}
pk, err := keyfile.ParsePrivateDER(der)
if err != nil {
t.Fatalf("ParsePrivateDER: %v", err)
}
if pk.KID == "" {
t.Error("KID should be auto-computed")
}
if pk.Alg != "EdDSA" {
t.Errorf("Alg: got %q, want %q", pk.Alg, "EdDSA")
}
}
func TestParsePublicDER_SPKI(t *testing.T) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
der, err := x509.MarshalPKIXPublicKey(&priv.PublicKey)
if err != nil {
t.Fatal(err)
}
pk, err := keyfile.ParsePublicDER(der)
if err != nil {
t.Fatalf("ParsePublicDER: %v", err)
}
if pk.KID == "" {
t.Error("KID should be auto-computed")
}
if pk.Alg != "ES384" {
t.Errorf("Alg: got %q, want %q", pk.Alg, "ES384")
}
}
func TestParsePrivateDER_InvalidData(t *testing.T) {
_, err := keyfile.ParsePrivateDER([]byte("not der data"))
if !errors.Is(err, jwt.ErrInvalidKey) {
t.Fatalf("expected ErrInvalidKey, got: %v", err)
}
}
// --- JWK parse wrapper tests ---
func TestParsePublicJWK(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
signer := mustFromPrivateKey(t, priv)
pub, err := signer.PublicKey()
if err != nil {
t.Fatal(err)
}
data, err := json.Marshal(pub)
if err != nil {
t.Fatal(err)
}
pk, err := jwt.ParsePublicJWK(data)
if err != nil {
t.Fatalf("ParsePublicJWK: %v", err)
}
if pk.KID == "" {
t.Error("KID should be set")
}
if pk.Key == nil {
t.Error("Key should be set")
}
}
func TestParsePrivateJWK(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
pk, err := jwt.FromPrivateKey(priv, "test-kid")
if err != nil {
t.Fatal(err)
}
data, err := json.Marshal(pk)
if err != nil {
t.Fatal(err)
}
parsed, err := jwt.ParsePrivateJWK(data)
if err != nil {
t.Fatalf("ParsePrivateJWK: %v", err)
}
if parsed.KID != "test-kid" {
t.Errorf("KID: got %q, want %q", parsed.KID, "test-kid")
}
if _, err := parsed.PublicKey(); err != nil {
t.Error("should be able to derive public key from parsed private key")
}
}
// --- Load function tests (file-based) ---
func TestLoadPrivatePEM_FromFile(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
der, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
t.Fatal(err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
path := filepath.Join(t.TempDir(), "key.pem")
if err := os.WriteFile(path, pemBytes, 0600); err != nil {
t.Fatal(err)
}
pk, err := keyfile.LoadPrivatePEM(path)
if err != nil {
t.Fatalf("LoadPrivatePEM: %v", err)
}
if pk.KID == "" {
t.Error("KID should be auto-computed")
}
if pk.Alg != "EdDSA" {
t.Errorf("Alg: got %q, want %q", pk.Alg, "EdDSA")
}
}
func TestLoadPublicJWK_FromFile(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
signer := mustFromPrivateKey(t, priv)
pub, err := signer.PublicKey()
if err != nil {
t.Fatal(err)
}
data, err := json.Marshal(pub)
if err != nil {
t.Fatal(err)
}
path := filepath.Join(t.TempDir(), "key.jwk")
if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatal(err)
}
pk, err := keyfile.LoadPublicJWK(path)
if err != nil {
t.Fatalf("LoadPublicJWK: %v", err)
}
if pk.KID == "" {
t.Error("KID should be set")
}
}
// --- LoadWellKnownJWKs tests ---
func TestLoadWellKnownJWKs_SingleJWK(t *testing.T) {
// A JWKS with a single key should parse and return one key.
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
pub, err := jwt.FromPublicKey(priv.Public())
if err != nil {
t.Fatal(err)
}
jwks := jwt.WellKnownJWKs{Keys: []jwt.PublicKey{*pub}}
data, err := json.Marshal(jwks)
if err != nil {
t.Fatal(err)
}
path := filepath.Join(t.TempDir(), "jwks.json")
if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatal(err)
}
loaded, err := keyfile.LoadWellKnownJWKs(path)
if err != nil {
t.Fatalf("LoadPublicJWKs: %v", err)
}
if len(loaded.Keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(loaded.Keys))
}
if loaded.Keys[0].KID != pub.KID {
t.Errorf("KID: got %q, want %q", loaded.Keys[0].KID, pub.KID)
}
}
func TestLoadWellKnownJWKs_MultipleKeys(t *testing.T) {
// A JWKS with multiple keys of different types.
_, edPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
edPub, err := jwt.FromPublicKey(edPriv.Public())
if err != nil {
t.Fatal(err)
}
ecPriv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
ecPub, err := jwt.FromPublicKey(&ecPriv.PublicKey)
if err != nil {
t.Fatal(err)
}
rsaPriv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
rsaPub, err := jwt.FromPublicKey(&rsaPriv.PublicKey)
if err != nil {
t.Fatal(err)
}
jwks := jwt.WellKnownJWKs{Keys: []jwt.PublicKey{*edPub, *ecPub, *rsaPub}}
data, err := json.Marshal(jwks)
if err != nil {
t.Fatal(err)
}
path := filepath.Join(t.TempDir(), "jwks.json")
if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatal(err)
}
loaded, err := keyfile.LoadWellKnownJWKs(path)
if err != nil {
t.Fatalf("LoadPublicJWKs: %v", err)
}
if len(loaded.Keys) != 3 {
t.Fatalf("expected 3 keys, got %d", len(loaded.Keys))
}
// Verify each key retained its KID.
wantKIDs := []string{edPub.KID, ecPub.KID, rsaPub.KID}
for i, want := range wantKIDs {
if loaded.Keys[i].KID != want {
t.Errorf("key[%d] KID: got %q, want %q", i, loaded.Keys[i].KID, want)
}
}
}
func TestLoadWellKnownJWKs_PEMFile(t *testing.T) {
// LoadWellKnownJWKs expects JWKS JSON, so a PEM file should fail.
_, pub, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
der, err := x509.MarshalPKIXPublicKey(pub.Public())
if err != nil {
t.Fatal(err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: der})
path := filepath.Join(t.TempDir(), "key.pem")
if err := os.WriteFile(path, pemBytes, 0644); err != nil {
t.Fatal(err)
}
_, err = keyfile.LoadWellKnownJWKs(path)
if err == nil {
t.Fatal("expected error for PEM file, got nil")
}
}
func TestLoadWellKnownJWKs_FileNotFound(t *testing.T) {
_, err := keyfile.LoadWellKnownJWKs(filepath.Join(t.TempDir(), "nonexistent.json"))
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("expected os.ErrNotExist, got: %v", err)
}
}
func TestLoadWellKnownJWKs_InvalidContent(t *testing.T) {
path := filepath.Join(t.TempDir(), "bad.json")
if err := os.WriteFile(path, []byte("this is not json"), 0644); err != nil {
t.Fatal(err)
}
_, err := keyfile.LoadWellKnownJWKs(path)
if err == nil {
t.Fatal("expected error for corrupt content, got nil")
}
}
// --- Save round-trip tests ---
func TestSavePublicJWK_RoundTrip(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
pub, err := jwt.FromPublicKey(priv.Public())
if err != nil {
t.Fatal(err)
}
path := filepath.Join(t.TempDir(), "pub.jwk")
if err := keyfile.SavePublicJWK(path, pub); err != nil {
t.Fatalf("SavePublicJWK: %v", err)
}
loaded, err := keyfile.LoadPublicJWK(path)
if err != nil {
t.Fatalf("LoadPublicJWK: %v", err)
}
if loaded.KID != pub.KID {
t.Errorf("KID mismatch: saved %q, loaded %q", pub.KID, loaded.KID)
}
// Verify file permissions are world-readable.
info, err := os.Stat(path)
if err != nil {
t.Fatal(err)
}
if perm := info.Mode().Perm(); perm != 0644 {
t.Errorf("file mode: got %o, want 0644", perm)
}
}
func TestSavePrivateJWK_RoundTrip(t *testing.T) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
pk, err := jwt.FromPrivateKey(priv, "")
if err != nil {
t.Fatal(err)
}
kid, err := pk.Thumbprint()
if err != nil {
t.Fatal(err)
}
pk.KID = kid
path := filepath.Join(t.TempDir(), "priv.jwk")
if err := keyfile.SavePrivateJWK(path, pk); err != nil {
t.Fatalf("SavePrivateJWK: %v", err)
}
loaded, err := keyfile.LoadPrivateJWK(path)
if err != nil {
t.Fatalf("LoadPrivateJWK: %v", err)
}
if loaded.KID != pk.KID {
t.Errorf("KID mismatch: saved %q, loaded %q", pk.KID, loaded.KID)
}
// Verify file permissions are owner-only.
info, err := os.Stat(path)
if err != nil {
t.Fatal(err)
}
if perm := info.Mode().Perm(); perm != 0600 {
t.Errorf("file mode: got %o, want 0600", perm)
}
}
// --- KID consistency test ---
func TestKIDConsistency_PEM_vs_JWK(t *testing.T) {
// Verify that a key loaded from PEM gets the same KID as the same key
// loaded from JWK format.
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
// Load via PEM.
der, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
t.Fatal(err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
pemKey, err := keyfile.ParsePrivatePEM(pemBytes)
if err != nil {
t.Fatalf("ParsePrivatePEM: %v", err)
}
// Load via JWK.
jwkPK, err := jwt.FromPrivateKey(priv, "")
if err != nil {
t.Fatal(err)
}
jwkJSON, err := json.Marshal(jwkPK)
if err != nil {
t.Fatal(err)
}
jwkKey, err := jwt.ParsePrivateJWK(jwkJSON)
if err != nil {
t.Fatalf("ParsePrivateJWK: %v", err)
}
if pemKey.KID != jwkKey.KID {
t.Errorf("KID mismatch: PEM=%q, JWK=%q", pemKey.KID, jwkKey.KID)
}
}

View File

@ -1,203 +0,0 @@
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"os"
"time"
)
type PublicKey interface {
Equal(x crypto.PublicKey) bool
}
// PublicJWK represents a parsed public key (RSA or ECDSA)
type PublicJWK struct {
PublicKey
KID string
Use string
}
// PublicJWKJSON represents a JSON Web Key as defined in the provided code
type PublicJWKJSON struct {
Kty string `json:"kty"`
KID string `json:"kid"`
N string `json:"n,omitempty"` // RSA modulus
E string `json:"e,omitempty"` // RSA exponent
Crv string `json:"crv,omitempty"`
X string `json:"x,omitempty"`
Y string `json:"y,omitempty"`
Use string `json:"use,omitempty"`
}
type JWKsJSON struct {
Keys []PublicJWKJSON `json:"keys"`
}
func UnmarshalPublicJWKs(data []byte) ([]PublicJWK, error) {
var jwks JWKsJSON
if err := json.Unmarshal(data, &jwks); err != nil {
return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
}
pubkeys, err := DecodePublicJWKsJSON(jwks)
if err != nil {
return nil, err
}
return pubkeys, nil
}
func DecodePublicJWKs(r io.Reader) ([]PublicJWK, error) {
var jwks JWKsJSON
if err := json.NewDecoder(r).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to parse JWKS JSON: %w", err)
}
pubkeys, err := DecodePublicJWKsJSON(jwks)
if err != nil {
return nil, err
}
return pubkeys, nil
}
// DecodePublicJWKsJSON parses JWKS from a Reader
func DecodePublicJWKsJSON(jwks JWKsJSON) ([]PublicJWK, error) {
// Process keys
var publicKeys []PublicJWK
for _, jwk := range jwks.Keys {
publicKey, err := DecodePublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse public jwk '%s': %w", jwk.KID, err)
}
publicKeys = append(publicKeys, *publicKey)
}
if len(publicKeys) == 0 {
return nil, fmt.Errorf("no valid RSA or ECDSA keys found")
}
return publicKeys, nil
}
// DecodePublicJWK parses JWKS from a Reader
func DecodePublicJWK(jwk PublicJWKJSON) (*PublicJWK, error) {
switch jwk.Kty {
case "RSA":
key, err := decodeRSAPublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse RSA key '%s': %w", jwk.KID, err)
}
// Ensure RSA key meets minimum size requirement
if key.Size() < 128 { // 1024 bits / 8 = 128 bytes
return nil, fmt.Errorf("RSA key '%s' too small: %d bytes", jwk.KID, key.Size())
}
return &PublicJWK{PublicKey: key, KID: jwk.KID, Use: jwk.Use}, nil
case "EC":
key, err := decodeECDSAPublicJWK(jwk)
if err != nil {
return nil, fmt.Errorf("failed to parse EC key '%s': %w", jwk.KID, err)
}
return &PublicJWK{KID: jwk.KID, PublicKey: key, Use: jwk.Use}, nil
default:
return nil, fmt.Errorf("failed to parse unknown key type '%s': %s", jwk.Kty, jwk.KID)
}
}
// ReadPublicJWKs reads and parses JWKS from a file
func ReadPublicJWKs(filePath string) ([]PublicJWK, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("failed to open JWKS file '%s': %w", filePath, err)
}
defer func() { _ = file.Close() }()
return DecodePublicJWKs(file)
}
// FetchPublicJWKs retrieves and parses JWKS from a given URL
func FetchPublicJWKs(url string) ([]PublicJWK, error) {
// Set up HTTP client with timeout
client := &http.Client{
Timeout: 10 * time.Second,
}
// Make HTTP request
resp, err := client.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// Check response status
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return DecodePublicJWKs(resp.Body)
}
// decodeRSAPublicJWK parses an RSA public key from a JWK
func decodeRSAPublicJWK(jwk PublicJWKJSON) (*rsa.PublicKey, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("invalid RSA modulus: %w", err)
}
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, fmt.Errorf("invalid RSA exponent: %w", err)
}
// Convert exponent to int
eInt := new(big.Int).SetBytes(e).Int64()
if eInt > int64(^uint(0)>>1) || eInt < 0 {
return nil, fmt.Errorf("RSA exponent too large or negative")
}
return &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(eInt),
}, nil
}
// decodeECDSAPublicJWK parses an ECDSA public key from a JWK
func decodeECDSAPublicJWK(jwk PublicJWKJSON) (*ecdsa.PublicKey, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("invalid ECDSA X: %w", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("invalid ECDSA Y: %w", err)
}
var curve elliptic.Curve
switch jwk.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported ECDSA curve: %s", jwk.Crv)
}
return &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
}, nil
}

View File

@ -1,51 +0,0 @@
package jwt
import (
"crypto/ecdsa"
"crypto/elliptic"
"encoding/base64"
"math/big"
"testing"
)
// TestDecodeJWKsJSON tests parsing a specific set of ECDSA P-256 JWKS
func TestDecodeJWKJSON(t *testing.T) {
// Create a temporary file with the test JWKS
kid := "KGx1KSmDRd_dwuwmZmWiEsl9Dh4c5dQtFLLtTl-UvlI"
jwkX := "WVBcjUpllgeGbGavZ9Bbq4ps3Zk73mgRRPpbfebkC3U"
jwkY := "aTmrRia2eiJsJwzuj7DIUVmMVGrjEzQJkxxiQMgVLOw"
jwkUse := "sig"
jwksJSON := []byte(`{"keys":[{"kty":"EC","crv":"P-256","x":"` + jwkX + `","y":"` + jwkY + `","kid":"` + kid + `","use":"` + jwkUse + `"}]}`)
// Decode from bytes to JSON to Public JWKs
keys, err := UnmarshalPublicJWKs(jwksJSON)
if err != nil {
t.Fatalf("ReadJWKs failed: %v", err)
}
// Verify results
if len(keys) != 1 {
t.Errorf("Expected 1 key, got %d", len(keys))
}
key := keys[0]
if key.KID != kid {
t.Errorf("Expected KID '%s', got '%s'", kid, key.KID)
}
if key.Use != jwkUse {
t.Errorf("Expected Use 'sig', got '%s'", key.Use)
}
expectedX, _ := base64.RawURLEncoding.DecodeString(jwkX)
expectedY, _ := base64.RawURLEncoding.DecodeString(jwkY)
// Verify Equal method
sameKey := &ecdsa.PublicKey{
Curve: elliptic.P256(),
X: new(big.Int).SetBytes(expectedX),
Y: new(big.Int).SetBytes(expectedY),
}
if !key.Equal(sameKey) {
t.Errorf("Equal method failed: key should equal itself")
}
}

347
auth/jwt/sign.go Normal file
View File

@ -0,0 +1,347 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package jwt
import (
"crypto"
"encoding/base64"
"encoding/json"
"fmt"
"sync/atomic"
)
// Signer manages one or more private signing keys and issues JWTs by
// round-robining across them. It is the issuing side of a JWT issuer -
// the party that signs tokens with a private key and publishes the
// corresponding public keys.
//
// Signer has [WellKnownJWKs], so the JWKS endpoint response is just:
//
// json.Marshal(&signer.WellKnownJWKs)
//
// The embedded WellKnownJWKs includes both the active signing keys' public keys
// and any RetiredKeys passed to [NewSigner]. Retired keys appear in the
// JWKS endpoint so that relying parties can still verify tokens signed
// before rotation, but they are never used for signing.
//
// Do not copy a Signer after first use - it contains an atomic counter.
type Signer struct {
WellKnownJWKs // Keys []PublicKey - promoted; marshals as {"keys":[...]}.
// Note: Keys is exported because json.Marshal needs it for the JWKS
// endpoint. Callers should not mutate the slice after construction.
keys []PrivateKey
signerIdx atomic.Uint64
}
// NewSigner creates a Signer from the provided signing keys.
//
// NewSigner normalises each key:
// - Alg: derived from the key type (ES256/ES384/ES512/RS256/EdDSA).
// Returns an error if the caller set an incompatible Alg.
// - Use: defaults to "sig" if empty; returns an error if set to anything else.
// - KID: auto-computed from the RFC 7638 thumbprint if empty.
//
// retiredKeys are public keys that appear in the JWKS endpoint for
// verification by relying parties but are no longer used for signing.
// This supports graceful key rotation: retire old keys so tokens signed
// before the rotation remain verifiable.
//
// Returns an error if the slice is empty, any key has no Signer,
// the key type is unsupported, or a thumbprint cannot be computed.
//
// https://www.rfc-editor.org/rfc/rfc7638.html
func NewSigner(keys []*PrivateKey, retiredKeys ...PublicKey) (*Signer, error) {
if len(keys) == 0 {
return nil, fmt.Errorf("NewSigner: %w", ErrNoSigningKey)
}
// Copy so the caller can't mutate after construction.
ss := make([]PrivateKey, len(keys))
for i, k := range keys {
if k == nil || k.privKey == nil {
return nil, fmt.Errorf("NewSigner: key[%d]: %w", i, ErrNoSigningKey)
}
ss[i] = *k
// Derive algorithm from key type; validate caller's Alg if already set.
alg, _, _, err := signingParams(ss[i].privKey)
if err != nil {
return nil, fmt.Errorf("NewSigner: key[%d]: %w", i, err)
}
if ss[i].Alg != "" && ss[i].Alg != alg {
return nil, fmt.Errorf("NewSigner: key[%d] alg %q expected %s: %w", i, ss[i].Alg, alg, ErrAlgConflict)
}
ss[i].Alg = alg
// Default Use to "sig" for signing keys; reject anything else.
if ss[i].Use == "" {
ss[i].Use = "sig"
} else if ss[i].Use != "sig" {
return nil, fmt.Errorf("NewSigner: key[%d] kid %q: use %q, want \"sig\"", i, ss[i].KID, ss[i].Use)
}
// Auto-compute KID from thumbprint if empty.
if ss[i].KID == "" {
thumb, err := ss[i].Thumbprint()
if err != nil {
return nil, fmt.Errorf("NewSigner: compute thumbprint for key[%d]: %w", i, err)
}
ss[i].KID = thumb
}
}
pubs := make([]PublicKey, len(ss), len(ss)+len(retiredKeys))
for i := range ss {
pub, err := ss[i].PublicKey()
if err != nil {
return nil, fmt.Errorf("NewSigner: key[%d] kid %q: %w", i, ss[i].KID, err)
}
pubs[i] = *pub
}
// Validate each key by performing a test sign+verify round-trip.
// This catches bad keys at construction rather than first use.
for i := range ss {
if err := validateSigningKey(&ss[i], &pubs[i]); err != nil {
return nil, fmt.Errorf("NewSigner: key[%d] kid %q: %w", i, ss[i].KID, err)
}
}
// Append retired keys so they appear in the JWKS endpoint but are
// never selected for signing.
pubs = append(pubs, retiredKeys...)
return &Signer{
WellKnownJWKs: WellKnownJWKs{Keys: pubs},
keys: ss,
}, nil
}
// nextKey returns the next signing key in round-robin order.
func (s *Signer) nextKey() *PrivateKey {
n := uint64(len(s.keys))
var idx uint64
for {
cur := s.signerIdx.Load()
next := (cur + 1) % n
if s.signerIdx.CompareAndSwap(cur, next) {
idx = cur
break
}
}
return &s.keys[idx]
}
// SignJWT signs jws in-place.
//
// Key selection: if the header already has a KID, the signer uses the key
// with that KID (returning [ErrUnknownKID] if none match). Otherwise the
// next key in round-robin order is selected and its KID is written into
// the header.
//
// The alg header field is set automatically from the selected key.
//
// Use this when you need the full signed JWT for further processing
// (e.g., inspecting headers before encoding). For the common one-step cases,
// prefer [Signer.Sign] or [Signer.SignToString].
func (s *Signer) SignJWT(jws SignableJWT) error {
hdr := jws.GetHeader()
var pk *PrivateKey
if hdr.KID != "" {
for i := range s.keys {
if s.keys[i].KID == hdr.KID {
pk = &s.keys[i]
break
}
}
if pk == nil {
return fmt.Errorf("kid %q: %w", hdr.KID, ErrUnknownKID)
}
} else {
pk = s.nextKey()
hdr.KID = pk.KID
}
if pk.privKey == nil {
return fmt.Errorf("kid %q: %w", pk.KID, ErrNoSigningKey)
}
alg, hash, ecKeySize, err := signingParams(pk.privKey)
if err != nil {
return err
}
// Validate and set header algorithm.
if hdr.Alg != "" && hdr.Alg != alg {
return fmt.Errorf("key %s vs header %q: %w", alg, hdr.Alg, ErrAlgConflict)
}
hdr.Alg = alg
if err := jws.SetHeader(&hdr); err != nil {
return err
}
input := signingInputBytes(jws.GetProtected(), jws.GetPayload())
sig, err := signBytes(pk.privKey, alg, hash, ecKeySize, input)
if err != nil {
return err
}
jws.SetSignature(sig)
return nil
}
// SignRaw signs an arbitrary protected header and payload, returning
// the result as a [*RawJWT] suitable for [json.Marshal] (flattened JWS)
// or [Encode] (compact serialization).
//
// Unlike [Signer.SignJWT], SignRaw does not set or validate the KID
// field -- the caller controls it entirely. This supports protocols
// like ACME (RFC 8555) where the kid is an account URL, or where kid
// must be absent (newAccount uses jwk instead).
//
// The alg field is always set from the key type. If hdr already has a
// non-empty Alg that conflicts with the key, SignRaw returns an error.
//
// payload is the raw bytes to encode as the JWS payload. A nil payload
// produces an empty payload segment (used by ACME POST-as-GET).
func (s *Signer) SignRaw(hdr Header, payload []byte) (*RawJWT, error) {
pk := s.nextKey()
if pk.privKey == nil {
return nil, fmt.Errorf("kid %q: %w", pk.KID, ErrNoSigningKey)
}
rfc := hdr.GetRFCHeader()
alg, hash, ecKeySize, err := signingParams(pk.privKey)
if err != nil {
return nil, err
}
if rfc.Alg != "" && rfc.Alg != alg {
return nil, fmt.Errorf("key %s vs header %q: %w", alg, rfc.Alg, ErrAlgConflict)
}
rfc.Alg = alg
headerJSON, err := json.Marshal(hdr)
if err != nil {
return nil, fmt.Errorf("marshal header: %w", err)
}
protectedB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
payloadB64 := base64.RawURLEncoding.EncodeToString(payload)
input := signingInputBytes([]byte(protectedB64), []byte(payloadB64))
sig, err := signBytes(pk.privKey, alg, hash, ecKeySize, input)
if err != nil {
return nil, err
}
return &RawJWT{
Protected: []byte(protectedB64),
Payload: []byte(payloadB64),
Signature: sig,
}, nil
}
// Sign creates a JWT from claims, signs it with the next signing key,
// and returns the signed JWT.
//
// Use this when you need access to the signed JWT object (e.g., to inspect
// headers or read the raw signature). For the common case of producing a
// compact token string, use [Signer.SignToString].
func (s *Signer) Sign(claims Claims) (*JWT, error) {
jws, err := New(claims)
if err != nil {
return nil, err
}
if err := s.SignJWT(jws); err != nil {
return nil, err
}
return jws, nil
}
// SignToString creates and signs a JWT from claims and returns the compact
// token string (header.payload.signature).
//
// This is the most convenient form for the common case of signing and
// immediately transmitting a token. The caller is responsible for setting
// the "iss" field in claims if issuer identification is needed.
func (s *Signer) SignToString(claims Claims) (string, error) {
jws, err := s.Sign(claims)
if err != nil {
return "", err
}
return Encode(jws)
}
// Verifier returns a new [*Verifier] containing the public keys of all
// signing keys plus any retired keys passed to [NewSigner].
//
// Panics if NewVerifier fails, which indicates an invariant violation
// since [NewSigner] already validated the keys.
func (s *Signer) Verifier() *Verifier {
v, err := NewVerifier(s.Keys)
if err != nil {
panic(fmt.Sprintf("jwt: Signer.Verifier: NewVerifier failed on previously validated keys: %v", err))
}
return v
}
// signBytes signs input using the given crypto.Signer with the appropriate
// hash and ECDSA DER-to-P1363 conversion. It handles pre-hashing for EC/RSA
// and raw signing for Ed25519.
func signBytes(signer crypto.Signer, alg string, hash crypto.Hash, ecKeySize int, input []byte) ([]byte, error) {
var sig []byte
var err error
if hash != 0 {
digest, derr := digestFor(hash, input)
if derr != nil {
return nil, derr
}
sig, err = signer.Sign(nil, digest, hash)
} else {
sig, err = signer.Sign(nil, input, crypto.Hash(0))
}
if err != nil {
return nil, fmt.Errorf("sign %s: %w", alg, err)
}
// ECDSA: crypto.Signer returns ASN.1 DER, but JWS (RFC 7515 §A.3)
// requires IEEE P1363 format (raw r||s concatenation).
if ecKeySize > 0 {
sig, err = ecdsaDERToP1363(sig, ecKeySize)
if err != nil {
return nil, err
}
}
return sig, nil
}
// validateSigningKey performs a test sign+verify round-trip to catch bad
// keys at construction time rather than on first use.
func validateSigningKey(pk *PrivateKey, pub *PublicKey) error {
alg, hash, ecKeySize, err := signingParams(pk.privKey)
if err != nil {
return err
}
testInput := []byte("jwt-key-validation")
sig, err := signBytes(pk.privKey, alg, hash, ecKeySize, testInput)
if err != nil {
return fmt.Errorf("test sign: %w", err)
}
// Verify against the public key.
h := RFCHeader{Alg: alg, KID: pk.KID}
if err := verifyOneKey(h, pub.Key, testInput, sig); err != nil {
return fmt.Errorf("test verify: %w", err)
}
return nil
}

160
auth/jwt/types.go Normal file
View File

@ -0,0 +1,160 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package jwt
import (
"encoding/json"
"fmt"
"strings"
)
// JOSE "typ" header values. The signer sets [DefaultTokenTyp] automatically.
// Use [NewAccessToken] or [JWT.SetTyp] to produce an OAuth 2.1 access token
// with [AccessTokenTyp].
const (
DefaultTokenTyp = "JWT" // standard JWT
AccessTokenTyp = "at+jwt" // OAuth 2.1 access token (RFC 9068 §2.1)
)
// Listish handles the JWT "aud" claim quirk: RFC 7519 §4.1.3 allows
// it to be either a single string or an array of strings.
//
// It unmarshals from both a single string ("https://auth.example.com") and
// an array (["https://api.example.com", "https://app.example.com"]).
// It marshals to a plain string for a single value and to an array for
// multiple values.
//
// https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.3
type Listish []string
// UnmarshalJSON decodes both the string and []string forms of the "aud" claim.
// An empty string unmarshals to an empty (non-nil) slice, round-tripping with
// [Listish.MarshalJSON].
func (l *Listish) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err == nil {
if s == "" {
*l = Listish{}
return nil
}
*l = Listish{s}
return nil
}
var ss []string
if err := json.Unmarshal(data, &ss); err != nil {
return fmt.Errorf("aud must be a string or array of strings: %w: %w", ErrInvalidPayload, err)
}
*l = ss
return nil
}
// IsZero reports whether the list is empty (nil or zero-length).
// Used by encoding/json with the omitzero tag option.
func (l Listish) IsZero() bool { return len(l) == 0 }
// MarshalJSON encodes the list as a plain string when there is one
// value, or as a JSON array for multiple values. An empty or nil Listish
// marshals as JSON null.
func (l Listish) MarshalJSON() ([]byte, error) {
switch len(l) {
case 0:
return []byte("null"), nil
case 1:
return json.Marshal(l[0])
default:
return json.Marshal([]string(l))
}
}
// SpaceDelimited is a slice of strings that marshals as a single
// space-separated string in JSON, per RFC 6749 §3.3.
//
// It has three-state semantics:
// - nil: absent - the field is not present (omitted via omitzero)
// - non-nil empty (SpaceDelimited{}): present but empty - marshals as ""
// - populated (SpaceDelimited{"openid", "profile"}): marshals as "openid profile"
//
// UnmarshalJSON decodes a space-separated string back into a slice,
// preserving the distinction between nil (absent) and empty non-nil (present as "").
//
// https://www.rfc-editor.org/rfc/rfc6749.html#section-3.3
type SpaceDelimited []string
// UnmarshalJSON decodes a space-separated string into a slice.
// An empty string "" unmarshals to a non-nil empty SpaceDelimited{},
// preserving the distinction from a nil (absent) SpaceDelimited.
func (s *SpaceDelimited) UnmarshalJSON(data []byte) error {
var str string
if err := json.Unmarshal(data, &str); err != nil {
return fmt.Errorf("space-delimited must be a string: %w: %w", ErrInvalidPayload, err)
}
if str == "" {
*s = SpaceDelimited{}
return nil
}
*s = strings.Fields(str)
return nil
}
// IsZero reports whether the slice is absent (nil).
// Used by encoding/json with the omitzero tag option to omit the field
// when it is nil, while still marshaling a non-nil empty SpaceDelimited as "".
func (s SpaceDelimited) IsZero() bool { return s == nil }
// MarshalJSON encodes the slice as a single space-separated string.
// A nil SpaceDelimited marshals as JSON null (but is typically omitted via omitzero).
// A non-nil empty SpaceDelimited marshals as the empty string "".
func (s SpaceDelimited) MarshalJSON() ([]byte, error) {
if s == nil {
return []byte("null"), nil
}
return json.Marshal(strings.Join(s, " "))
}
// NullBool represents a boolean that can be null, true, or false.
// Used for OIDC *_verified fields where null means "not applicable"
// (the corresponding value is absent), false means "present but not
// verified", and true means "verified".
type NullBool struct {
Bool bool
Valid bool // Valid is true if Bool is not NULL
}
// IsZero reports whether nb is the zero value (not valid).
// Used by encoding/json with the omitzero tag option.
func (nb NullBool) IsZero() bool { return !nb.Valid }
// MarshalJSON encodes the NullBool as JSON. If !Valid, it outputs null;
// otherwise it outputs true or false.
func (nb NullBool) MarshalJSON() ([]byte, error) {
if !nb.Valid {
return []byte("null"), nil
}
if nb.Bool {
return []byte("true"), nil
}
return []byte("false"), nil
}
// UnmarshalJSON decodes a JSON value into a NullBool.
// null -> {false, false}; true -> {true, true}; false -> {false, true}.
func (nb *NullBool) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
nb.Bool = false
nb.Valid = false
return nil
}
var b bool
if err := json.Unmarshal(data, &b); err != nil {
return err
}
nb.Bool = b
nb.Valid = true
return nil
}

653
auth/jwt/validate.go Normal file
View File

@ -0,0 +1,653 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package jwt
import (
"errors"
"fmt"
"slices"
"strings"
"time"
)
// ValidationError represents a single claim validation failure with a
// machine-readable code suitable for API responses.
//
// Code values and their meanings:
//
// token_expired - exp claim is in the past
// token_not_yet_valid - nbf claim is in the future
// future_issued_at - iat claim is in the future
// future_auth_time - auth_time claim is in the future
// auth_time_exceeded - auth_time exceeds max age
// insufficient_scope - required scopes not granted
// missing_claim - a required claim is absent
// invalid_claim - a claim value is wrong (bad iss, aud, etc.)
// server_error - server-side validator config error (treat as 500)
// unknown_error - unrecognized sentinel (should not occur)
//
// ValidationError satisfies [error] and supports [errors.Is] via [Unwrap]
// against the underlying sentinel (e.g., [ErrAfterExp], [ErrMissingClaim]).
//
// JSON serialization produces {"code": "...", "description": "..."}
// for direct use in API error responses.
//
// Use [ValidationErrors] to extract these from the error returned by
// [Validator.Validate]:
//
// err := v.Validate(nil, &claims, time.Now())
// for _, ve := range jwt.ValidationErrors(err) {
// log.Printf("code=%s: %s", ve.Code, ve.Description)
// }
type ValidationError struct {
Code string `json:"code"` // machine-readable code (see table above)
Description string `json:"description"` // human-readable detail, prefixed with claim name
Err error `json:"-"` // sentinel for errors.Is / Unwrap
}
// Error implements [error]. Returns the human-readable description.
func (e *ValidationError) Error() string { return e.Description }
// Unwrap returns the underlying sentinel error for use with [errors.Is].
func (e *ValidationError) Unwrap() error { return e.Err }
// ValidationErrors extracts structured [*ValidationError] values from the
// error returned by [Validator.Validate] or [TokenClaims.Errors].
//
// Non-ValidationError entries (such as the server-time context line) are
// skipped. Returns nil if err is nil or contains no ValidationError values.
func ValidationErrors(err error) []*ValidationError {
if err == nil {
return nil
}
var errs []error
if joined, ok := err.(interface{ Unwrap() []error }); ok {
errs = joined.Unwrap()
} else {
errs = []error{err}
}
result := make([]*ValidationError, 0, len(errs))
for _, e := range errs {
if ve, ok := e.(*ValidationError); ok {
result = append(result, ve)
}
}
if len(result) == 0 {
return nil
}
return result
}
// GetOAuth2Error returns the OAuth 2.0 error code for the validation error
// returned by [Validator.Validate] or [TokenClaims.Errors].
//
// Returns one of:
//
// - "invalid_token" - the token is expired, malformed, or otherwise invalid
// - "insufficient_scope" - the token lacks required scopes
// - "server_error" - server-side misconfiguration (treat as HTTP 500)
//
// per RFC 6750 §3.1. When multiple validation failures exist, the most severe
// code wins (server_error > insufficient_scope > invalid_token).
//
// Returns "" if err is nil or contains no [*ValidationError] values.
// Use err.Error() for the human-readable description:
//
// err := v.Validate(nil, &claims, time.Now())
// if code := jwt.GetOAuth2Error(err); code != "" {
// vals := url.Values{"error": {code}, "error_description": {err.Error()}}
// http.Redirect(w, r, redirectURI+"?"+vals.Encode(), http.StatusFound)
// }
func GetOAuth2Error(err error) (oauth2Error string) {
ves := ValidationErrors(err)
if len(ves) == 0 {
return ""
}
// Pick the most severe OAuth code across all errors.
code := "invalid_token"
for _, ve := range ves {
switch {
case errors.Is(ve.Err, ErrMisconfigured):
code = "server_error"
case errors.Is(ve.Err, ErrInsufficientScope) && code != "server_error":
code = "insufficient_scope"
}
}
return code
}
// appendError constructs a [*ValidationError] and appends it to the slice.
// sentinel is the error for [errors.Is] matching; format and args produce the
// human-readable description (conventionally prefixed with the claim name,
// e.g. "exp: expired 5m ago").
func appendError(errs []error, sentinel error, format string, args ...any) []error {
return append(errs, &ValidationError{
Code: codeFor(sentinel),
Description: fmt.Sprintf(format, args...),
Err: sentinel,
})
}
// isTimeSentinel reports whether the sentinel is a time-related claim error.
func isTimeSentinel(sentinel error) bool {
return errors.Is(sentinel, ErrAfterExp) ||
errors.Is(sentinel, ErrBeforeNBf) ||
errors.Is(sentinel, ErrBeforeIAt) ||
errors.Is(sentinel, ErrBeforeAuthTime) ||
errors.Is(sentinel, ErrAfterAuthMaxAge)
}
// codeFor maps a sentinel error to a machine-readable code string.
func codeFor(sentinel error) string {
switch {
case errors.Is(sentinel, ErrAfterExp):
return "token_expired"
case errors.Is(sentinel, ErrBeforeNBf):
return "token_not_yet_valid"
case errors.Is(sentinel, ErrBeforeIAt):
return "future_issued_at"
case errors.Is(sentinel, ErrBeforeAuthTime):
return "future_auth_time"
case errors.Is(sentinel, ErrAfterAuthMaxAge):
return "auth_time_exceeded"
case errors.Is(sentinel, ErrInsufficientScope):
return "insufficient_scope"
case errors.Is(sentinel, ErrMissingClaim):
return "missing_claim"
case errors.Is(sentinel, ErrInvalidTyp):
return "invalid_typ"
case errors.Is(sentinel, ErrInvalidClaim):
return "invalid_claim"
case errors.Is(sentinel, ErrMisconfigured):
return "server_error"
default:
return "unknown_error"
}
}
// formatDuration formats a duration as a human-readable string with days,
// hours, minutes, seconds, and milliseconds.
func formatDuration(d time.Duration) string {
if d < 0 {
d = -d
}
days := int(d / (24 * time.Hour))
d -= time.Duration(days) * 24 * time.Hour
hours := int(d / time.Hour)
d -= time.Duration(hours) * time.Hour
minutes := int(d / time.Minute)
d -= time.Duration(minutes) * time.Minute
seconds := int(d / time.Second)
var parts []string
if days > 0 {
parts = append(parts, fmt.Sprintf("%dd", days))
}
if hours > 0 {
parts = append(parts, fmt.Sprintf("%dh", hours))
}
if minutes > 0 {
parts = append(parts, fmt.Sprintf("%dm", minutes))
}
if seconds > 0 {
parts = append(parts, fmt.Sprintf("%ds", seconds))
}
if len(parts) == 0 {
// Sub-second duration: fall back to milliseconds.
millis := int(d / time.Millisecond)
parts = append(parts, fmt.Sprintf("%dms", millis))
}
return strings.Join(parts, " ")
}
// defaultGracePeriod is the tolerance applied to exp, iat, and auth_time checks
// when Validator.GracePeriod is zero.
//
// It should be set to at least 2s in most cases to account for practical edge
// cases of corresponding systems having even a millisecond of clock skew between
// them and the offset of their respective implementations truncating, flooring,
// ceiling, or rounding seconds differently.
//
// For example: If 1.999 is truncated to 1 and 2.001 is ceiled to 3, then there
// is a 2 second difference.
//
// This will very rarely affect calculations on exp (and hopefully a client knows
// better than to ride the very millisecond of expiration), but it can very
// frequently affect calculations on iat and nbf on distributed production
// systems.
const defaultGracePeriod = 2 * time.Second
// Checks is a bitmask that selects which claim validations [Validator]
// performs. Combine with OR:
//
// v := &jwt.Validator{
// Checks: jwt.CheckIss | jwt.CheckExp,
// Iss: []string{"https://example.com"},
// }
//
// Use [NewIDTokenValidator] or [NewAccessTokenValidator] for sensible defaults.
type Checks uint32
const (
// ChecksConfigured is a sentinel bit that distinguishes a deliberately
// configured Checks value from the zero value. Constructors like
// [NewIDTokenValidator] set it automatically. Struct-literal users
// should include it so that [Validator.Validate] does not reject the
// Validator as unconfigured.
ChecksConfigured Checks = 1 << iota
CheckIss // validate issuer
CheckSub // validate subject presence
CheckAud // validate audience
CheckExp // validate expiration
CheckNBf // validate not-before
CheckIAt // validate issued-at is not in the future
CheckClientID // validate client_id presence
CheckJTI // validate jti presence
CheckAuthTime // validate auth_time
CheckAzP // validate authorized party
CheckScope // validate scope presence
)
// resolveSkew converts a GracePeriod configuration value to a skew duration.
// Zero means use [defaultGracePeriod]; negative means no tolerance.
func resolveSkew(gracePeriod time.Duration) time.Duration {
if gracePeriod == 0 {
return defaultGracePeriod
}
if gracePeriod < 0 {
return 0
}
return gracePeriod
}
// Validator checks JWT claims for both ID tokens and access tokens.
//
// Use [NewIDTokenValidator] or [NewAccessTokenValidator] to create one with
// sensible defaults for the token type. You can also construct a Validator
// literal with a custom [Checks] bitmask - but you must OR at least one
// Check* flag or set Iss/Aud/AzP/RequiredScopes/MaxAge, otherwise Validate
// returns a misconfiguration error (a zero-value Validator is never valid).
//
// Iss, Aud, and AzP distinguish nil from empty: nil means unconfigured
// (no check unless the corresponding Check* flag is set), a non-nil empty
// slice is always a misconfiguration error (the empty set allows nothing),
// and ["*"] accepts any non-empty value. A non-nil slice forces its check
// regardless of the Checks bitmask.
//
// GracePeriod is applied to exp, nbf, iat, and auth_time (including maxAge)
// to tolerate minor clock differences between distributed systems. If zero,
// the default grace period (2s) is used. Set to a negative value to disable
// skew tolerance entirely.
//
// Explicit configuration (non-nil Iss/Aud/AzP, non-empty RequiredScopes,
// MaxAge > 0) forces the corresponding check regardless of the Checks bitmask.
type Validator struct {
Checks Checks
GracePeriod time.Duration // 0 = default (2s); negative = no tolerance
MaxAge time.Duration
Iss []string // nil=unchecked, []=misconfigured, ["*"]=any, ["x"]=must match
Aud []string // nil=unchecked, []=misconfigured, ["*"]=any, ["x"]=must intersect
AzP []string // nil=unchecked, []=misconfigured, ["*"]=any, ["x"]=must match
RequiredScopes []string // all of these must appear in the token's scope
}
// NewIDTokenValidator returns a [Validator] configured for OIDC Core §2 ID Tokens.
//
// Pass the allowed issuers and audiences, or nil to skip that check.
// Use []string{"*"} to require the claim be present without restricting its value.
//
// Checks enabled by default: iss, sub, aud, exp, iat, auth_time, azp
// Not checked: amr, nonce, nbf, jti, client_id, and scope.
// Adjust by OR-ing or masking the returned Validator's Checks field.
//
// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
func NewIDTokenValidator(iss, aud, azp []string) *Validator {
checks := ChecksConfigured | CheckSub | CheckExp | CheckIAt | CheckAuthTime
if iss != nil {
checks |= CheckIss
}
if aud != nil {
checks |= CheckAud
}
if azp != nil {
checks |= CheckAzP
}
return &Validator{
Checks: checks,
GracePeriod: defaultGracePeriod,
Iss: iss,
Aud: aud,
AzP: azp,
}
}
// NewAccessTokenValidator returns a [Validator] configured for OAuth 2.1 JWT
// access tokens per RFC 9068 §2.2.
//
// Pass the allowed issuers and audiences, or nil to skip that check.
// Use []string{"*"} to require the claim be present without restricting its value.
//
// Checks enabled by default: iss, exp, aud, sub, client_id, iat, jti. and scope.
// Not checked: nbf, auth_time, and, azp.
// Populate RequiredScopes to enforce specific scope values (overrides CheckScope).
//
// https://www.rfc-editor.org/rfc/rfc9068.html#section-2.2
func NewAccessTokenValidator(iss, aud, scopes []string) *Validator {
checks := ChecksConfigured | CheckSub | CheckExp | CheckIAt | CheckJTI | CheckClientID
if iss != nil {
checks |= CheckIss
}
if aud != nil {
checks |= CheckAud
}
if scopes != nil {
checks |= CheckScope
}
return &Validator{
Checks: checks,
GracePeriod: defaultGracePeriod,
Iss: iss,
Aud: aud,
RequiredScopes: scopes,
}
}
// Validate checks JWT claims according to the configured [Checks] bitmask.
//
// Each Check* flag enables its check. Explicit configuration
// (non-nil Iss/Aud/AzP, non-empty RequiredScopes, MaxAge > 0) forces
// the corresponding check regardless of the Checks bitmask.
//
// The individual check methods on [TokenClaims] are exported so that custom
// validators can call them directly without going through Validate.
//
// The errs parameter lets callers thread in errors from earlier checks
// (e.g. [RFCHeader.IsAllowedTyp]) so that all findings appear in a single
// joined error. Pass nil when there are no prior errors.
//
// now is caller-supplied (not time.Now()) so that validation is
// deterministic and testable.
//
// Returns nil on success. On failure, the returned error is a joined
// error that supports [errors.Is] for individual sentinels (e.g.
// [ErrAfterExp], [ErrMissingClaim]). Use Unwrap() []error to iterate
// each finding.
func (v *Validator) Validate(errs []error, claims Claims, now time.Time) error {
tc := claims.GetTokenClaims()
// Detect unconfigured validator: no Check* flags and no explicit config.
if v.Checks == 0 && len(v.Iss) == 0 && len(v.Aud) == 0 && len(v.AzP) == 0 &&
len(v.RequiredScopes) == 0 && v.MaxAge == 0 {
return appendError(nil, ErrMisconfigured, "validator has no checks configured; use a constructor or set Check* flags")[0]
}
skew := resolveSkew(v.GracePeriod)
if v.Iss != nil || v.Checks&CheckIss != 0 {
errs = tc.IsAllowedIss(errs, v.Iss)
}
if v.Checks&CheckSub != 0 {
errs = tc.IsPresentSub(errs)
}
if v.Aud != nil || v.Checks&CheckAud != 0 {
errs = tc.HasAllowedAud(errs, v.Aud)
}
if v.Checks&CheckExp != 0 {
errs = tc.IsBeforeExp(errs, now, skew)
}
if v.Checks&CheckNBf != 0 {
errs = tc.IsAfterNBf(errs, now, skew)
}
if v.Checks&CheckIAt != 0 {
errs = tc.IsAfterIAt(errs, now, skew)
}
if v.Checks&CheckJTI != 0 {
errs = tc.IsPresentJTI(errs)
}
if v.MaxAge > 0 || v.Checks&CheckAuthTime != 0 {
errs = tc.IsValidAuthTime(errs, now, skew, v.MaxAge)
}
if v.AzP != nil || v.Checks&CheckAzP != 0 {
errs = tc.IsAllowedAzP(errs, v.AzP)
}
if v.Checks&CheckClientID != 0 {
errs = tc.IsPresentClientID(errs)
}
if len(v.RequiredScopes) > 0 || v.Checks&CheckScope != 0 {
errs = tc.ContainsScopes(errs, v.RequiredScopes)
}
if len(errs) > 0 {
// Annotate time-related errors with the server's clock for debugging.
serverTime := fmt.Sprintf("server time %s (%s)", now.Format("2006-01-02 15:04:05 MST"), time.Local)
for _, e := range errs {
if ve, ok := e.(*ValidationError); ok && isTimeSentinel(ve.Err) {
ve.Description = fmt.Sprintf("%s; %s", ve.Description, serverTime)
}
}
return errors.Join(errs...)
}
return nil
}
// --- Per-claim check methods on *TokenClaims ---
//
// These exported methods can be called directly by custom validators.
// Each method appends validation errors to the provided slice and returns it.
// The [Validator] decides which checks to call based on its [Checks] bitmask.
//
// Methods are named by assertion kind:
//
// - IsAllowed - value must appear in a configured list
// - HasAllowed - value must intersect a configured list
// - IsPresent - value must be non-empty
// - IsBefore - now must be before a time boundary
// - IsAfter - now must be after a time boundary
// - IsValid - composite check (presence + time bounds)
// - Contains - value must contain all required entries
// IsAllowedIss validates the issuer claim.
//
// Allowed semantics: nil = misconfigured (error), [] = misconfigured (error),
// ["*"] = any non-empty value, ["x","y"] = must match one.
//
// At the [Validator] level, passing nil Iss disables the issuer check
// entirely (the method is never called). Calling this method directly
// with nil is a misconfiguration error.
func (tc *TokenClaims) IsAllowedIss(errs []error, allowed []string) []error {
if allowed == nil {
return appendError(errs, ErrMisconfigured, "iss: issuer checking enabled but Iss is nil")
}
if len(allowed) == 0 {
return appendError(errs, ErrMisconfigured, "iss: non-nil empty Iss allows no issuers")
} else if tc.Iss == "" {
return appendError(errs, ErrMissingClaim, "iss: missing required claim")
} else if !slices.Contains(allowed, "*") && !slices.Contains(allowed, tc.Iss) {
return appendError(errs, ErrInvalidClaim, "iss %q not in allowed list", tc.Iss)
}
return errs
}
// IsPresentSub validates that the subject claim is present.
func (tc *TokenClaims) IsPresentSub(errs []error) []error {
if tc.Sub == "" {
return appendError(errs, ErrMissingClaim, "sub: missing required claim")
}
return errs
}
// HasAllowedAud validates the audience claim.
//
// Allowed semantics: nil = misconfigured (error), [] = misconfigured (error),
// ["*"] = any non-empty value, ["x","y"] = token's aud must intersect.
//
// At the [Validator] level, passing nil Aud disables the audience check
// entirely (the method is never called). Calling this method directly
// with nil is a misconfiguration error.
func (tc *TokenClaims) HasAllowedAud(errs []error, allowed []string) []error {
if allowed == nil {
return appendError(errs, ErrMisconfigured, "aud: audience checking enabled but Aud is nil")
}
if len(allowed) == 0 {
return appendError(errs, ErrMisconfigured, "aud: non-nil empty Aud allows no audiences")
} else if len(tc.Aud) == 0 {
return appendError(errs, ErrMissingClaim, "aud: missing required claim")
} else if !slices.Contains(allowed, "*") && !slices.ContainsFunc([]string(tc.Aud), func(a string) bool {
return slices.Contains(allowed, a)
}) {
return appendError(errs, ErrInvalidClaim, "aud %v not in allowed list", tc.Aud)
}
return errs
}
// IsBeforeExp validates the expiration claim.
// now is caller-supplied for testability; pass time.Now() in production.
func (tc *TokenClaims) IsBeforeExp(errs []error, now time.Time, skew time.Duration) []error {
if tc.Exp <= 0 {
return appendError(errs, ErrMissingClaim, "exp: missing required claim")
}
expTime := time.Unix(tc.Exp, 0)
if now.After(expTime.Add(skew)) {
dur := now.Sub(expTime)
return appendError(errs, ErrAfterExp, "expired %s ago (%s)",
formatDuration(dur), expTime.Format("2006-01-02 15:04:05 MST"))
}
return errs
}
// IsAfterNBf validates the not-before claim. Absence is never an error.
// now is caller-supplied for testability; pass time.Now() in production.
func (tc *TokenClaims) IsAfterNBf(errs []error, now time.Time, skew time.Duration) []error {
if tc.NBf <= 0 {
return errs
}
nbfTime := time.Unix(tc.NBf, 0)
if nbfTime.After(now.Add(skew)) {
dur := nbfTime.Sub(now)
return appendError(errs, ErrBeforeNBf, "nbf is %s in the future (%s)",
formatDuration(dur), nbfTime.Format("2006-01-02 15:04:05 MST"))
}
return errs
}
// IsAfterIAt validates that the issued-at claim is not in the future.
// now is caller-supplied for testability; pass time.Now() in production.
//
// Unlike iss or sub, absence is not an error - iat is optional per
// RFC 7519. However, when present, a future iat is rejected as a
// common-sense sanity check (the spec does not require this).
func (tc *TokenClaims) IsAfterIAt(errs []error, now time.Time, skew time.Duration) []error {
if tc.IAt <= 0 {
return errs // absence is not an error
}
iatTime := time.Unix(tc.IAt, 0)
if iatTime.After(now.Add(skew)) {
dur := iatTime.Sub(now)
return appendError(errs, ErrBeforeIAt, "iat is %s in the future (%s)",
formatDuration(dur), iatTime.Format("2006-01-02 15:04:05 MST"))
}
return errs
}
// IsPresentJTI validates that the JWT ID claim is present.
func (tc *TokenClaims) IsPresentJTI(errs []error) []error {
if tc.JTI == "" {
return appendError(errs, ErrMissingClaim, "jti: missing required claim")
}
return errs
}
// IsValidAuthTime validates the authentication time claim.
// now is caller-supplied for testability; pass time.Now() in production.
//
// When maxAge is positive, auth_time must be present and within maxAge
// of now. When maxAge is zero, only presence and future-time checks apply.
func (tc *TokenClaims) IsValidAuthTime(errs []error, now time.Time, skew time.Duration, maxAge time.Duration) []error {
if tc.AuthTime == 0 {
return appendError(errs, ErrMissingClaim, "auth_time: missing required claim")
}
authTime := time.Unix(tc.AuthTime, 0)
authTimeStr := authTime.Format("2006-01-02 15:04:05 MST")
if authTime.After(now.Add(skew)) {
dur := authTime.Sub(now)
return appendError(errs, ErrBeforeAuthTime, "auth_time %s is %s in the future",
authTimeStr, formatDuration(dur))
} else if maxAge > 0 {
age := now.Sub(authTime)
if age > maxAge+skew {
diff := age - maxAge
return appendError(errs, ErrAfterAuthMaxAge, "auth_time %s is %s old, exceeding max age %s by %s",
authTimeStr, formatDuration(age), formatDuration(maxAge), formatDuration(diff))
}
}
return errs
}
// IsAllowedAzP validates the authorized party claim.
//
// Allowed semantics: nil = misconfigured (error), [] = misconfigured (error),
// ["*"] = any non-empty value, ["x","y"] = must match one.
func (tc *TokenClaims) IsAllowedAzP(errs []error, allowed []string) []error {
if allowed == nil {
return appendError(errs, ErrMisconfigured, "azp: authorized party checking enabled but AzP is nil")
}
if len(allowed) == 0 {
return appendError(errs, ErrMisconfigured, "azp: non-nil empty AzP allows no parties")
} else if tc.AzP == "" {
return appendError(errs, ErrMissingClaim, "azp: missing required claim")
} else if !slices.Contains(allowed, "*") && !slices.Contains(allowed, tc.AzP) {
return appendError(errs, ErrInvalidClaim, "azp %q not in allowed list", tc.AzP)
}
return errs
}
// IsPresentClientID validates that the client_id claim is present.
func (tc *TokenClaims) IsPresentClientID(errs []error) []error {
if tc.ClientID == "" {
return appendError(errs, ErrMissingClaim, "client_id: missing required claim")
}
return errs
}
// ContainsScopes validates that the token's scope claim is present and
// contains all required values. When required is nil, only presence is checked.
func (tc *TokenClaims) ContainsScopes(errs []error, required []string) []error {
if len(tc.Scope) == 0 {
return appendError(errs, ErrMissingClaim, "scope: missing required claim")
}
for _, req := range required {
if !slices.Contains(tc.Scope, req) {
errs = appendError(errs, ErrInsufficientScope, "scope %q not granted", req)
}
}
return errs
}
// IsAllowedTyp validates that the JOSE "typ" header is one of the allowed
// values. Comparison is case-insensitive per RFC 7515 §4.1.9.
// Call this between [Verifier.Verify] and [Validator.Validate] to enforce
// token-type constraints (e.g. reject an access token where an ID token
// is expected).
//
// hdr := jws.GetHeader()
// errs = hdr.IsAllowedTyp(errs, []string{"JWT"})
func (h *RFCHeader) IsAllowedTyp(errs []error, allowed []string) []error {
if len(allowed) == 0 {
return appendError(errs, ErrMisconfigured, "typ: allowed list is empty")
}
for _, a := range allowed {
if strings.EqualFold(h.Typ, a) {
return errs
}
}
return appendError(errs, ErrInvalidTyp, "typ %q not in allowed list", h.Typ)
}

227
auth/jwt/verify.go Normal file
View File

@ -0,0 +1,227 @@
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// SPDX-License-Identifier: MPL-2.0
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"errors"
"fmt"
"math/big"
"slices"
)
// Verifier holds the public keys of a JWT issuer and verifies token signatures.
//
// In OIDC terminology, the "issuer" is the identity provider that both signs
// tokens and publishes its public keys. Verifier represents that issuer from
// the relying party's perspective - you hold its public keys and use them to
// verify that tokens were legitimately signed by it.
//
// When a token's kid header matches a key, that key is tried. When the kid is
// empty, every key is tried in order; the first successful verification wins.
//
// Verifier is immutable after construction - safe for concurrent use with no locking.
// Use [NewVerifier] to construct with a fixed key set, or use [Signer.Verifier] or
// [keyfetch.KeyFetcher.Verifier] to obtain one from a signer or remote JWKS endpoint.
type Verifier struct {
pubKeys []PublicKey
}
// NewVerifier creates a Verifier with an explicit set of public keys.
//
// Multiple keys may share the same KID (e.g. during key rotation).
// When verifying, all keys with a matching KID are tried until one succeeds.
// Keys with identical KID and key material are deduplicated automatically.
//
// The returned Verifier is immutable - keys cannot be added or removed after
// construction. For dynamic key rotation, see keyfetch.KeyFetcher.
func NewVerifier(keys []PublicKey) (*Verifier, error) {
if len(keys) == 0 {
return nil, fmt.Errorf("NewVerifier: %w", ErrNoVerificationKey)
}
deduped := make([]PublicKey, 0, len(keys))
type seenEntry struct {
key CryptoPublicKey
index int
}
seen := make(map[string][]seenEntry, len(keys))
for _, k := range keys {
entries := seen[k.KID]
dup := false
for _, e := range entries {
if e.key.Equal(k.Key) {
dup = true
break
}
}
if dup {
continue // identical key material, skip
}
seen[k.KID] = append(entries, seenEntry{key: k.Key, index: len(deduped)})
deduped = append(deduped, k)
}
return &Verifier{
pubKeys: deduped,
}, nil
}
// PublicKeys returns a copy of the public keys held by this Verifier.
// Callers may safely modify the returned slice without affecting the Verifier.
//
// To serialize as a JWKS JSON document:
//
// json.Marshal(WellKnownJWKs{Keys: verifier.PublicKeys()})
func (v *Verifier) PublicKeys() []PublicKey {
return slices.Clone(v.pubKeys)
}
// Verify checks the signature of an already-decoded [VerifiableJWT].
//
// Key selection by KID:
// - Token has a KID: all verifier keys with a matching KID are tried
// (supports key rotation where multiple keys share a KID).
// Returns [ErrUnknownKID] if no key matches the KID.
// - Token has no KID: all verifier keys are tried.
//
// In both cases the first successful verification wins.
//
// Returns nil on success, a descriptive error on failure. Claim values
// (iss, aud, exp, etc.) are NOT checked - call [Validator.Validate] on the
// unmarshalled claims after verifying.
//
// Use [Decode] followed by Verify when you need to inspect the header
// (kid, alg) before deciding which verifier to apply:
//
// jws, err := jwt.Decode(tokenStr)
// if err != nil { /* malformed */ }
// // route by kid before verifying
// if err := chosenVerifier.Verify(jws); err != nil { /* bad sig */ }
//
// Use [Verifier.VerifyJWT] to decode and verify in one step.
func (v *Verifier) Verify(jws VerifiableJWT) error {
h := jws.GetHeader()
signingInput := signingInputBytes(jws.GetProtected(), jws.GetPayload())
sig := jws.GetSignature()
// Build the candidate key list: all keys with a matching KID, or all
// keys when the token has no KID. First successful verification wins.
// Multiple keys may share a KID during key rotation.
var candidates []PublicKey
if h.KID != "" {
for i := range v.pubKeys {
if v.pubKeys[i].KID == h.KID {
candidates = append(candidates, v.pubKeys[i])
}
}
if len(candidates) == 0 {
return fmt.Errorf("kid %q: %w", h.KID, ErrUnknownKID)
}
} else {
candidates = v.pubKeys
}
// Try each candidate key. Prefer ErrSignatureInvalid (key type matched
// but signature bytes didn't verify) over ErrAlgConflict (wrong key type
// for the token's algorithm) since it's more informative.
var bestErr error
for _, pk := range candidates {
err := verifyOneKey(h, pk.Key, signingInput, sig)
if err == nil {
return nil
}
if bestErr == nil || errors.Is(err, ErrSignatureInvalid) {
bestErr = err
}
}
return bestErr
}
// verifyOneKey checks the signature against a single key.
func verifyOneKey(h RFCHeader, key CryptoPublicKey, signingInput, sig []byte) error {
kid := h.KID
switch h.Alg {
case "ES256", "ES384", "ES512":
k, ok := key.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("kid %q alg %q: key type %T: %w", kid, h.Alg, key, ErrAlgConflict)
}
ci, err := ecInfoForAlg(k.Curve, h.Alg)
if err != nil {
return fmt.Errorf("kid %q alg %q: %w", kid, h.Alg, err)
}
if len(sig) != 2*ci.KeySize {
return fmt.Errorf("kid %q alg %q: sig len %d, want %d: %w", kid, h.Alg, len(sig), 2*ci.KeySize, ErrSignatureInvalid)
}
digest, err := digestFor(ci.Hash, signingInput)
if err != nil {
return fmt.Errorf("kid %q alg %q: %w", kid, h.Alg, err)
}
r := new(big.Int).SetBytes(sig[:ci.KeySize])
s := new(big.Int).SetBytes(sig[ci.KeySize:])
if !ecdsa.Verify(k, digest, r, s) {
return fmt.Errorf("kid %q alg %q: %w", kid, h.Alg, ErrSignatureInvalid)
}
return nil
case "RS256":
k, ok := key.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("kid %q alg %q: key type %T: %w", kid, h.Alg, key, ErrAlgConflict)
}
digest, err := digestFor(crypto.SHA256, signingInput)
if err != nil {
return fmt.Errorf("kid %q alg %q: %w", kid, h.Alg, err)
}
if err := rsa.VerifyPKCS1v15(k, crypto.SHA256, digest, sig); err != nil {
return fmt.Errorf("kid %q alg %q: %w: %w", kid, h.Alg, ErrSignatureInvalid, err)
}
return nil
case "EdDSA":
k, ok := key.(ed25519.PublicKey)
if !ok {
return fmt.Errorf("kid %q alg %q: key type %T: %w", kid, h.Alg, key, ErrAlgConflict)
}
if !ed25519.Verify(k, signingInput, sig) {
return fmt.Errorf("kid %q alg %q: %w", kid, h.Alg, ErrSignatureInvalid)
}
return nil
default:
return fmt.Errorf("kid %q alg %q: %w", kid, h.Alg, ErrUnsupportedAlg)
}
}
// VerifyJWT decodes tokenStr and verifies its signature, returning the parsed
// [*JWT] on success.
//
// Returns (nil, err) on any failure - the caller never receives an
// unauthenticated JWT. Claim values (iss, aud, exp, etc.) are NOT checked;
// call [Validator.Validate] on the unmarshalled claims after VerifyJWT:
//
// jws, err := v.VerifyJWT(tokenStr)
// if err != nil { /* bad sig, malformed token, unknown kid */ }
// var claims AppClaims
// if err := jws.UnmarshalClaims(&claims); err != nil { /* ... */ }
// if err := v.Validate(nil, &claims, time.Now()); err != nil { /* ... */ }
//
// For routing by kid/iss before verifying, use [Decode] then [Verifier.Verify].
func (v *Verifier) VerifyJWT(tokenStr string) (*JWT, error) {
jws, err := Decode(tokenStr)
if err != nil {
return nil, err
}
if err := v.Verify(jws); err != nil {
return nil, err
}
return jws, nil
}