mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-29 13:13:57 +00:00
2460 lines
68 KiB
Go
2460 lines
68 KiB
Go
// Copyright 2026 AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
|
//
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package jwt_test
|
|
|
|
import (
|
|
"crypto/ecdsa"
|
|
"crypto/ed25519"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"crypto"
|
|
|
|
"github.com/therootcompany/golib/auth/jwt"
|
|
)
|
|
|
|
func mustPK(t testing.TB, signer crypto.Signer, kid string) *jwt.PrivateKey {
|
|
t.Helper()
|
|
pk, err := jwt.FromPrivateKey(signer, kid)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return pk
|
|
}
|
|
|
|
// AppClaims embeds TokenClaims and adds application-specific fields.
|
|
//
|
|
// Because TokenClaims is embedded, AppClaims satisfies Claims
|
|
// for free via Go's method promotion - no interface to implement.
|
|
type AppClaims struct {
|
|
jwt.TokenClaims
|
|
Email string `json:"email"`
|
|
Roles []string `json:"roles"`
|
|
}
|
|
|
|
// validateAppClaims is a plain function - not a method satisfying an interface.
|
|
// It demonstrates the Decode+Verify pattern: custom validation logic lives here,
|
|
// calling Validator.Validate and adding app-specific checks.
|
|
func validateAppClaims(c AppClaims, v *jwt.Validator, now time.Time) error {
|
|
var errs []error
|
|
if err := v.Validate(nil, &c, now); err != nil {
|
|
errs = append(errs, err)
|
|
}
|
|
if c.Email == "" {
|
|
errs = append(errs, errors.New("missing email claim"))
|
|
}
|
|
return errors.Join(errs...)
|
|
}
|
|
|
|
func goodClaims() AppClaims {
|
|
now := time.Now()
|
|
return AppClaims{
|
|
TokenClaims: jwt.TokenClaims{
|
|
Iss: "https://example.com",
|
|
Sub: "user123",
|
|
Aud: jwt.Listish{"myapp"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
AuthTime: now.Unix(),
|
|
AMR: []string{"pwd"},
|
|
JTI: "abc123",
|
|
AzP: "myapp",
|
|
Nonce: "nonce1",
|
|
},
|
|
Email: "user@example.com",
|
|
Roles: []string{"admin"},
|
|
}
|
|
}
|
|
|
|
// goodValidator configures the ID token validator with iss set to "https://example.com".
|
|
// Iss checking is now the Validator's responsibility, not the Verifier's.
|
|
func goodValidator() *jwt.Validator {
|
|
return jwt.NewIDTokenValidator(
|
|
[]string{"https://example.com"},
|
|
[]string{"myapp"},
|
|
[]string{"myapp"},
|
|
0,
|
|
)
|
|
}
|
|
|
|
func goodVerifier(pub jwt.PublicKey) *jwt.Verifier {
|
|
v, err := jwt.NewVerifier([]jwt.PublicKey{pub})
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return v
|
|
}
|
|
|
|
// TestRoundTrip is the primary happy path using ES256.
|
|
// It demonstrates the full Verify / UnmarshalClaims / Validate flow.
|
|
func TestRoundTrip(t *testing.T) {
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
claims := goodClaims()
|
|
jws, err := signer.Sign(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if jws.GetHeader().Alg != "ES256" {
|
|
t.Fatalf("expected ES256, got %s", jws.GetHeader().Alg)
|
|
}
|
|
|
|
token, err := jwt.Encode(jws)
|
|
if err != nil {
|
|
t.Fatalf("Encode failed: %v", err)
|
|
}
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"})
|
|
|
|
jws2, err := iss.VerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT failed: %v", err)
|
|
}
|
|
if jws2.GetHeader().Alg != "ES256" {
|
|
t.Errorf("expected ES256 alg in jws, got %s", jws2.GetHeader().Alg)
|
|
}
|
|
|
|
var decoded AppClaims
|
|
if err := jws2.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
|
}
|
|
if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil {
|
|
t.Fatalf("claim validation failed: %v", err)
|
|
}
|
|
// Direct field access - no type assertion needed.
|
|
if decoded.Email != claims.Email {
|
|
t.Errorf("email: got %s, want %s", decoded.Email, claims.Email)
|
|
}
|
|
}
|
|
|
|
// TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256.
|
|
func TestRoundTripRS256(t *testing.T) {
|
|
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
claims := goodClaims()
|
|
jws, err := signer.Sign(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if jws.GetHeader().Alg != "RS256" {
|
|
t.Fatalf("expected RS256, got %s", jws.GetHeader().Alg)
|
|
}
|
|
|
|
token, err := jwt.Encode(jws)
|
|
if err != nil {
|
|
t.Fatalf("Encode failed: %v", err)
|
|
}
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"})
|
|
|
|
jws2, err := iss.VerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT failed: %v", err)
|
|
}
|
|
var decoded AppClaims
|
|
if err := jws2.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
|
}
|
|
if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil {
|
|
t.Fatalf("claim validation failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestRoundTripEdDSA exercises Ed25519 / EdDSA (RFC 8037).
|
|
func TestRoundTripEdDSA(t *testing.T) {
|
|
pubKeyBytes, privKey, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
claims := goodClaims()
|
|
jws, err := signer.Sign(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if jws.GetHeader().Alg != "EdDSA" {
|
|
t.Fatalf("expected EdDSA, got %s", jws.GetHeader().Alg)
|
|
}
|
|
|
|
token, err := jwt.Encode(jws)
|
|
if err != nil {
|
|
t.Fatalf("Encode failed: %v", err)
|
|
}
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: pubKeyBytes, KID: "key-1"})
|
|
|
|
jws2, err := iss.VerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT failed: %v", err)
|
|
}
|
|
var decoded AppClaims
|
|
if err := jws2.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
|
}
|
|
if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil {
|
|
t.Fatalf("claim validation failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestRoundTripES384 exercises ECDSA P-384 / ES384.
|
|
func TestRoundTripES384(t *testing.T) {
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
claims := goodClaims()
|
|
jws, err := signer.Sign(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if jws.GetHeader().Alg != "ES384" {
|
|
t.Fatalf("expected ES384, got %s", jws.GetHeader().Alg)
|
|
}
|
|
|
|
token, err := jwt.Encode(jws)
|
|
if err != nil {
|
|
t.Fatalf("Encode failed: %v", err)
|
|
}
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"})
|
|
|
|
jws2, err := iss.VerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT failed: %v", err)
|
|
}
|
|
var decoded AppClaims
|
|
if err := jws2.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
|
}
|
|
if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil {
|
|
t.Fatalf("claim validation failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestRoundTripES512 exercises ECDSA P-521 / ES512.
|
|
func TestRoundTripES512(t *testing.T) {
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
claims := goodClaims()
|
|
jws, err := signer.Sign(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if jws.GetHeader().Alg != "ES512" {
|
|
t.Fatalf("expected ES512, got %s", jws.GetHeader().Alg)
|
|
}
|
|
|
|
token, err := jwt.Encode(jws)
|
|
if err != nil {
|
|
t.Fatalf("Encode failed: %v", err)
|
|
}
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"})
|
|
|
|
jws2, err := iss.VerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT failed: %v", err)
|
|
}
|
|
var decoded AppClaims
|
|
if err := jws2.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
|
}
|
|
if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil {
|
|
t.Fatalf("claim validation failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestDecodeVerifyFlow demonstrates the Decode + Verify + custom validation pattern.
|
|
// The caller owns the full validation pipeline.
|
|
func TestDecodeVerifyFlow(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")})
|
|
|
|
claims := goodClaims()
|
|
token, _ := signer.SignToString(&claims)
|
|
|
|
iss, _ := jwt.NewVerifier([]jwt.PublicKey{{Key: &privKey.PublicKey, KID: "k"}})
|
|
|
|
jws2, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatalf("Decode failed: %v", err)
|
|
}
|
|
if err := iss.Verify(jws2); err != nil {
|
|
t.Fatalf("Verify failed: %v", err)
|
|
}
|
|
|
|
var decoded AppClaims
|
|
if err := jws2.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
|
}
|
|
|
|
if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil {
|
|
t.Fatalf("Validate failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestDecodeReturnsParsedOnSigFailure verifies that Decode returns a non-nil
|
|
// *StandardJWS even when the token will later fail signature verification.
|
|
// Callers can inspect the header (kid, alg) for routing before calling Verify.
|
|
func TestDecodeReturnsParsedOnSigFailure(t *testing.T) {
|
|
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, signingKey, "k")})
|
|
|
|
claims := goodClaims()
|
|
token, _ := signer.SignToString(&claims)
|
|
|
|
// Verifier has wrong public key - sig verification will fail.
|
|
iss, _ := jwt.NewVerifier([]jwt.PublicKey{{Key: &wrongKey.PublicKey, KID: "k"}})
|
|
|
|
// Decode always succeeds for well-formed tokens.
|
|
result, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatalf("Decode failed: %v", err)
|
|
}
|
|
if result == nil {
|
|
t.Fatal("Decode should return non-nil StandardJWS")
|
|
}
|
|
if result.GetHeader().KID != "k" {
|
|
t.Errorf("expected kid %q, got %q", "k", result.GetHeader().KID)
|
|
}
|
|
|
|
// Verify should fail with the wrong key.
|
|
if err := iss.Verify(result); err == nil {
|
|
t.Fatal("expected Verify to fail with wrong key")
|
|
}
|
|
}
|
|
|
|
// TestCustomValidation demonstrates that Validator.Validate is called
|
|
// explicitly and custom fields are validated alongside the standard ones.
|
|
func TestCustomValidation(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
// Token with empty Email - our custom validator should reject it.
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")})
|
|
claims := goodClaims()
|
|
claims.Email = ""
|
|
token, _ := signer.SignToString(&claims)
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"})
|
|
jws2, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatalf("Decode failed: %v", err)
|
|
}
|
|
if err := iss.Verify(jws2); err != nil {
|
|
t.Fatalf("Verify failed unexpectedly: %v", err)
|
|
}
|
|
|
|
var decoded AppClaims
|
|
_ = jws2.UnmarshalClaims(&decoded)
|
|
|
|
err = validateAppClaims(decoded, goodValidator(), time.Now())
|
|
if err == nil {
|
|
t.Fatal("expected validation to fail: email is empty")
|
|
}
|
|
if !strings.Contains(err.Error(), "missing email claim") {
|
|
t.Fatalf("expected 'missing email claim' in error: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestNBFValidation confirms that a token with nbf in the future is rejected,
|
|
// and that a token with nbf in the past (or absent) is accepted.
|
|
func TestNBFValidation(t *testing.T) {
|
|
now := time.Now()
|
|
|
|
base := AppClaims{
|
|
TokenClaims: jwt.TokenClaims{
|
|
Iss: "https://example.com",
|
|
Aud: jwt.Listish{"myapp"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
},
|
|
}
|
|
|
|
v := &jwt.Validator{
|
|
Checks: jwt.CheckIss | jwt.CheckAud | jwt.CheckExp | jwt.CheckIAt | jwt.CheckNBf,
|
|
Iss: []string{"https://example.com"},
|
|
Aud: []string{"myapp"},
|
|
}
|
|
|
|
// No nbf: should pass.
|
|
if err := v.Validate(nil, &base, now); err != nil {
|
|
t.Fatalf("expected no error without nbf: %v", err)
|
|
}
|
|
|
|
// nbf in the past: should pass.
|
|
pastNBF := base
|
|
pastNBF.NBf = now.Add(-time.Hour).Unix()
|
|
if err := v.Validate(nil, &pastNBF, now); err != nil {
|
|
t.Fatalf("expected no error with past nbf: %v", err)
|
|
}
|
|
|
|
// nbf in the future: must be rejected.
|
|
futureNBF := base
|
|
futureNBF.NBf = now.Add(time.Hour).Unix()
|
|
err := v.Validate(nil, &futureNBF, now)
|
|
if err == nil {
|
|
t.Fatal("expected error for future nbf")
|
|
}
|
|
if !errors.Is(err, jwt.ErrBeforeNBf) {
|
|
t.Fatalf("expected ErrBeforeNBf, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestVerifyWithoutValidation confirms that Verify + UnmarshalClaims succeeds
|
|
// independently of claim validation - the caller decides whether to validate.
|
|
func TestVerifyWithoutValidation(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")})
|
|
c := goodClaims()
|
|
token, _ := signer.SignToString(&c)
|
|
|
|
iss, _ := jwt.NewVerifier([]jwt.PublicKey{{Key: &privKey.PublicKey, KID: "k"}})
|
|
|
|
jws2, err := iss.VerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT failed: %v", err)
|
|
}
|
|
var claims AppClaims
|
|
if err := jws2.UnmarshalClaims(&claims); err != nil {
|
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
|
}
|
|
if claims.Email != c.Email {
|
|
t.Errorf("claims not unmarshalled: email got %q, want %q", claims.Email, c.Email)
|
|
}
|
|
}
|
|
|
|
// TestVerifierWrongKey confirms that a different key's public key is rejected.
|
|
func TestVerifierWrongKey(t *testing.T) {
|
|
signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, signingKey, "k")})
|
|
|
|
claims := goodClaims()
|
|
token, _ := signer.SignToString(&claims)
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: &wrongKey.PublicKey, KID: "k"})
|
|
|
|
parsed, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := iss.Verify(parsed); !errors.Is(err, jwt.ErrSignatureInvalid) {
|
|
t.Fatalf("expected ErrSignatureInvalid, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestVerifierUnknownKid confirms that an unknown kid is rejected.
|
|
func TestVerifierUnknownKid(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "unknown-kid")})
|
|
|
|
claims := goodClaims()
|
|
token, _ := signer.SignToString(&claims)
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "known-kid"})
|
|
|
|
parsed, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := iss.Verify(parsed); !errors.Is(err, jwt.ErrUnknownKID) {
|
|
t.Fatalf("expected ErrUnknownKID, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestVerifierIssMismatch confirms that a token with a mismatched iss is caught
|
|
// by the Validator, not the Verifier. Signature verification succeeds; the iss
|
|
// mismatch appears as a soft validation error.
|
|
func TestVerifierIssMismatch(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")})
|
|
|
|
claims := goodClaims()
|
|
claims.Iss = "https://evil.example.com"
|
|
token, _ := signer.SignToString(&claims)
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"})
|
|
|
|
// Decode+Verify succeeds - iss is not checked at the Verifier level.
|
|
parsed, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := iss.Verify(parsed); err != nil {
|
|
t.Fatalf("Verify should succeed (no iss check): %v", err)
|
|
}
|
|
|
|
// VerifyJWT + Validate: signature passes but iss validation catches the mismatch.
|
|
jws2, err := iss.VerifyJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("unexpected hard error from VerifyJWT: %v", err)
|
|
}
|
|
var decoded AppClaims
|
|
if err := jws2.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
|
}
|
|
err = goodValidator().Validate(nil, &decoded, time.Now())
|
|
if err == nil {
|
|
t.Fatal("expected validation errors for iss mismatch")
|
|
}
|
|
if !errors.Is(err, jwt.ErrInvalidClaim) {
|
|
t.Fatalf("expected ErrInvalidClaim for iss mismatch, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestVerifyTamperedAlg confirms that a tampered alg header ("none") is rejected.
|
|
// The token is reconstructed with a replaced protected header; the original
|
|
// ES256 signature is kept, making the signing input mismatch detectable.
|
|
func TestVerifyTamperedAlg(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")})
|
|
|
|
claims := goodClaims()
|
|
token, _ := signer.SignToString(&claims)
|
|
|
|
iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"})
|
|
|
|
// Replace the protected header with one that has alg:"none".
|
|
// The original ES256 signature stays - the signing input will mismatch.
|
|
noneHeader := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","kid":"k","typ":"JWT"}`))
|
|
parts := strings.SplitN(token, ".", 3)
|
|
tamperedToken := noneHeader + "." + parts[1] + "." + parts[2]
|
|
|
|
parsed, err := jwt.Decode(tamperedToken)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := iss.Verify(parsed); !errors.Is(err, jwt.ErrUnsupportedAlg) {
|
|
t.Fatalf("expected ErrUnsupportedAlg for tampered alg, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestSignerRoundTrip verifies the Signer / Sign / Verifier / Verify / Validate flow.
|
|
func TestSignerRoundTrip(t *testing.T) {
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k1")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
claims := goodClaims()
|
|
tokenStr, err := signer.SignToString(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
iss := signer.Verifier()
|
|
jws, err := iss.VerifyJWT(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT failed: %v", err)
|
|
}
|
|
var decoded AppClaims
|
|
if err := jws.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims failed: %v", err)
|
|
}
|
|
if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil {
|
|
t.Fatalf("claim validation failed: %v", err)
|
|
}
|
|
if decoded.Email != claims.Email {
|
|
t.Errorf("email: got %s, want %s", decoded.Email, claims.Email)
|
|
}
|
|
}
|
|
|
|
// TestSignerAutoKID verifies that KID is auto-computed from the key thumbprint
|
|
// when PrivateKey.KID is empty.
|
|
func TestSignerAutoKID(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
keys := signer.Keys
|
|
if len(keys) != 1 {
|
|
t.Fatalf("expected 1 key, got %d", len(keys))
|
|
}
|
|
if keys[0].KID == "" {
|
|
t.Fatal("KID should be auto-computed from thumbprint")
|
|
}
|
|
|
|
// Token should verify with the auto-KID issuer.
|
|
iss := signer.Verifier()
|
|
claims := goodClaims()
|
|
tokenStr, _ := signer.SignToString(&claims)
|
|
|
|
parsed, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := iss.Verify(parsed); err != nil {
|
|
t.Fatalf("Verify failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestSignerRoundRobin verifies that signing round-robins across keys and that
|
|
// all resulting tokens verify with the combined Verifier.
|
|
func TestSignerRoundRobin(t *testing.T) {
|
|
key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
key2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{
|
|
mustPK(t, key1, "k1"),
|
|
mustPK(t, key2, "k2"),
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
iss := signer.Verifier()
|
|
v := goodValidator()
|
|
|
|
for i := range 4 {
|
|
claims := goodClaims()
|
|
tokenStr, err := signer.SignToString(&claims)
|
|
if err != nil {
|
|
t.Fatalf("Sign[%d] failed: %v", i, err)
|
|
}
|
|
jws, err := iss.VerifyJWT(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT[%d] failed: %v", i, err)
|
|
}
|
|
var decoded AppClaims
|
|
if err := jws.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims[%d] failed: %v", i, err)
|
|
}
|
|
if err := v.Validate(nil, &decoded, time.Now()); err != nil {
|
|
t.Fatalf("Validate[%d] failed: %v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestSignJWTSelectsByKID verifies that when the header already has a KID,
|
|
// SignJWT uses that specific key instead of round-robin.
|
|
func TestSignJWTSelectsByKID(t *testing.T) {
|
|
key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
key2, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{
|
|
mustPK(t, key1, "ec256"),
|
|
mustPK(t, key2, "ec384"),
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Request signing with key2 specifically by setting KID in the header.
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://example.com",
|
|
Sub: "user-1",
|
|
Aud: jwt.Listish{"app"},
|
|
Exp: time.Now().Add(time.Hour).Unix(),
|
|
IAt: time.Now().Unix(),
|
|
}
|
|
jws, err := jwt.New(claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
jws.SetTyp("JWT")
|
|
// Pre-set the KID to select key2.
|
|
hdr := jws.GetHeader()
|
|
hdr.KID = "ec384"
|
|
if err := jws.SetHeader(&hdr); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := signer.SignJWT(jws); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Should have used ES384, not ES256.
|
|
if got := jws.GetHeader().Alg; got != "ES384" {
|
|
t.Fatalf("alg: got %s, want ES384 (should have selected key2)", got)
|
|
}
|
|
if got := jws.GetHeader().KID; got != "ec384" {
|
|
t.Fatalf("kid: got %s, want ec384", got)
|
|
}
|
|
|
|
// Verify round-trip.
|
|
verifier := signer.Verifier()
|
|
tokenStr, err := jwt.Encode(jws)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
parsed, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := verifier.Verify(parsed); err != nil {
|
|
t.Fatalf("Verify: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestSignJWTUnknownKID verifies that SignJWT returns ErrUnknownKID when the
|
|
// header requests a KID that the signer doesn't have.
|
|
func TestSignJWTUnknownKID(t *testing.T) {
|
|
key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, key1, "k1")})
|
|
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://example.com",
|
|
Sub: "user-1",
|
|
Aud: jwt.Listish{"app"},
|
|
Exp: time.Now().Add(time.Hour).Unix(),
|
|
IAt: time.Now().Unix(),
|
|
}
|
|
jws, _ := jwt.New(claims)
|
|
hdr := jws.GetHeader()
|
|
hdr.KID = "nonexistent"
|
|
_ = jws.SetHeader(&hdr)
|
|
|
|
err := signer.SignJWT(jws)
|
|
if !errors.Is(err, jwt.ErrUnknownKID) {
|
|
t.Fatalf("expected ErrUnknownKID, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestJWKsRoundTrip verifies JWKS serialization and round-trip parsing.
|
|
func TestJWKsRoundTrip(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k1")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
jwksBytes, err := json.Marshal(signer)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Round-trip: parse the JWKS JSON and verify it produces a working Verifier.
|
|
var jwks jwt.WellKnownJWKs
|
|
if err := json.Unmarshal(jwksBytes, &jwks); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
keys := jwks.Keys
|
|
if len(keys) != 1 {
|
|
t.Fatalf("expected 1 key, got %d", len(keys))
|
|
}
|
|
if keys[0].KID != "k1" {
|
|
t.Errorf("expected kid 'k1', got %q", keys[0].KID)
|
|
}
|
|
|
|
iss2, _ := jwt.NewVerifier(keys)
|
|
claims := goodClaims()
|
|
tokenStr, _ := signer.SignToString(&claims)
|
|
|
|
parsed, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := iss2.Verify(parsed); err != nil {
|
|
t.Fatalf("Verify on round-tripped JWKS failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestKeyType verifies that KeyType returns the correct JWK kty string for each key type.
|
|
func TestKeyType(t *testing.T) {
|
|
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
|
edPub, _, _ := ed25519.GenerateKey(rand.Reader)
|
|
|
|
tests := []struct {
|
|
name string
|
|
key jwt.PublicKey
|
|
wantKty string
|
|
}{
|
|
{"EC P-256", jwt.PublicKey{Key: &ecKey.PublicKey}, "EC"},
|
|
{"RSA 2048", jwt.PublicKey{Key: &rsaKey.PublicKey}, "RSA"},
|
|
{"Ed25519", jwt.PublicKey{Key: edPub}, "OKP"},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := tt.key.KeyType(); got != tt.wantKty {
|
|
t.Errorf("KeyType() = %q, want %q", got, tt.wantKty)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestPublicKeyOps verifies that PrivateKey.PublicKey() translates key_ops to their
|
|
// public-key counterparts ("sign"=>"verify", "decrypt"=>"encrypt", "unwrapKey"=>"wrapKey").
|
|
func TestPublicKeyOps(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
privateOps []string
|
|
wantOps []string
|
|
}{
|
|
{"sign=>verify", []string{"sign"}, []string{"verify"}},
|
|
{"decrypt=>encrypt", []string{"decrypt"}, []string{"encrypt"}},
|
|
{"unwrapKey=>wrapKey", []string{"unwrapKey"}, []string{"wrapKey"}},
|
|
{"multiple", []string{"sign", "decrypt"}, []string{"verify", "encrypt"}},
|
|
{"public op passthrough", []string{"verify"}, []string{"verify"}},
|
|
{"no public equivalent dropped", []string{"deriveKey"}, nil},
|
|
{"empty", nil, nil},
|
|
}
|
|
base, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
pk := *base
|
|
pk.KeyOps = tt.privateOps
|
|
pub, err := pk.PublicKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(pub.KeyOps) != len(tt.wantOps) {
|
|
t.Fatalf("KeyOps = %v, want %v", pub.KeyOps, tt.wantOps)
|
|
}
|
|
for i, op := range pub.KeyOps {
|
|
if op != tt.wantOps[i] {
|
|
t.Errorf("KeyOps[%d] = %q, want %q", i, op, tt.wantOps[i])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDecodePublicJWKJSON verifies JWKS JSON parsing with real base64url-encoded
|
|
// key material from RFC 7517 / OIDC examples.
|
|
func TestDecodePublicJWKJSON(t *testing.T) {
|
|
jwksJSON := []byte(`{"keys":[
|
|
{"kty":"EC","crv":"P-256",
|
|
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
|
|
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
|
|
"kid":"ec-256","use":"sig"},
|
|
{"kty":"RSA",
|
|
"n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
|
|
"e":"AQAB","kid":"rsa-2048","use":"sig"}
|
|
]}`)
|
|
|
|
var jwks jwt.WellKnownJWKs
|
|
if err := json.Unmarshal(jwksJSON, &jwks); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
keys := jwks.Keys
|
|
if len(keys) != 2 {
|
|
t.Fatalf("expected 2 keys, got %d", len(keys))
|
|
}
|
|
|
|
var ecCount, rsaCount int
|
|
for _, k := range keys {
|
|
switch k.KeyType() {
|
|
case "EC":
|
|
ecCount++
|
|
if k.KID != "ec-256" {
|
|
t.Errorf("unexpected EC kid: %s", k.KID)
|
|
}
|
|
case "RSA":
|
|
rsaCount++
|
|
if k.KID != "rsa-2048" {
|
|
t.Errorf("unexpected RSA kid: %s", k.KID)
|
|
}
|
|
}
|
|
}
|
|
if ecCount != 1 {
|
|
t.Errorf("expected 1 EC key, got %d", ecCount)
|
|
}
|
|
if rsaCount != 1 {
|
|
t.Errorf("expected 1 RSA key, got %d", rsaCount)
|
|
}
|
|
}
|
|
|
|
// TestThumbprint verifies that Thumbprint returns a non-empty base64url string
|
|
// for EC, RSA, and Ed25519 keys, and that two equal keys produce the same thumbprint.
|
|
func TestThumbprint(t *testing.T) {
|
|
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
|
edPub, _, _ := ed25519.GenerateKey(rand.Reader)
|
|
|
|
tests := []struct {
|
|
name string
|
|
pub jwt.PublicKey
|
|
}{
|
|
{"EC P-256", jwt.PublicKey{Key: &ecKey.PublicKey}},
|
|
{"RSA 2048", jwt.PublicKey{Key: &rsaKey.PublicKey}},
|
|
{"Ed25519", jwt.PublicKey{Key: edPub}},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
thumb, err := tt.pub.Thumbprint()
|
|
if err != nil {
|
|
t.Fatalf("Thumbprint() error: %v", err)
|
|
}
|
|
if thumb == "" {
|
|
t.Fatal("Thumbprint() returned empty string")
|
|
}
|
|
// Must be valid base64url (no padding, no +/)
|
|
if strings.Contains(thumb, "+") || strings.Contains(thumb, "/") || strings.Contains(thumb, "=") {
|
|
t.Errorf("Thumbprint() contains non-base64url characters: %s", thumb)
|
|
}
|
|
// Same key, same thumbprint
|
|
thumb2, _ := tt.pub.Thumbprint()
|
|
if thumb != thumb2 {
|
|
t.Errorf("Thumbprint() not deterministic: %s != %s", thumb, thumb2)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestNoKidAutoThumbprint verifies that a JWKS key without a "kid" field gets
|
|
// its KID auto-populated from the RFC 7638 thumbprint.
|
|
func TestNoKidAutoThumbprint(t *testing.T) {
|
|
// EC key with no "kid" field in the JWKS
|
|
jwksJSON := []byte(`{"keys":[
|
|
{"kty":"EC","crv":"P-256",
|
|
"x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
|
|
"y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM",
|
|
"use":"sig"}
|
|
]}`)
|
|
|
|
var jwks jwt.WellKnownJWKs
|
|
if err := json.Unmarshal(jwksJSON, &jwks); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
keys := jwks.Keys
|
|
if len(keys) != 1 {
|
|
t.Fatalf("expected 1 key, got %d", len(keys))
|
|
}
|
|
if keys[0].KID == "" {
|
|
t.Fatal("KID should be auto-populated from Thumbprint when absent in JWKS")
|
|
}
|
|
|
|
// The auto-KID should be a valid base64url string.
|
|
kid := keys[0].KID
|
|
if strings.Contains(kid, "+") || strings.Contains(kid, "/") || strings.Contains(kid, "=") {
|
|
t.Errorf("auto-KID contains non-base64url characters: %s", kid)
|
|
}
|
|
|
|
// Round-trip: compute Thumbprint directly and compare.
|
|
thumb, err := keys[0].Thumbprint()
|
|
if err != nil {
|
|
t.Fatalf("Thumbprint() error: %v", err)
|
|
}
|
|
if kid != thumb {
|
|
t.Errorf("auto-KID %q != direct Thumbprint %q", kid, thumb)
|
|
}
|
|
}
|
|
|
|
// TestNewPrivateKey verifies that jwt.NewPrivateKey generates an Ed25519 key
|
|
// with a non-empty KID auto-derived from the thumbprint, and that the key
|
|
// works end-to-end for signing and verification.
|
|
func TestNewPrivateKey(t *testing.T) {
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatalf("NewPrivateKey() error: %v", err)
|
|
}
|
|
if pk.KID == "" {
|
|
t.Fatal("NewPrivateKey() returned empty KID")
|
|
}
|
|
// KID must be base64url (no +, /, or =).
|
|
if strings.Contains(pk.KID, "+") || strings.Contains(pk.KID, "/") || strings.Contains(pk.KID, "=") {
|
|
t.Errorf("KID contains non-base64url characters: %s", pk.KID)
|
|
}
|
|
// Two calls must produce different keys but always produce valid base64url KIDs.
|
|
pk2, _ := jwt.NewPrivateKey()
|
|
if pk.KID == pk2.KID {
|
|
t.Error("NewPrivateKey() produced identical KIDs for two different keys")
|
|
}
|
|
|
|
// Full sign+verify round-trip with the generated key.
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatalf("NewSigner() error: %v", err)
|
|
}
|
|
claims := goodClaims()
|
|
tokenStr, err := signer.SignToString(&claims)
|
|
if err != nil {
|
|
t.Fatalf("SignToString() error: %v", err)
|
|
}
|
|
iss := signer.Verifier()
|
|
jws, err := iss.VerifyJWT(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT() error: %v", err)
|
|
}
|
|
var decoded AppClaims
|
|
if err := jws.UnmarshalClaims(&decoded); err != nil {
|
|
t.Fatalf("UnmarshalClaims() error: %v", err)
|
|
}
|
|
if decoded.Sub != claims.Sub {
|
|
t.Errorf("sub: got %q, want %q", decoded.Sub, claims.Sub)
|
|
}
|
|
}
|
|
|
|
// --- DecodeRaw + UnmarshalHeader tests ---
|
|
|
|
func TestDecodeRaw(t *testing.T) {
|
|
// Sign a real token to get a valid compact string.
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatalf("NewPrivateKey() error: %v", err)
|
|
}
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatalf("NewSigner() error: %v", err)
|
|
}
|
|
claims := goodClaims()
|
|
tokenStr, err := signer.SignToString(&claims)
|
|
if err != nil {
|
|
t.Fatalf("SignToString() error: %v", err)
|
|
}
|
|
|
|
raw, err := jwt.DecodeRaw(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("DecodeRaw() error: %v", err)
|
|
}
|
|
|
|
// protected and payload should be non-empty base64url segments.
|
|
if len(raw.GetProtected()) == 0 {
|
|
t.Error("GetProtected() is empty")
|
|
}
|
|
if len(raw.GetPayload()) == 0 {
|
|
t.Error("GetPayload() is empty")
|
|
}
|
|
if len(raw.GetSignature()) == 0 {
|
|
t.Error("GetSignature() is empty")
|
|
}
|
|
}
|
|
|
|
func TestDecodeRawErrors(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
sentinel error
|
|
}{
|
|
{"empty string", "", jwt.ErrMalformedToken},
|
|
{"one segment", "abc", jwt.ErrMalformedToken},
|
|
{"two segments", "abc.def", jwt.ErrMalformedToken},
|
|
{"four segments", "a.b.c.d", jwt.ErrMalformedToken},
|
|
{"bad signature base64", "eyJhbGciOiJFZERTQSJ9.eyJpc3MiOiJ4In0.!!!bad!!!", jwt.ErrSignatureInvalid},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
_, err := jwt.DecodeRaw(tc.input)
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if !errors.Is(err, tc.sentinel) {
|
|
t.Errorf("expected %v, got: %v", tc.sentinel, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDecodeRawSegmentCount(t *testing.T) {
|
|
_, err := jwt.DecodeRaw("")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
// Empty string normalizes to 0 segments.
|
|
if !strings.Contains(err.Error(), "got 0") {
|
|
t.Errorf("expected segment count 0 in error, got: %v", err)
|
|
}
|
|
|
|
_, err = jwt.DecodeRaw("a.b")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "got 2") {
|
|
t.Errorf("expected segment count 2 in error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestUnmarshalHeader(t *testing.T) {
|
|
// Sign a token and use DecodeRaw + UnmarshalHeader to recover the header.
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatalf("NewPrivateKey() error: %v", err)
|
|
}
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatalf("NewSigner() error: %v", err)
|
|
}
|
|
claims := goodClaims()
|
|
tokenStr, err := signer.SignToString(&claims)
|
|
if err != nil {
|
|
t.Fatalf("SignToString() error: %v", err)
|
|
}
|
|
|
|
raw, err := jwt.DecodeRaw(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("DecodeRaw() error: %v", err)
|
|
}
|
|
|
|
var h jwt.RFCHeader
|
|
if err := raw.UnmarshalHeader(&h); err != nil {
|
|
t.Fatalf("UnmarshalHeader() error: %v", err)
|
|
}
|
|
|
|
if h.Alg != "EdDSA" {
|
|
t.Errorf("alg: got %q, want %q", h.Alg, "EdDSA")
|
|
}
|
|
if h.KID != pk.KID {
|
|
t.Errorf("kid: got %q, want %q", h.KID, pk.KID)
|
|
}
|
|
if h.Typ != "JWT" {
|
|
t.Errorf("typ: got %q, want %q", h.Typ, "JWT")
|
|
}
|
|
}
|
|
|
|
func TestUnmarshalHeaderCustomFields(t *testing.T) {
|
|
// Build a token whose header has a custom "nonce" field by constructing
|
|
// the compact string manually: base64(header).base64(payload).base64(sig).
|
|
type CustomHeader struct {
|
|
jwt.RFCHeader
|
|
Nonce string `json:"nonce"`
|
|
}
|
|
|
|
hdr := CustomHeader{
|
|
RFCHeader: jwt.RFCHeader{Alg: "EdDSA", KID: "test-key", Typ: "dpop+jwt"},
|
|
Nonce: "server-nonce-42",
|
|
}
|
|
hdrJSON, _ := json.Marshal(hdr)
|
|
payJSON := []byte(`{"sub":"user"}`)
|
|
fakeSig := []byte{0xDE, 0xAD}
|
|
|
|
compact := base64.RawURLEncoding.EncodeToString(hdrJSON) +
|
|
"." + base64.RawURLEncoding.EncodeToString(payJSON) +
|
|
"." + base64.RawURLEncoding.EncodeToString(fakeSig)
|
|
|
|
raw, err := jwt.DecodeRaw(compact)
|
|
if err != nil {
|
|
t.Fatalf("DecodeRaw() error: %v", err)
|
|
}
|
|
|
|
var got CustomHeader
|
|
if err := raw.UnmarshalHeader(&got); err != nil {
|
|
t.Fatalf("UnmarshalHeader() error: %v", err)
|
|
}
|
|
|
|
if got.Nonce != "server-nonce-42" {
|
|
t.Errorf("nonce: got %q, want %q", got.Nonce, "server-nonce-42")
|
|
}
|
|
if got.Alg != "EdDSA" {
|
|
t.Errorf("alg: got %q, want %q", got.Alg, "EdDSA")
|
|
}
|
|
if got.Typ != "dpop+jwt" {
|
|
t.Errorf("typ: got %q, want %q", got.Typ, "dpop+jwt")
|
|
}
|
|
}
|
|
|
|
func TestUnmarshalHeaderViaJWS(t *testing.T) {
|
|
// Verify that UnmarshalHeader is promoted from RawJWT to *JWT.
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatalf("NewPrivateKey() error: %v", err)
|
|
}
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatalf("NewSigner() error: %v", err)
|
|
}
|
|
claims := goodClaims()
|
|
tokenStr, err := signer.SignToString(&claims)
|
|
if err != nil {
|
|
t.Fatalf("SignToString() error: %v", err)
|
|
}
|
|
|
|
// Use Decode (not DecodeRaw) - UnmarshalHeader should still work via promotion.
|
|
jws, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("Decode() error: %v", err)
|
|
}
|
|
|
|
var h jwt.RFCHeader
|
|
if err := jws.UnmarshalHeader(&h); err != nil {
|
|
t.Fatalf("jws.UnmarshalHeader() error: %v", err)
|
|
}
|
|
if h.Alg != "EdDSA" {
|
|
t.Errorf("alg: got %q, want %q", h.Alg, "EdDSA")
|
|
}
|
|
}
|
|
|
|
// --- SpaceDelimited tests ---
|
|
|
|
func TestSpaceDelimitedMarshalJSON(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
in jwt.SpaceDelimited
|
|
want string
|
|
}{
|
|
{"multiple", jwt.SpaceDelimited{"openid", "profile", "email"}, `"openid profile email"`},
|
|
{"single", jwt.SpaceDelimited{"openid"}, `"openid"`},
|
|
{"empty", jwt.SpaceDelimited{}, `""`},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := json.Marshal(tt.in)
|
|
if err != nil {
|
|
t.Fatalf("Marshal error: %v", err)
|
|
}
|
|
if string(got) != tt.want {
|
|
t.Errorf("got %s, want %s", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSpaceDelimitedUnmarshalJSON(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
in string
|
|
want jwt.SpaceDelimited
|
|
wantNil bool
|
|
}{
|
|
{"multiple", `"openid profile email"`, jwt.SpaceDelimited{"openid", "profile", "email"}, false},
|
|
{"single", `"openid"`, jwt.SpaceDelimited{"openid"}, false},
|
|
{"empty", `""`, jwt.SpaceDelimited{}, false},
|
|
{"extra whitespace", `"openid profile\temail"`, jwt.SpaceDelimited{"openid", "profile", "email"}, false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var got jwt.SpaceDelimited
|
|
if err := json.Unmarshal([]byte(tt.in), &got); err != nil {
|
|
t.Fatalf("Unmarshal error: %v", err)
|
|
}
|
|
if tt.wantNil {
|
|
if got != nil {
|
|
t.Errorf("got %v, want nil", got)
|
|
}
|
|
return
|
|
}
|
|
if len(got) != len(tt.want) {
|
|
t.Fatalf("got %v (len %d), want %v (len %d)", got, len(got), tt.want, len(tt.want))
|
|
}
|
|
for i := range got {
|
|
if got[i] != tt.want[i] {
|
|
t.Errorf("got[%d] = %q, want %q", i, got[i], tt.want[i])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSpaceDelimitedRoundTrip(t *testing.T) {
|
|
type claims struct {
|
|
Scope jwt.SpaceDelimited `json:"scope,omitempty"`
|
|
}
|
|
orig := claims{Scope: jwt.SpaceDelimited{"openid", "profile"}}
|
|
data, err := json.Marshal(orig)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Verify wire format is a space-separated string.
|
|
if !strings.Contains(string(data), `"openid profile"`) {
|
|
t.Fatalf("expected space-separated scope in JSON, got %s", data)
|
|
}
|
|
|
|
var decoded claims
|
|
if err := json.Unmarshal(data, &decoded); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(decoded.Scope) != 2 || decoded.Scope[0] != "openid" || decoded.Scope[1] != "profile" {
|
|
t.Errorf("round-trip failed: got %v", decoded.Scope)
|
|
}
|
|
}
|
|
|
|
// --- SetTyp / NewAccessToken tests ---
|
|
|
|
func TestSetTyp(t *testing.T) {
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://example.com",
|
|
Sub: "user1",
|
|
Aud: jwt.Listish{"api"},
|
|
Exp: time.Now().Add(time.Hour).Unix(),
|
|
IAt: time.Now().Unix(),
|
|
}
|
|
jws, err := jwt.New(claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Default typ is "JWT".
|
|
if got := jws.GetHeader().Typ; got != "JWT" {
|
|
t.Fatalf("default typ: got %q, want %q", got, "JWT")
|
|
}
|
|
|
|
jws.SetTyp(jwt.AccessTokenTyp)
|
|
|
|
if got := jws.GetHeader().Typ; got != jwt.AccessTokenTyp {
|
|
t.Fatalf("after SetTyp: got %q, want %q", got, jwt.AccessTokenTyp)
|
|
}
|
|
}
|
|
|
|
func TestSetTypSurvivesSigning(t *testing.T) {
|
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://auth.example.com",
|
|
Sub: "user1",
|
|
Aud: jwt.Listish{"https://api.example.com"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
JTI: "tok-001",
|
|
ClientID: "webapp",
|
|
Scope: jwt.SpaceDelimited{"openid", "profile"},
|
|
}
|
|
|
|
jws, err := jwt.NewAccessToken(claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := signer.SignJWT(jws); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Verify typ survived signing.
|
|
if got := jws.GetHeader().Typ; got != jwt.AccessTokenTyp {
|
|
t.Fatalf("typ after signing: got %q, want %q", got, jwt.AccessTokenTyp)
|
|
}
|
|
|
|
// Decode the token and verify typ is in the wire format.
|
|
token, err := jwt.Encode(jws)
|
|
if err != nil {
|
|
t.Fatalf("Encode failed: %v", err)
|
|
}
|
|
decoded, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got := decoded.GetHeader().Typ; got != jwt.AccessTokenTyp {
|
|
t.Fatalf("typ after decode: got %q, want %q", got, jwt.AccessTokenTyp)
|
|
}
|
|
|
|
// Verify claims round-trip.
|
|
var rt jwt.TokenClaims
|
|
if err := decoded.UnmarshalClaims(&rt); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if rt.ClientID != "webapp" {
|
|
t.Errorf("client_id: got %q, want %q", rt.ClientID, "webapp")
|
|
}
|
|
if len(rt.Scope) != 2 || rt.Scope[0] != "openid" {
|
|
t.Errorf("scope: got %v, want [openid profile]", rt.Scope)
|
|
}
|
|
}
|
|
|
|
// --- Access token Validator tests ---
|
|
|
|
func goodAccessTokenClaims() *jwt.TokenClaims {
|
|
now := time.Now()
|
|
return &jwt.TokenClaims{
|
|
Iss: "https://auth.example.com",
|
|
Sub: "user1",
|
|
Aud: jwt.Listish{"https://api.example.com"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
JTI: "tok-001",
|
|
ClientID: "webapp",
|
|
Scope: jwt.SpaceDelimited{"openid", "profile"},
|
|
AMR: []string{"pwd"},
|
|
}
|
|
}
|
|
|
|
func TestAccessTokenValidatorHappyPath(t *testing.T) {
|
|
v := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"},
|
|
[]string{"https://api.example.com"},
|
|
0,
|
|
)
|
|
claims := goodAccessTokenClaims()
|
|
if err := v.Validate(nil, claims, time.Now()); err != nil {
|
|
t.Fatalf("valid access token rejected: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAccessTokenValidatorRequiresJTI(t *testing.T) {
|
|
v := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"},
|
|
[]string{"https://api.example.com"},
|
|
0,
|
|
)
|
|
claims := goodAccessTokenClaims()
|
|
claims.JTI = ""
|
|
err := v.Validate(nil, claims, time.Now())
|
|
if !errors.Is(err, jwt.ErrMissingClaim) {
|
|
t.Fatalf("expected ErrMissingClaim for missing jti, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAccessTokenValidatorRequiresClientID(t *testing.T) {
|
|
v := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"},
|
|
[]string{"https://api.example.com"},
|
|
0,
|
|
)
|
|
claims := goodAccessTokenClaims()
|
|
claims.ClientID = ""
|
|
err := v.Validate(nil, claims, time.Now())
|
|
if !errors.Is(err, jwt.ErrMissingClaim) {
|
|
t.Fatalf("expected ErrMissingClaim for missing client_id, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAccessTokenValidatorDisableClientID(t *testing.T) {
|
|
v := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"},
|
|
[]string{"https://api.example.com"},
|
|
0,
|
|
)
|
|
v.Checks &^= jwt.CheckClientID
|
|
claims := goodAccessTokenClaims()
|
|
claims.ClientID = ""
|
|
if err := v.Validate(nil, claims, time.Now()); err != nil {
|
|
t.Fatalf("disabling CheckClientID should accept empty client_id: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAccessTokenValidatorExpiredToken(t *testing.T) {
|
|
v := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"},
|
|
[]string{"https://api.example.com"},
|
|
0,
|
|
)
|
|
claims := goodAccessTokenClaims()
|
|
claims.Exp = time.Now().Add(-time.Hour).Unix()
|
|
err := v.Validate(nil, claims, time.Now())
|
|
if !errors.Is(err, jwt.ErrAfterExp) {
|
|
t.Fatalf("expected ErrAfterExp, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAccessTokenValidatorRequiredScopes(t *testing.T) {
|
|
v := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"},
|
|
[]string{"https://api.example.com"},
|
|
0,
|
|
)
|
|
v.RequiredScopes = []string{"openid", "admin"}
|
|
claims := goodAccessTokenClaims()
|
|
// claims has ["openid", "profile"] - missing "admin"
|
|
err := v.Validate(nil, claims, time.Now())
|
|
if !errors.Is(err, jwt.ErrInsufficientScope) {
|
|
t.Fatalf("expected ErrInsufficientScope for missing scope, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAccessTokenValidatorExpectScope(t *testing.T) {
|
|
v := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"},
|
|
[]string{"https://api.example.com"},
|
|
0,
|
|
)
|
|
v.Checks |= jwt.CheckScope // enable scope presence check
|
|
claims := goodAccessTokenClaims()
|
|
claims.Scope = nil
|
|
err := v.Validate(nil, claims, time.Now())
|
|
if !errors.Is(err, jwt.ErrMissingClaim) {
|
|
t.Fatalf("expected ErrMissingClaim for empty scope, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAccessTokenValidatorDisableJTI(t *testing.T) {
|
|
v := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"},
|
|
[]string{"https://api.example.com"},
|
|
0,
|
|
)
|
|
v.Checks &^= jwt.CheckJTI
|
|
claims := goodAccessTokenClaims()
|
|
claims.JTI = ""
|
|
if err := v.Validate(nil, claims, time.Now()); err != nil {
|
|
t.Fatalf("disabling CheckJTI should accept empty jti: %v", err)
|
|
}
|
|
}
|
|
|
|
// --- Encode validation tests ---
|
|
|
|
// stubJWT is a minimal VerifiableJWT for testing Encode validation.
|
|
type stubJWT struct {
|
|
protected []byte
|
|
payload []byte
|
|
signature []byte
|
|
header jwt.RFCHeader
|
|
}
|
|
|
|
func (s *stubJWT) GetProtected() []byte { return s.protected }
|
|
func (s *stubJWT) GetPayload() []byte { return s.payload }
|
|
func (s *stubJWT) GetSignature() []byte { return s.signature }
|
|
func (s *stubJWT) GetHeader() jwt.RFCHeader { return s.header }
|
|
|
|
// TestEncodeRejectsEmptyAlg verifies that Encode returns an error
|
|
// when the alg header field is empty (unsigned token).
|
|
func TestEncodeRejectsEmptyAlg(t *testing.T) {
|
|
// Zero-value stub: no alg set.
|
|
jws := &stubJWT{}
|
|
_, err := jwt.Encode(jws)
|
|
if err == nil {
|
|
t.Fatal("expected error for empty alg")
|
|
}
|
|
if !errors.Is(err, jwt.ErrInvalidHeader) {
|
|
t.Fatalf("expected ErrInvalidHeader, got: %v", err)
|
|
}
|
|
if !strings.Contains(err.Error(), "alg is empty") {
|
|
t.Fatalf("unexpected message: %v", err)
|
|
}
|
|
|
|
// Explicit header with typ but no alg.
|
|
jws2 := &stubJWT{
|
|
protected: []byte(base64.RawURLEncoding.EncodeToString([]byte(`{"typ":"JWT"}`))),
|
|
payload: []byte(base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"user1"}`))),
|
|
signature: []byte{0x01, 0x02, 0x03},
|
|
header: jwt.RFCHeader{Typ: "JWT"},
|
|
}
|
|
_, err = jwt.Encode(jws2)
|
|
if err == nil {
|
|
t.Fatal("expected error for empty alg with typ-only header")
|
|
}
|
|
if !errors.Is(err, jwt.ErrInvalidHeader) {
|
|
t.Fatalf("expected ErrInvalidHeader, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestEncodeSucceedsAfterSigning verifies the happy path: a signed JWT
|
|
// encodes without error.
|
|
func TestEncodeSucceedsAfterSigning(t *testing.T) {
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k1")})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
claims := goodClaims()
|
|
jws, err := signer.Sign(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
token, err := jwt.Encode(jws)
|
|
if err != nil {
|
|
t.Fatalf("Encode failed on signed JWT: %v", err)
|
|
}
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 3 {
|
|
t.Fatalf("expected 3 segments, got %d", len(parts))
|
|
}
|
|
for i, p := range parts {
|
|
if p == "" {
|
|
t.Fatalf("segment %d is empty", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
// --- Full pipeline round-trip tests ---
|
|
|
|
// TestRoundTrip_IDToken exercises the full ID token pipeline:
|
|
// NewPrivateKey -> NewSigner -> SignToString -> Decode -> Verify -> UnmarshalClaims -> Validate.
|
|
func TestRoundTrip_IDToken(t *testing.T) {
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://idp.example.com",
|
|
Sub: "user-42",
|
|
Aud: jwt.Listish{"my-client"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
AuthTime: now.Unix(),
|
|
AzP: "my-client",
|
|
Nonce: "n-0S6_WzA2Mj",
|
|
AMR: []string{"pwd", "otp"},
|
|
JTI: "id-tok-001",
|
|
}
|
|
|
|
tokenStr, err := signer.SignToString(claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Decode + Verify
|
|
decoded, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("Decode: %v", err)
|
|
}
|
|
verifier := signer.Verifier()
|
|
if err := verifier.Verify(decoded); err != nil {
|
|
t.Fatalf("Verify: %v", err)
|
|
}
|
|
|
|
// UnmarshalClaims
|
|
var got jwt.TokenClaims
|
|
if err := decoded.UnmarshalClaims(&got); err != nil {
|
|
t.Fatalf("UnmarshalClaims: %v", err)
|
|
}
|
|
|
|
// Validate with NewIDTokenValidator
|
|
v := jwt.NewIDTokenValidator(
|
|
[]string{"https://idp.example.com"},
|
|
[]string{"my-client"},
|
|
[]string{"my-client"},
|
|
0,
|
|
)
|
|
if err := v.Validate(nil, &got, time.Now()); err != nil {
|
|
t.Fatalf("Validate: %v", err)
|
|
}
|
|
|
|
// Spot-check round-tripped fields.
|
|
if got.Sub != "user-42" {
|
|
t.Errorf("sub: got %q, want %q", got.Sub, "user-42")
|
|
}
|
|
if got.Nonce != "n-0S6_WzA2Mj" {
|
|
t.Errorf("nonce: got %q, want %q", got.Nonce, "n-0S6_WzA2Mj")
|
|
}
|
|
if len(got.AMR) != 2 || got.AMR[0] != "pwd" || got.AMR[1] != "otp" {
|
|
t.Errorf("amr: got %v, want [pwd otp]", got.AMR)
|
|
}
|
|
}
|
|
|
|
// TestRoundTrip_AccessToken exercises the full access token pipeline:
|
|
// NewAccessToken -> SignJWT -> Encode -> Decode -> Verify -> UnmarshalClaims -> Validate with RequiredScopes.
|
|
func TestRoundTrip_AccessToken(t *testing.T) {
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://auth.example.com",
|
|
Sub: "svc-account",
|
|
Aud: jwt.Listish{"https://api.example.com"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
JTI: "at-001",
|
|
ClientID: "backend-svc",
|
|
Scope: jwt.SpaceDelimited{"read", "write", "admin"},
|
|
}
|
|
|
|
jws, err := jwt.NewAccessToken(claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := signer.SignJWT(jws); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tokenStr, err := jwt.Encode(jws)
|
|
if err != nil {
|
|
t.Fatalf("Encode: %v", err)
|
|
}
|
|
|
|
// Decode + Verify
|
|
decoded, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("Decode: %v", err)
|
|
}
|
|
|
|
// Verify typ survived the round-trip.
|
|
if got := decoded.GetHeader().Typ; got != jwt.AccessTokenTyp {
|
|
t.Errorf("typ: got %q, want %q", got, jwt.AccessTokenTyp)
|
|
}
|
|
|
|
verifier := signer.Verifier()
|
|
if err := verifier.Verify(decoded); err != nil {
|
|
t.Fatalf("Verify: %v", err)
|
|
}
|
|
|
|
var got jwt.TokenClaims
|
|
if err := decoded.UnmarshalClaims(&got); err != nil {
|
|
t.Fatalf("UnmarshalClaims: %v", err)
|
|
}
|
|
|
|
// Validate with RequiredScopes.
|
|
v := jwt.NewAccessTokenValidator(
|
|
[]string{"https://auth.example.com"},
|
|
[]string{"https://api.example.com"},
|
|
0,
|
|
)
|
|
v.RequiredScopes = []string{"read", "write"}
|
|
if err := v.Validate(nil, &got, time.Now()); err != nil {
|
|
t.Fatalf("Validate: %v", err)
|
|
}
|
|
|
|
// Spot-check scope round-trip.
|
|
if len(got.Scope) != 3 || got.Scope[0] != "read" || got.Scope[2] != "admin" {
|
|
t.Errorf("scope: got %v, want [read write admin]", got.Scope)
|
|
}
|
|
if got.ClientID != "backend-svc" {
|
|
t.Errorf("client_id: got %q, want %q", got.ClientID, "backend-svc")
|
|
}
|
|
}
|
|
|
|
// TestRoundTrip_StandardClaims verifies that StandardClaims with NullBool fields
|
|
// survive the sign-decode round-trip.
|
|
func TestRoundTrip_StandardClaims(t *testing.T) {
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := &jwt.StandardClaims{
|
|
TokenClaims: jwt.TokenClaims{
|
|
Iss: "https://idp.example.com",
|
|
Sub: "user-99",
|
|
Aud: jwt.Listish{"app"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
AuthTime: now.Unix(),
|
|
AzP: "app",
|
|
},
|
|
Name: "Jane Doe",
|
|
Email: "jane@example.com",
|
|
EmailVerified: jwt.NullBool{Bool: true, Valid: true},
|
|
PhoneNumber: "+1-555-0100",
|
|
PhoneNumberVerified: jwt.NullBool{Bool: false, Valid: true},
|
|
Locale: "en-US",
|
|
}
|
|
|
|
tokenStr, err := signer.SignToString(claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Decode + Verify
|
|
verifier := signer.Verifier()
|
|
jws, err := verifier.VerifyJWT(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT: %v", err)
|
|
}
|
|
|
|
var got jwt.StandardClaims
|
|
if err := jws.UnmarshalClaims(&got); err != nil {
|
|
t.Fatalf("UnmarshalClaims: %v", err)
|
|
}
|
|
|
|
// Verify NullBool fields survived the round-trip.
|
|
if !got.EmailVerified.Valid || !got.EmailVerified.Bool {
|
|
t.Errorf("email_verified: got %+v, want {Bool:true Valid:true}", got.EmailVerified)
|
|
}
|
|
if !got.PhoneNumberVerified.Valid || got.PhoneNumberVerified.Bool {
|
|
t.Errorf("phone_number_verified: got %+v, want {Bool:false Valid:true}", got.PhoneNumberVerified)
|
|
}
|
|
|
|
// Verify other profile fields.
|
|
if got.Name != "Jane Doe" {
|
|
t.Errorf("name: got %q, want %q", got.Name, "Jane Doe")
|
|
}
|
|
if got.Email != "jane@example.com" {
|
|
t.Errorf("email: got %q, want %q", got.Email, "jane@example.com")
|
|
}
|
|
if got.Locale != "en-US" {
|
|
t.Errorf("locale: got %q, want %q", got.Locale, "en-US")
|
|
}
|
|
|
|
// Verify that an unset NullBool (not in JSON) comes back as zero value.
|
|
// StandardClaims has no field we explicitly omitted that is a NullBool
|
|
// other than the two we set, so we create a fresh StandardClaims without
|
|
// email_verified to test omission.
|
|
claims2 := &jwt.StandardClaims{
|
|
TokenClaims: jwt.TokenClaims{
|
|
Iss: "https://idp.example.com",
|
|
Sub: "user-100",
|
|
Aud: jwt.Listish{"app"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
AuthTime: now.Unix(),
|
|
AzP: "app",
|
|
},
|
|
Email: "bob@example.com",
|
|
// EmailVerified left as zero value (NullBool{})
|
|
}
|
|
tok2, err := signer.SignToString(claims2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
jws2, err := verifier.VerifyJWT(tok2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var got2 jwt.StandardClaims
|
|
if err := jws2.UnmarshalClaims(&got2); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got2.EmailVerified.Valid {
|
|
t.Errorf("omitted email_verified should be invalid, got %+v", got2.EmailVerified)
|
|
}
|
|
}
|
|
|
|
// --- DPoPJWT: custom header type used by TestRoundTrip_CustomHeader ---
|
|
|
|
// dpopHeader extends the standard JOSE header with a DPoP nonce.
|
|
type dpopHeader struct {
|
|
jwt.RFCHeader
|
|
Nonce string `json:"nonce,omitempty"`
|
|
}
|
|
|
|
// dpopJWT is a custom JWT that carries a dpopHeader.
|
|
type dpopJWT struct {
|
|
jwt.RawJWT
|
|
Header dpopHeader
|
|
}
|
|
|
|
func (d *dpopJWT) GetHeader() jwt.RFCHeader { return d.Header.RFCHeader }
|
|
|
|
func (d *dpopJWT) SetHeader(hdr jwt.Header) error {
|
|
d.Header.RFCHeader = *hdr.GetRFCHeader()
|
|
data, err := json.Marshal(d.Header)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.Protected = []byte(base64.RawURLEncoding.EncodeToString(data))
|
|
return nil
|
|
}
|
|
|
|
// TestRoundTrip_CustomHeader verifies that custom header fields survive the
|
|
// full sign-decode-verify round-trip when using a custom SignableJWT type.
|
|
func TestRoundTrip_CustomHeader(t *testing.T) {
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
verifier := signer.Verifier()
|
|
|
|
now := time.Now()
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://example.com",
|
|
Sub: "user-1",
|
|
Aud: jwt.Listish{"api"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
}
|
|
|
|
// Build and sign a DPoP JWT with a custom nonce header.
|
|
dpop := &dpopJWT{Header: dpopHeader{
|
|
RFCHeader: jwt.RFCHeader{Typ: "dpop+jwt"},
|
|
Nonce: "server-nonce-abc",
|
|
}}
|
|
if err := dpop.SetClaims(claims); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := signer.SignJWT(dpop); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tokenStr, err := jwt.Encode(dpop)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Verify the signature with the standard Decode path.
|
|
jws, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("Decode: %v", err)
|
|
}
|
|
if err := verifier.Verify(jws); err != nil {
|
|
t.Fatalf("Verify: %v", err)
|
|
}
|
|
|
|
// Recover the custom header via DecodeRaw + UnmarshalHeader.
|
|
raw, err := jwt.DecodeRaw(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("DecodeRaw: %v", err)
|
|
}
|
|
var gotHdr dpopHeader
|
|
if err := raw.UnmarshalHeader(&gotHdr); err != nil {
|
|
t.Fatalf("UnmarshalHeader: %v", err)
|
|
}
|
|
|
|
if gotHdr.Nonce != "server-nonce-abc" {
|
|
t.Errorf("nonce: got %q, want %q", gotHdr.Nonce, "server-nonce-abc")
|
|
}
|
|
if gotHdr.Typ != "dpop+jwt" {
|
|
t.Errorf("typ: got %q, want %q", gotHdr.Typ, "dpop+jwt")
|
|
}
|
|
if gotHdr.Alg != "EdDSA" {
|
|
t.Errorf("alg: got %q, want %q", gotHdr.Alg, "EdDSA")
|
|
}
|
|
if gotHdr.KID != pk.KID {
|
|
t.Errorf("kid: got %q, want %q", gotHdr.KID, pk.KID)
|
|
}
|
|
}
|
|
|
|
// TestRoundTrip_ExpiredTokenRejection signs a token with a past exp, verifies
|
|
// that Decode+Verify succeeds (signature is valid) but Validate rejects it.
|
|
func TestRoundTrip_ExpiredTokenRejection(t *testing.T) {
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://idp.example.com",
|
|
Sub: "user-1",
|
|
Aud: jwt.Listish{"app"},
|
|
Exp: now.Add(-time.Hour).Unix(), // expired 1 hour ago
|
|
IAt: now.Add(-2 * time.Hour).Unix(),
|
|
AuthTime: now.Add(-2 * time.Hour).Unix(),
|
|
AzP: "app",
|
|
}
|
|
|
|
tokenStr, err := signer.SignToString(claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Decode + Verify should succeed - the signature is valid.
|
|
decoded, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("Decode: %v", err)
|
|
}
|
|
verifier := signer.Verifier()
|
|
if err := verifier.Verify(decoded); err != nil {
|
|
t.Fatalf("Verify should succeed for expired token: %v", err)
|
|
}
|
|
|
|
var got jwt.TokenClaims
|
|
if err := decoded.UnmarshalClaims(&got); err != nil {
|
|
t.Fatalf("UnmarshalClaims: %v", err)
|
|
}
|
|
|
|
// Validate should fail with ErrAfterExp.
|
|
v := jwt.NewIDTokenValidator(
|
|
[]string{"https://idp.example.com"},
|
|
[]string{"app"},
|
|
[]string{"app"},
|
|
0,
|
|
)
|
|
err = v.Validate(nil, &got, time.Now())
|
|
if err == nil {
|
|
t.Fatal("expected validation error for expired token")
|
|
}
|
|
if !errors.Is(err, jwt.ErrAfterExp) {
|
|
t.Fatalf("expected ErrAfterExp, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestRoundTrip_WrongAudienceRejection signs a token with one audience and
|
|
// validates against a different audience, expecting rejection.
|
|
func TestRoundTrip_WrongAudienceRejection(t *testing.T) {
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://idp.example.com",
|
|
Sub: "user-1",
|
|
Aud: jwt.Listish{"app-A"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
AuthTime: now.Unix(),
|
|
AzP: "app-A",
|
|
}
|
|
|
|
tokenStr, err := signer.SignToString(claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Decode + Verify succeeds.
|
|
verifier := signer.Verifier()
|
|
jws, err := verifier.VerifyJWT(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("VerifyJWT: %v", err)
|
|
}
|
|
|
|
var got jwt.TokenClaims
|
|
if err := jws.UnmarshalClaims(&got); err != nil {
|
|
t.Fatalf("UnmarshalClaims: %v", err)
|
|
}
|
|
|
|
// Validate with a different audience - should fail.
|
|
v := jwt.NewIDTokenValidator(
|
|
[]string{"https://idp.example.com"},
|
|
[]string{"app-B"}, // wrong audience
|
|
[]string{"app-A"},
|
|
0,
|
|
)
|
|
err = v.Validate(nil, &got, time.Now())
|
|
if err == nil {
|
|
t.Fatal("expected validation error for wrong audience")
|
|
}
|
|
if !errors.Is(err, jwt.ErrInvalidClaim) {
|
|
t.Fatalf("expected ErrInvalidClaim for aud mismatch, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestDuplicateKIDRotation verifies that when multiple keys share the same KID
|
|
// (e.g. during key rotation), the verifier tries all matching keys and succeeds
|
|
// if any one of them verifies the signature.
|
|
func TestDuplicateKIDRotation(t *testing.T) {
|
|
oldKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
newKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
// Both keys share the same KID (simulating a rotation where the KID is reused).
|
|
sharedKID := "rotating-key"
|
|
|
|
// Sign a token with the OLD key.
|
|
oldSigner, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, oldKey, sharedKID)})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
claims := goodClaims()
|
|
oldToken, err := oldSigner.SignToString(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Sign a token with the NEW key.
|
|
newSigner, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, newKey, sharedKID)})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
newToken, err := newSigner.SignToString(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Verifier has both keys under the same KID.
|
|
verifier, err := jwt.NewVerifier([]jwt.PublicKey{
|
|
{Key: &oldKey.PublicKey, KID: sharedKID},
|
|
{Key: &newKey.PublicKey, KID: sharedKID},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Both tokens should verify successfully.
|
|
for _, tt := range []struct {
|
|
name string
|
|
token string
|
|
}{
|
|
{"old key", oldToken},
|
|
{"new key", newToken},
|
|
} {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
parsed, err := jwt.Decode(tt.token)
|
|
if err != nil {
|
|
t.Fatalf("Decode: %v", err)
|
|
}
|
|
if err := verifier.Verify(parsed); err != nil {
|
|
t.Fatalf("Verify should succeed for %s: %v", tt.name, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestNoKIDTriesAllKeys verifies that when a token has no KID, all verifier
|
|
// keys are tried and the first successful verification wins.
|
|
func TestNoKIDTriesAllKeys(t *testing.T) {
|
|
wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
rightKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
// Sign with rightKey using SignRaw with a header that has no KID.
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, rightKey, "any")})
|
|
hdr := &jwt.RFCHeader{} // no KID, no typ
|
|
payloadJSON := []byte(`{"sub":"user-1"}`)
|
|
raw, err := signer.SignRaw(hdr, payloadJSON)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Reconstruct as compact token: protected.payload.signature
|
|
token := string(raw.GetProtected()) + "." + string(raw.GetPayload()) +
|
|
"." + base64.RawURLEncoding.EncodeToString(raw.GetSignature())
|
|
|
|
// Verifier has wrongKey first, then rightKey.
|
|
verifier, err := jwt.NewVerifier([]jwt.PublicKey{
|
|
{Key: &wrongKey.PublicKey, KID: "wrong"},
|
|
{Key: &rightKey.PublicKey, KID: "right"},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
parsed, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if parsed.GetHeader().KID != "" {
|
|
t.Fatalf("token should have no KID, got %q", parsed.GetHeader().KID)
|
|
}
|
|
if err := verifier.Verify(parsed); err != nil {
|
|
t.Fatalf("Verify should try all keys when token has no KID: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestEmptyKIDTokenEmptyKIDKey verifies that when both the token and a
|
|
// verifier key have empty KIDs, verification succeeds (all keys are tried).
|
|
func TestEmptyKIDTokenEmptyKIDKey(t *testing.T) {
|
|
rightKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
// Sign with SignRaw to produce a token with no KID.
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, rightKey, "any")})
|
|
raw, err := signer.SignRaw(&jwt.RFCHeader{}, []byte(`{"sub":"user-1"}`))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
token := string(raw.GetProtected()) + "." + string(raw.GetPayload()) +
|
|
"." + base64.RawURLEncoding.EncodeToString(raw.GetSignature())
|
|
|
|
// Verifier key also has empty KID.
|
|
verifier, err := jwt.NewVerifier([]jwt.PublicKey{
|
|
{Key: &rightKey.PublicKey, KID: ""},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
parsed, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := verifier.Verify(parsed); err != nil {
|
|
t.Fatalf("empty KID token + empty KID key should verify: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestKIDTokenEmptyKIDKey verifies that when the token has a KID but the
|
|
// verifier key has an empty KID, the key is not a candidate (no match).
|
|
func TestKIDTokenEmptyKIDKey(t *testing.T) {
|
|
rightKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
// Sign normally -- token gets a KID from the signer.
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, rightKey, "my-kid")})
|
|
claims := goodClaims()
|
|
token, _ := signer.SignToString(&claims)
|
|
|
|
// Verifier has the same key material but with an empty KID.
|
|
verifier, err := jwt.NewVerifier([]jwt.PublicKey{
|
|
{Key: &rightKey.PublicKey, KID: ""},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
parsed, err := jwt.Decode(token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = verifier.Verify(parsed)
|
|
if !errors.Is(err, jwt.ErrUnknownKID) {
|
|
t.Fatalf("token KID %q should not match empty-KID key, expected ErrUnknownKID, got: %v",
|
|
parsed.GetHeader().KID, err)
|
|
}
|
|
}
|
|
|
|
// TestMultiKeyVerifier verifies that a Verifier with keys of different algorithms
|
|
// correctly selects the right key by KID when verifying tokens.
|
|
func TestMultiKeyVerifier(t *testing.T) {
|
|
ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
|
edPub, edPriv, _ := ed25519.GenerateKey(rand.Reader)
|
|
|
|
ecPK := mustPK(t, ecKey, "ec-key")
|
|
rsaPK := mustPK(t, rsaKey, "rsa-key")
|
|
edPK := mustPK(t, edPriv, "ed-key")
|
|
|
|
// Create a verifier with all three public keys.
|
|
verifier, err := jwt.NewVerifier([]jwt.PublicKey{
|
|
{Key: &ecKey.PublicKey, KID: "ec-key"},
|
|
{Key: &rsaKey.PublicKey, KID: "rsa-key"},
|
|
{Key: edPub, KID: "ed-key"},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Sign a token with each key type and verify the multi-key verifier picks the right one.
|
|
for _, tt := range []struct {
|
|
name string
|
|
pk *jwt.PrivateKey
|
|
alg string
|
|
}{
|
|
{"EC/ES256", ecPK, "ES256"},
|
|
{"RSA/RS256", rsaPK, "RS256"},
|
|
{"Ed25519/EdDSA", edPK, "EdDSA"},
|
|
} {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{tt.pk})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
claims := goodClaims()
|
|
tokenStr, err := signer.SignToString(&claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
parsed, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatalf("Decode: %v", err)
|
|
}
|
|
if parsed.GetHeader().Alg != tt.alg {
|
|
t.Fatalf("alg: got %s, want %s", parsed.GetHeader().Alg, tt.alg)
|
|
}
|
|
if err := verifier.Verify(parsed); err != nil {
|
|
t.Fatalf("Verify failed for %s: %v", tt.alg, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAudienceSingleString verifies that a single-string "aud" claim
|
|
// (RFC 7519 §4.1.3) is correctly unmarshaled as a single-element Audience.
|
|
func TestAudienceSingleString(t *testing.T) {
|
|
pk, err := jwt.NewPrivateKey()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://example.com",
|
|
Sub: "user-1",
|
|
Aud: jwt.Listish{"single-aud"}, // single element
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
AuthTime: now.Unix(),
|
|
}
|
|
|
|
tokenStr, err := signer.SignToString(claims)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Verify the wire format uses a string (not an array) for single-element aud.
|
|
decoded, err := jwt.Decode(tokenStr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var rawPayload map[string]json.RawMessage
|
|
payloadBytes, _ := base64.RawURLEncoding.DecodeString(string(decoded.GetPayload()))
|
|
if err := json.Unmarshal(payloadBytes, &rawPayload); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
audRaw := string(rawPayload["aud"])
|
|
if audRaw[0] == '[' {
|
|
t.Errorf("single-element aud should be a string, got array: %s", audRaw)
|
|
}
|
|
|
|
// Unmarshal and validate.
|
|
var got jwt.TokenClaims
|
|
if err := decoded.UnmarshalClaims(&got); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(got.Aud) != 1 || got.Aud[0] != "single-aud" {
|
|
t.Errorf("aud: got %v, want [single-aud]", got.Aud)
|
|
}
|
|
|
|
// Also test unmarshaling from a manually constructed single-string aud.
|
|
singleJSON := []byte(`"just-one"`)
|
|
var aud jwt.Listish
|
|
if err := json.Unmarshal(singleJSON, &aud); err != nil {
|
|
t.Fatalf("Unmarshal single-string aud: %v", err)
|
|
}
|
|
if len(aud) != 1 || aud[0] != "just-one" {
|
|
t.Errorf("single-string unmarshal: got %v, want [just-one]", aud)
|
|
}
|
|
|
|
// And array form.
|
|
arrayJSON := []byte(`["a","b"]`)
|
|
var aud2 jwt.Listish
|
|
if err := json.Unmarshal(arrayJSON, &aud2); err != nil {
|
|
t.Fatalf("Unmarshal array aud: %v", err)
|
|
}
|
|
if len(aud2) != 2 || aud2[0] != "a" || aud2[1] != "b" {
|
|
t.Errorf("array unmarshal: got %v, want [a b]", aud2)
|
|
}
|
|
}
|
|
|
|
// TestVerifyTamperedPayload confirms that a tampered payload (modified after signing)
|
|
// is rejected by signature verification.
|
|
func TestVerifyTamperedPayload(t *testing.T) {
|
|
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")})
|
|
|
|
claims := goodClaims()
|
|
token, _ := signer.SignToString(&claims)
|
|
|
|
verifier := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"})
|
|
|
|
// Tamper with the payload: change the sub claim.
|
|
parts := strings.SplitN(token, ".", 3)
|
|
payloadBytes, _ := base64.RawURLEncoding.DecodeString(parts[1])
|
|
tampered := strings.Replace(string(payloadBytes), claims.Sub, "evil-user", 1)
|
|
parts[1] = base64.RawURLEncoding.EncodeToString([]byte(tampered))
|
|
tamperedToken := strings.Join(parts, ".")
|
|
|
|
parsed, err := jwt.Decode(tamperedToken)
|
|
if err != nil {
|
|
t.Fatalf("Decode should succeed for well-formed tampered token: %v", err)
|
|
}
|
|
if err := verifier.Verify(parsed); err == nil {
|
|
t.Fatal("expected Verify to fail for tampered payload")
|
|
} else if !errors.Is(err, jwt.ErrSignatureInvalid) {
|
|
t.Fatalf("expected ErrSignatureInvalid, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestValidationErrorAnnotation verifies that time-related validation errors
|
|
// include the server time annotation.
|
|
func TestValidationErrorAnnotation(t *testing.T) {
|
|
now := time.Now()
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://example.com",
|
|
Sub: "user-1",
|
|
Aud: jwt.Listish{"app"},
|
|
Exp: now.Add(-time.Hour).Unix(), // expired
|
|
IAt: now.Unix(),
|
|
AuthTime: now.Unix(),
|
|
}
|
|
|
|
v := jwt.NewIDTokenValidator(
|
|
[]string{"https://example.com"},
|
|
[]string{"app"},
|
|
nil,
|
|
0,
|
|
)
|
|
err := v.Validate(nil, claims, now)
|
|
if err == nil {
|
|
t.Fatal("expected validation error")
|
|
}
|
|
if !strings.Contains(err.Error(), "server time") {
|
|
t.Errorf("time error should include server time annotation: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestValidateThreadsHeaderErrors verifies that errors from header validation
|
|
// (IsAllowedTyp) are preserved when threaded into Validate.
|
|
func TestValidateThreadsHeaderErrors(t *testing.T) {
|
|
now := time.Now()
|
|
claims := &jwt.TokenClaims{
|
|
Iss: "https://example.com",
|
|
Sub: "user-1",
|
|
Aud: jwt.Listish{"app"},
|
|
Exp: now.Add(time.Hour).Unix(),
|
|
IAt: now.Unix(),
|
|
AuthTime: now.Unix(),
|
|
}
|
|
|
|
v := jwt.NewIDTokenValidator(
|
|
[]string{"https://example.com"},
|
|
[]string{"app"},
|
|
nil,
|
|
0,
|
|
)
|
|
|
|
// Simulate a header check failure.
|
|
hdr := jwt.RFCHeader{Typ: "at+jwt"} // wrong typ for ID token
|
|
var errs []error
|
|
errs = hdr.IsAllowedTyp(errs, []string{"JWT"})
|
|
if len(errs) == 0 {
|
|
t.Fatal("expected IsAllowedTyp to produce an error")
|
|
}
|
|
|
|
// Thread header errors into Validate - claims are valid, so only header error remains.
|
|
err := v.Validate(errs, claims, now)
|
|
if err == nil {
|
|
t.Fatal("expected error from threaded header validation")
|
|
}
|
|
if !strings.Contains(err.Error(), "typ") {
|
|
t.Errorf("expected typ error in output: %v", err)
|
|
}
|
|
}
|