diff --git a/auth/jwt/edge_test.go b/auth/jwt/edge_test.go new file mode 100644 index 0000000..7565c54 --- /dev/null +++ b/auth/jwt/edge_test.go @@ -0,0 +1,714 @@ +package jwt_test + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "math/big" + "testing" + + "github.com/therootcompany/golib/auth/jwt" +) + +// --- JWK parsing edge cases --- + +func TestJWKMissingKty(t *testing.T) { + data := []byte(`{"kid":"test"}`) + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for missing kty") + } + if !errors.Is(err, jwt.ErrUnsupportedKeyType) { + t.Fatalf("expected ErrUnsupportedKeyType, got: %v", err) + } +} + +func TestJWKUnknownKty(t *testing.T) { + data := []byte(`{"kty":"MAGIC","kid":"test"}`) + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for unknown kty") + } + if !errors.Is(err, jwt.ErrUnsupportedKeyType) { + t.Fatalf("expected ErrUnsupportedKeyType, got: %v", err) + } +} + +func TestPrivateKeyMissingD(t *testing.T) { + // Valid EC public key but no "d" field + ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + pub := jwt.PublicKey{Key: &ecKey.PublicKey, KID: "test"} + pubJSON, err := json.Marshal(pub) + if err != nil { + t.Fatal(err) + } + + var pk jwt.PrivateKey + err = pk.UnmarshalJSON(pubJSON) + if err == nil { + t.Fatal("expected error for missing d field") + } + if !errors.Is(err, jwt.ErrMissingKeyData) { + t.Fatalf("expected ErrMissingKeyData, got: %v", err) + } +} + +func TestRSAKeyTooSmall(t *testing.T) { + tests := []struct { + name string + nLen int // modulus byte length + }{ + {"all_zeros_1024bit", 128}, // 1024 bits of zeros - Size() returns 0 + {"all_zeros_64byte", 64}, // 512 bits of zeros + {"all_zeros_1byte", 1}, // 8 bits of zeros + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := make([]byte, tt.nLen) // all zeros + data, _ := json.Marshal(map[string]string{ + "kty": "RSA", + "kid": "small", + "n": base64.RawURLEncoding.EncodeToString(n), + "e": "AQAB", + }) + + var decoded jwt.PublicKey + err := decoded.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for all-zeros RSA modulus") + } + if !errors.Is(err, jwt.ErrKeyTooSmall) { + t.Fatalf("expected ErrKeyTooSmall, got: %v", err) + } + }) + } +} + +func TestRSADegenerateExponent(t *testing.T) { + tests := []struct { + name string + e int + }{ + {"exponent_0", 0}, + {"exponent_1", 1}, + {"exponent_2", 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Build a JWK with a degenerate exponent + n := make([]byte, 256) // 2048-bit modulus + n[0] = 1 // non-zero MSB + eBytes := big.NewInt(int64(tt.e)).Bytes() + data, _ := json.Marshal(map[string]string{ + "kty": "RSA", + "kid": "bad-e", + "n": base64.RawURLEncoding.EncodeToString(n), + "e": base64.RawURLEncoding.EncodeToString(eBytes), + }) + + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for degenerate RSA exponent") + } + if !errors.Is(err, jwt.ErrInvalidKey) { + t.Fatalf("expected ErrInvalidKey, got: %v", err) + } + }) + } +} + +func TestRSAEmptyFields(t *testing.T) { + tests := []struct { + name string + jwk map[string]string + }{ + {"empty_n", map[string]string{"kty": "RSA", "n": "", "e": "AQAB"}}, + {"empty_e", map[string]string{"kty": "RSA", "n": "AQAB", "e": ""}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.jwk) + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for empty RSA field") + } + if !errors.Is(err, jwt.ErrInvalidKey) { + t.Fatalf("expected ErrInvalidKey, got: %v", err) + } + }) + } +} + +func TestEd25519WrongKeySize(t *testing.T) { + tests := []struct { + name string + size int + }{ + {"too_short_31", 31}, + {"too_long_33", 33}, + {"zero_length", 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + x := make([]byte, tt.size) + data, _ := json.Marshal(map[string]string{ + "kty": "OKP", + "crv": "Ed25519", + "x": base64.RawURLEncoding.EncodeToString(x), + }) + + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatalf("expected error for Ed25519 key size %d", tt.size) + } + if !errors.Is(err, jwt.ErrInvalidKey) { + t.Fatalf("expected ErrInvalidKey, got: %v", err) + } + }) + } +} + +func TestEd25519AllZerosKey(t *testing.T) { + // All-zeros is a valid encoding but represents a low-order point. + // The key should parse (it's 32 bytes), and signing should work + // but verification with the wrong key should fail. + x := make([]byte, ed25519.PublicKeySize) // all zeros + data, _ := json.Marshal(map[string]string{ + "kty": "OKP", + "crv": "Ed25519", + "kid": "zero-key", + "x": base64.RawURLEncoding.EncodeToString(x), + }) + + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err != nil { + t.Fatalf("all-zeros Ed25519 key should parse: %v", err) + } + + // Verify with this key should reject any real signature + realKey, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{realKey}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + // Parse and try to verify with the all-zeros key + jws, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatal(err) + } + + // Change the kid in the header to match our zero key + zeroVerifier, _ := jwt.NewVerifier([]jwt.PublicKey{pk}) + // The KID won't match, but let's verify that the system handles it + err = zeroVerifier.Verify(jws) + if err == nil { + t.Fatal("expected verification to fail with wrong key") + } +} + +func TestOKPWrongCrv(t *testing.T) { + data, _ := json.Marshal(map[string]string{ + "kty": "OKP", + "crv": "X25519", + "x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)), + }) + + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for X25519 crv") + } + if !errors.Is(err, jwt.ErrUnsupportedCurve) { + t.Fatalf("expected ErrUnsupportedCurve, got: %v", err) + } +} + +func TestOKPPrivateWrongCrv(t *testing.T) { + data, _ := json.Marshal(map[string]string{ + "kty": "OKP", + "crv": "X25519", + "x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)), + "d": base64.RawURLEncoding.EncodeToString(make([]byte, 32)), + }) + + var pk jwt.PrivateKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for X25519 crv on private key") + } + if !errors.Is(err, jwt.ErrUnsupportedCurve) { + t.Fatalf("expected ErrUnsupportedCurve, got: %v", err) + } +} + +func TestECCoordinatesTooLong(t *testing.T) { + ci := struct { + keySize int + crv string + }{32, "P-256"} // P-256 has 32-byte coordinates + + tests := []struct { + name string + xSize int + ySize int + }{ + {"x_too_long", ci.keySize + 1, ci.keySize}, + {"y_too_long", ci.keySize, ci.keySize + 1}, + {"both_too_long", ci.keySize + 1, ci.keySize + 1}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(map[string]string{ + "kty": "EC", + "crv": ci.crv, + "x": base64.RawURLEncoding.EncodeToString(make([]byte, tt.xSize)), + "y": base64.RawURLEncoding.EncodeToString(make([]byte, tt.ySize)), + }) + + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for oversized EC coordinates") + } + if !errors.Is(err, jwt.ErrInvalidKey) { + t.Fatalf("expected ErrInvalidKey, got: %v", err) + } + }) + } +} + +func TestECUnsupportedCurve(t *testing.T) { + data, _ := json.Marshal(map[string]string{ + "kty": "EC", + "crv": "P-192", + "x": base64.RawURLEncoding.EncodeToString(make([]byte, 24)), + "y": base64.RawURLEncoding.EncodeToString(make([]byte, 24)), + }) + + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for P-192 curve") + } + if !errors.Is(err, jwt.ErrUnsupportedCurve) { + t.Fatalf("expected ErrUnsupportedCurve, got: %v", err) + } +} + +func TestEd25519PrivateWrongSeedSize(t *testing.T) { + tests := []struct { + name string + size int + }{ + {"seed_too_short", 31}, + {"seed_too_long", 33}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(map[string]string{ + "kty": "OKP", + "crv": "Ed25519", + "x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)), + "d": base64.RawURLEncoding.EncodeToString(make([]byte, tt.size)), + }) + + var pk jwt.PrivateKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatalf("expected error for seed size %d", tt.size) + } + if !errors.Is(err, jwt.ErrInvalidKey) { + t.Fatalf("expected ErrInvalidKey, got: %v", err) + } + }) + } +} + +func TestInvalidBase64Fields(t *testing.T) { + tests := []struct { + name string + jwk map[string]string + }{ + {"invalid_rsa_n", map[string]string{"kty": "RSA", "n": "!!!invalid!!!", "e": "AQAB"}}, + {"invalid_rsa_e", map[string]string{"kty": "RSA", "n": "AQAB", "e": "!!!"}}, + {"invalid_ec_x", map[string]string{"kty": "EC", "crv": "P-256", "x": "!!!", "y": "AAAA"}}, + {"invalid_ec_y", map[string]string{"kty": "EC", "crv": "P-256", "x": "AAAA", "y": "!!!"}}, + {"invalid_okp_x", map[string]string{"kty": "OKP", "crv": "Ed25519", "x": "!!!"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.jwk) + var pk jwt.PublicKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for invalid base64") + } + if !errors.Is(err, jwt.ErrInvalidKey) { + t.Fatalf("expected ErrInvalidKey, got: %v", err) + } + }) + } +} + +func TestInvalidBase64PrivateFields(t *testing.T) { + tests := []struct { + name string + jwk map[string]string + }{ + {"invalid_ec_d", map[string]string{ + "kty": "EC", "crv": "P-256", + "x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)), + "y": base64.RawURLEncoding.EncodeToString(make([]byte, 32)), + "d": "!!!invalid!!!", + }}, + {"invalid_rsa_d", map[string]string{ + "kty": "RSA", + "n": base64.RawURLEncoding.EncodeToString(make([]byte, 256)), + "e": "AQAB", + "d": "!!!invalid!!!", + }}, + {"invalid_okp_d", map[string]string{ + "kty": "OKP", "crv": "Ed25519", + "x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)), + "d": "!!!invalid!!!", + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.jwk) + var pk jwt.PrivateKey + err := pk.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for invalid base64 in private field") + } + if !errors.Is(err, jwt.ErrInvalidKey) { + t.Fatalf("expected ErrInvalidKey, got: %v", err) + } + }) + } +} + +// --- Signature verification edge cases --- + +func TestVerifyWrongKeyTypeForAlg(t *testing.T) { + // Sign with Ed25519, then try to verify with an RSA key + // that has the same KID + edKey, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{edKey}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + // Create an RSA key with the same KID + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + rsaPub := jwt.PublicKey{ + Key: &rsaKey.PublicKey, + KID: edKey.KID, // same KID + } + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{rsaPub}) + jws, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatal(err) + } + err = verifier.Verify(jws) + if err == nil { + t.Fatal("expected error: wrong key type for EdDSA alg") + } + if !errors.Is(err, jwt.ErrAlgConflict) { + t.Fatalf("expected ErrAlgConflict, got: %v", err) + } +} + +func TestVerifyZeroLengthSignature(t *testing.T) { + // Create a valid token then replace the signature with empty + key, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{key}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + // Replace signature with empty + parts := splitToken(tokenStr) + tampered := parts[0] + "." + parts[1] + "." + jws, err := jwt.Decode(tampered) + if err != nil { + t.Fatal(err) + } + + verifier := signer.Verifier() + err = verifier.Verify(jws) + if err == nil { + t.Fatal("expected error for zero-length signature") + } + if !errors.Is(err, jwt.ErrSignatureInvalid) { + t.Fatalf("expected ErrSignatureInvalid, got: %v", err) + } +} + +func TestVerifyECDSAWrongSigLength(t *testing.T) { + // Sign with P-256, verify with correct key but tampered sig length + ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + pk, err := jwt.FromPrivateKey(ecKey, "") + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + // Replace signature with wrong-length bytes + parts := splitToken(tokenStr) + wrongSig := base64.RawURLEncoding.EncodeToString([]byte("short")) + tampered := parts[0] + "." + parts[1] + "." + wrongSig + + jws, err := jwt.Decode(tampered) + if err != nil { + t.Fatal(err) + } + + verifier := signer.Verifier() + err = verifier.Verify(jws) + if err == nil { + t.Fatal("expected error for wrong ECDSA signature length") + } + if !errors.Is(err, jwt.ErrSignatureInvalid) { + t.Fatalf("expected ErrSignatureInvalid, got: %v", err) + } +} + +func TestVerifyUnsupportedAlg(t *testing.T) { + // Build a token with an unsupported alg header + key, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{key}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + // Tamper the header to use an unsupported alg + header := map[string]string{"alg": "HS256", "kid": key.KID, "typ": "JWT"} + headerJSON, _ := json.Marshal(header) + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + + parts := splitToken(tokenStr) + tampered := headerB64 + "." + parts[1] + "." + parts[2] + + jws, err := jwt.Decode(tampered) + if err != nil { + t.Fatal(err) + } + + verifier := signer.Verifier() + err = verifier.Verify(jws) + if err == nil { + t.Fatal("expected error for unsupported alg") + } + if !errors.Is(err, jwt.ErrUnsupportedAlg) { + t.Fatalf("expected ErrUnsupportedAlg, got: %v", err) + } +} + +// TestVerifyMissingKID verifies that tokens without a KID header try all keys +// via fallthrough. A tampered header (different signing input) still fails +// with ErrSignatureInvalid, but the lookup itself does not reject the token. +func TestVerifyMissingKID(t *testing.T) { + key, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{key}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + // Tamper header to remove kid - signing input changes, so sig will be invalid. + header := map[string]string{"alg": "EdDSA", "typ": "JWT"} + headerJSON, _ := json.Marshal(header) + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + + parts := splitToken(tokenStr) + tampered := headerB64 + "." + parts[1] + "." + parts[2] + + jws, err := jwt.Decode(tampered) + if err != nil { + t.Fatal(err) + } + + verifier := signer.Verifier() + err = verifier.Verify(jws) + if err == nil { + t.Fatal("expected error for tampered header") + } + // With no KID, all keys are tried - fails with signature invalid, not missing KID. + if !errors.Is(err, jwt.ErrSignatureInvalid) { + t.Fatalf("expected ErrSignatureInvalid, got: %v", err) + } +} + +func TestDecodeEmptyToken(t *testing.T) { + _, err := jwt.Decode("") + if err == nil { + t.Fatal("expected error for empty token") + } + if !errors.Is(err, jwt.ErrMalformedToken) { + t.Fatalf("expected ErrMalformedToken, got: %v", err) + } +} + +func TestDecodeOnePart(t *testing.T) { + _, err := jwt.Decode("justonepart") + if err == nil { + t.Fatal("expected error for single-part token") + } + if !errors.Is(err, jwt.ErrMalformedToken) { + t.Fatalf("expected ErrMalformedToken, got: %v", err) + } +} + +func TestDecodeFourParts(t *testing.T) { + _, err := jwt.Decode("a.b.c.d") + if err == nil { + t.Fatal("expected error for four-part token") + } + if !errors.Is(err, jwt.ErrMalformedToken) { + t.Fatalf("expected ErrMalformedToken, got: %v", err) + } +} + +// --- RSA private key edge cases --- + +func TestRSAPrivateKeyTooSmall(t *testing.T) { + tests := []struct { + name string + nLen int + }{ + {"all_zeros_1024bit", 128}, + {"all_zeros_64byte", 64}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n := make([]byte, tt.nLen) // all zeros + d := make([]byte, tt.nLen) // all zeros + data, _ := json.Marshal(map[string]string{ + "kty": "RSA", + "kid": "small", + "n": base64.RawURLEncoding.EncodeToString(n), + "e": "AQAB", + "d": base64.RawURLEncoding.EncodeToString(d), + }) + + var decoded jwt.PrivateKey + err := decoded.UnmarshalJSON(data) + if err == nil { + t.Fatal("expected error for all-zeros RSA private key") + } + if !errors.Is(err, jwt.ErrKeyTooSmall) { + t.Fatalf("expected ErrKeyTooSmall, got: %v", err) + } + }) + } +} + +// --- Thumbprint edge cases --- + +func TestThumbprintNilKey(t *testing.T) { + pk := jwt.PublicKey{} // nil Key + _, err := pk.Thumbprint() + if err == nil { + t.Fatal("expected error for nil key thumbprint") + } + if !errors.Is(err, jwt.ErrUnsupportedKeyType) { + t.Fatalf("expected ErrUnsupportedKeyType, got: %v", err) + } +} + +// splitToken splits a compact JWT into its three dot-separated parts. +func splitToken(s string) [3]string { + var parts [3]string + idx1 := 0 + for i, c := range s { + if c == '.' { + if idx1 == 0 { + parts[0] = s[:i] + idx1 = i + 1 + } else { + parts[1] = s[idx1:i] + parts[2] = s[i+1:] + return parts + } + } + } + return parts +} diff --git a/auth/jwt/jwt_test.go b/auth/jwt/jwt_test.go new file mode 100644 index 0000000..199a8cf --- /dev/null +++ b/auth/jwt/jwt_test.go @@ -0,0 +1,2459 @@ +// Copyright 2026 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +package jwt_test + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "strings" + "testing" + "time" + + "crypto" + + "github.com/therootcompany/golib/auth/jwt" +) + +func mustPK(t testing.TB, signer crypto.Signer, kid string) *jwt.PrivateKey { + t.Helper() + pk, err := jwt.FromPrivateKey(signer, kid) + if err != nil { + t.Fatal(err) + } + return pk +} + +// AppClaims embeds TokenClaims and adds application-specific fields. +// +// Because TokenClaims is embedded, AppClaims satisfies Claims +// for free via Go's method promotion - no interface to implement. +type AppClaims struct { + jwt.TokenClaims + Email string `json:"email"` + Roles []string `json:"roles"` +} + +// validateAppClaims is a plain function - not a method satisfying an interface. +// It demonstrates the Decode+Verify pattern: custom validation logic lives here, +// calling Validator.Validate and adding app-specific checks. +func validateAppClaims(c AppClaims, v *jwt.Validator, now time.Time) error { + var errs []error + if err := v.Validate(nil, &c, now); err != nil { + errs = append(errs, err) + } + if c.Email == "" { + errs = append(errs, errors.New("missing email claim")) + } + return errors.Join(errs...) +} + +func goodClaims() AppClaims { + now := time.Now() + return AppClaims{ + TokenClaims: jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user123", + Aud: jwt.Listish{"myapp"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + AuthTime: now.Unix(), + AMR: []string{"pwd"}, + JTI: "abc123", + AzP: "myapp", + Nonce: "nonce1", + }, + Email: "user@example.com", + Roles: []string{"admin"}, + } +} + +// goodValidator configures the ID token validator with iss set to "https://example.com". +// Iss checking is now the Validator's responsibility, not the Verifier's. +func goodValidator() *jwt.Validator { + return jwt.NewIDTokenValidator( + []string{"https://example.com"}, + []string{"myapp"}, + []string{"myapp"}, + 0, + ) +} + +func goodVerifier(pub jwt.PublicKey) *jwt.Verifier { + v, err := jwt.NewVerifier([]jwt.PublicKey{pub}) + if err != nil { + panic(err) + } + return v +} + +// TestRoundTrip is the primary happy path using ES256. +// It demonstrates the full Verify / UnmarshalClaims / Validate flow. +func TestRoundTrip(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := signer.Sign(&claims) + if err != nil { + t.Fatal(err) + } + if jws.GetHeader().Alg != "ES256" { + t.Fatalf("expected ES256, got %s", jws.GetHeader().Alg) + } + + token, err := jwt.Encode(jws) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + + iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"}) + + jws2, err := iss.VerifyJWT(token) + if err != nil { + t.Fatalf("VerifyJWT failed: %v", err) + } + if jws2.GetHeader().Alg != "ES256" { + t.Errorf("expected ES256 alg in jws, got %s", jws2.GetHeader().Alg) + } + + var decoded AppClaims + if err := jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { + t.Fatalf("claim validation failed: %v", err) + } + // Direct field access - no type assertion needed. + if decoded.Email != claims.Email { + t.Errorf("email: got %s, want %s", decoded.Email, claims.Email) + } +} + +// TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256. +func TestRoundTripRS256(t *testing.T) { + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := signer.Sign(&claims) + if err != nil { + t.Fatal(err) + } + if jws.GetHeader().Alg != "RS256" { + t.Fatalf("expected RS256, got %s", jws.GetHeader().Alg) + } + + token, err := jwt.Encode(jws) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + + iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"}) + + jws2, err := iss.VerifyJWT(token) + if err != nil { + t.Fatalf("VerifyJWT failed: %v", err) + } + var decoded AppClaims + if err := jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { + t.Fatalf("claim validation failed: %v", err) + } +} + +// TestRoundTripEdDSA exercises Ed25519 / EdDSA (RFC 8037). +func TestRoundTripEdDSA(t *testing.T) { + pubKeyBytes, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := signer.Sign(&claims) + if err != nil { + t.Fatal(err) + } + if jws.GetHeader().Alg != "EdDSA" { + t.Fatalf("expected EdDSA, got %s", jws.GetHeader().Alg) + } + + token, err := jwt.Encode(jws) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + + iss := goodVerifier(jwt.PublicKey{Key: pubKeyBytes, KID: "key-1"}) + + jws2, err := iss.VerifyJWT(token) + if err != nil { + t.Fatalf("VerifyJWT failed: %v", err) + } + var decoded AppClaims + if err := jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { + t.Fatalf("claim validation failed: %v", err) + } +} + +// TestRoundTripES384 exercises ECDSA P-384 / ES384. +func TestRoundTripES384(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := signer.Sign(&claims) + if err != nil { + t.Fatal(err) + } + if jws.GetHeader().Alg != "ES384" { + t.Fatalf("expected ES384, got %s", jws.GetHeader().Alg) + } + + token, err := jwt.Encode(jws) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + + iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"}) + + jws2, err := iss.VerifyJWT(token) + if err != nil { + t.Fatalf("VerifyJWT failed: %v", err) + } + var decoded AppClaims + if err := jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { + t.Fatalf("claim validation failed: %v", err) + } +} + +// TestRoundTripES512 exercises ECDSA P-521 / ES512. +func TestRoundTripES512(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + jws, err := signer.Sign(&claims) + if err != nil { + t.Fatal(err) + } + if jws.GetHeader().Alg != "ES512" { + t.Fatalf("expected ES512, got %s", jws.GetHeader().Alg) + } + + token, err := jwt.Encode(jws) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + + iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"}) + + jws2, err := iss.VerifyJWT(token) + if err != nil { + t.Fatalf("VerifyJWT failed: %v", err) + } + var decoded AppClaims + if err := jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { + t.Fatalf("claim validation failed: %v", err) + } +} + +// TestDecodeVerifyFlow demonstrates the Decode + Verify + custom validation pattern. +// The caller owns the full validation pipeline. +func TestDecodeVerifyFlow(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) + + claims := goodClaims() + token, _ := signer.SignToString(&claims) + + iss, _ := jwt.NewVerifier([]jwt.PublicKey{{Key: &privKey.PublicKey, KID: "k"}}) + + jws2, err := jwt.Decode(token) + if err != nil { + t.Fatalf("Decode failed: %v", err) + } + if err := iss.Verify(jws2); err != nil { + t.Fatalf("Verify failed: %v", err) + } + + var decoded AppClaims + if err := jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + + if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { + t.Fatalf("Validate failed: %v", err) + } +} + +// TestDecodeReturnsParsedOnSigFailure verifies that Decode returns a non-nil +// *StandardJWS even when the token will later fail signature verification. +// Callers can inspect the header (kid, alg) for routing before calling Verify. +func TestDecodeReturnsParsedOnSigFailure(t *testing.T) { + signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, signingKey, "k")}) + + claims := goodClaims() + token, _ := signer.SignToString(&claims) + + // Verifier has wrong public key - sig verification will fail. + iss, _ := jwt.NewVerifier([]jwt.PublicKey{{Key: &wrongKey.PublicKey, KID: "k"}}) + + // Decode always succeeds for well-formed tokens. + result, err := jwt.Decode(token) + if err != nil { + t.Fatalf("Decode failed: %v", err) + } + if result == nil { + t.Fatal("Decode should return non-nil StandardJWS") + } + if result.GetHeader().KID != "k" { + t.Errorf("expected kid %q, got %q", "k", result.GetHeader().KID) + } + + // Verify should fail with the wrong key. + if err := iss.Verify(result); err == nil { + t.Fatal("expected Verify to fail with wrong key") + } +} + +// TestCustomValidation demonstrates that Validator.Validate is called +// explicitly and custom fields are validated alongside the standard ones. +func TestCustomValidation(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + // Token with empty Email - our custom validator should reject it. + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) + claims := goodClaims() + claims.Email = "" + token, _ := signer.SignToString(&claims) + + iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"}) + jws2, err := jwt.Decode(token) + if err != nil { + t.Fatalf("Decode failed: %v", err) + } + if err := iss.Verify(jws2); err != nil { + t.Fatalf("Verify failed unexpectedly: %v", err) + } + + var decoded AppClaims + _ = jws2.UnmarshalClaims(&decoded) + + err = validateAppClaims(decoded, goodValidator(), time.Now()) + if err == nil { + t.Fatal("expected validation to fail: email is empty") + } + if !strings.Contains(err.Error(), "missing email claim") { + t.Fatalf("expected 'missing email claim' in error: %v", err) + } +} + +// TestNBFValidation confirms that a token with nbf in the future is rejected, +// and that a token with nbf in the past (or absent) is accepted. +func TestNBFValidation(t *testing.T) { + now := time.Now() + + base := AppClaims{ + TokenClaims: jwt.TokenClaims{ + Iss: "https://example.com", + Aud: jwt.Listish{"myapp"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + }, + } + + v := &jwt.Validator{ + Checks: jwt.CheckIss | jwt.CheckAud | jwt.CheckExp | jwt.CheckIAt | jwt.CheckNBf, + Iss: []string{"https://example.com"}, + Aud: []string{"myapp"}, + } + + // No nbf: should pass. + if err := v.Validate(nil, &base, now); err != nil { + t.Fatalf("expected no error without nbf: %v", err) + } + + // nbf in the past: should pass. + pastNBF := base + pastNBF.NBf = now.Add(-time.Hour).Unix() + if err := v.Validate(nil, &pastNBF, now); err != nil { + t.Fatalf("expected no error with past nbf: %v", err) + } + + // nbf in the future: must be rejected. + futureNBF := base + futureNBF.NBf = now.Add(time.Hour).Unix() + err := v.Validate(nil, &futureNBF, now) + if err == nil { + t.Fatal("expected error for future nbf") + } + if !errors.Is(err, jwt.ErrBeforeNBf) { + t.Fatalf("expected ErrBeforeNBf, got: %v", err) + } +} + +// TestVerifyWithoutValidation confirms that Verify + UnmarshalClaims succeeds +// independently of claim validation - the caller decides whether to validate. +func TestVerifyWithoutValidation(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) + c := goodClaims() + token, _ := signer.SignToString(&c) + + iss, _ := jwt.NewVerifier([]jwt.PublicKey{{Key: &privKey.PublicKey, KID: "k"}}) + + jws2, err := iss.VerifyJWT(token) + if err != nil { + t.Fatalf("VerifyJWT failed: %v", err) + } + var claims AppClaims + if err := jws2.UnmarshalClaims(&claims); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + if claims.Email != c.Email { + t.Errorf("claims not unmarshalled: email got %q, want %q", claims.Email, c.Email) + } +} + +// TestVerifierWrongKey confirms that a different key's public key is rejected. +func TestVerifierWrongKey(t *testing.T) { + signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, signingKey, "k")}) + + claims := goodClaims() + token, _ := signer.SignToString(&claims) + + iss := goodVerifier(jwt.PublicKey{Key: &wrongKey.PublicKey, KID: "k"}) + + parsed, err := jwt.Decode(token) + if err != nil { + t.Fatal(err) + } + if err := iss.Verify(parsed); !errors.Is(err, jwt.ErrSignatureInvalid) { + t.Fatalf("expected ErrSignatureInvalid, got: %v", err) + } +} + +// TestVerifierUnknownKid confirms that an unknown kid is rejected. +func TestVerifierUnknownKid(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "unknown-kid")}) + + claims := goodClaims() + token, _ := signer.SignToString(&claims) + + iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "known-kid"}) + + parsed, err := jwt.Decode(token) + if err != nil { + t.Fatal(err) + } + if err := iss.Verify(parsed); !errors.Is(err, jwt.ErrUnknownKID) { + t.Fatalf("expected ErrUnknownKID, got: %v", err) + } +} + +// TestVerifierIssMismatch confirms that a token with a mismatched iss is caught +// by the Validator, not the Verifier. Signature verification succeeds; the iss +// mismatch appears as a soft validation error. +func TestVerifierIssMismatch(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) + + claims := goodClaims() + claims.Iss = "https://evil.example.com" + token, _ := signer.SignToString(&claims) + + iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"}) + + // Decode+Verify succeeds - iss is not checked at the Verifier level. + parsed, err := jwt.Decode(token) + if err != nil { + t.Fatal(err) + } + if err := iss.Verify(parsed); err != nil { + t.Fatalf("Verify should succeed (no iss check): %v", err) + } + + // VerifyJWT + Validate: signature passes but iss validation catches the mismatch. + jws2, err := iss.VerifyJWT(token) + if err != nil { + t.Fatalf("unexpected hard error from VerifyJWT: %v", err) + } + var decoded AppClaims + if err := jws2.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + err = goodValidator().Validate(nil, &decoded, time.Now()) + if err == nil { + t.Fatal("expected validation errors for iss mismatch") + } + if !errors.Is(err, jwt.ErrInvalidClaim) { + t.Fatalf("expected ErrInvalidClaim for iss mismatch, got: %v", err) + } +} + +// TestVerifyTamperedAlg confirms that a tampered alg header ("none") is rejected. +// The token is reconstructed with a replaced protected header; the original +// ES256 signature is kept, making the signing input mismatch detectable. +func TestVerifyTamperedAlg(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) + + claims := goodClaims() + token, _ := signer.SignToString(&claims) + + iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"}) + + // Replace the protected header with one that has alg:"none". + // The original ES256 signature stays - the signing input will mismatch. + noneHeader := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","kid":"k","typ":"JWT"}`)) + parts := strings.SplitN(token, ".", 3) + tamperedToken := noneHeader + "." + parts[1] + "." + parts[2] + + parsed, err := jwt.Decode(tamperedToken) + if err != nil { + t.Fatal(err) + } + if err := iss.Verify(parsed); !errors.Is(err, jwt.ErrUnsupportedAlg) { + t.Fatalf("expected ErrUnsupportedAlg for tampered alg, got: %v", err) + } +} + +// TestSignerRoundTrip verifies the Signer / Sign / Verifier / Verify / Validate flow. +func TestSignerRoundTrip(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k1")}) + if err != nil { + t.Fatal(err) + } + + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + iss := signer.Verifier() + jws, err := iss.VerifyJWT(tokenStr) + if err != nil { + t.Fatalf("VerifyJWT failed: %v", err) + } + var decoded AppClaims + if err := jws.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims failed: %v", err) + } + if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { + t.Fatalf("claim validation failed: %v", err) + } + if decoded.Email != claims.Email { + t.Errorf("email: got %s, want %s", decoded.Email, claims.Email) + } +} + +// TestSignerAutoKID verifies that KID is auto-computed from the key thumbprint +// when PrivateKey.KID is empty. +func TestSignerAutoKID(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "")}) + if err != nil { + t.Fatal(err) + } + + keys := signer.Keys + if len(keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(keys)) + } + if keys[0].KID == "" { + t.Fatal("KID should be auto-computed from thumbprint") + } + + // Token should verify with the auto-KID issuer. + iss := signer.Verifier() + claims := goodClaims() + tokenStr, _ := signer.SignToString(&claims) + + parsed, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatal(err) + } + if err := iss.Verify(parsed); err != nil { + t.Fatalf("Verify failed: %v", err) + } +} + +// TestSignerRoundRobin verifies that signing round-robins across keys and that +// all resulting tokens verify with the combined Verifier. +func TestSignerRoundRobin(t *testing.T) { + key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + key2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ + mustPK(t, key1, "k1"), + mustPK(t, key2, "k2"), + }) + if err != nil { + t.Fatal(err) + } + + iss := signer.Verifier() + v := goodValidator() + + for i := range 4 { + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatalf("Sign[%d] failed: %v", i, err) + } + jws, err := iss.VerifyJWT(tokenStr) + if err != nil { + t.Fatalf("VerifyJWT[%d] failed: %v", i, err) + } + var decoded AppClaims + if err := jws.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims[%d] failed: %v", i, err) + } + if err := v.Validate(nil, &decoded, time.Now()); err != nil { + t.Fatalf("Validate[%d] failed: %v", i, err) + } + } +} + +// TestSignJWTSelectsByKID verifies that when the header already has a KID, +// SignJWT uses that specific key instead of round-robin. +func TestSignJWTSelectsByKID(t *testing.T) { + key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + key2, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ + mustPK(t, key1, "ec256"), + mustPK(t, key2, "ec384"), + }) + if err != nil { + t.Fatal(err) + } + + // Request signing with key2 specifically by setting KID in the header. + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user-1", + Aud: jwt.Listish{"app"}, + Exp: time.Now().Add(time.Hour).Unix(), + IAt: time.Now().Unix(), + } + jws, err := jwt.New(claims) + if err != nil { + t.Fatal(err) + } + jws.SetTyp("JWT") + // Pre-set the KID to select key2. + hdr := jws.GetHeader() + hdr.KID = "ec384" + if err := jws.SetHeader(&hdr); err != nil { + t.Fatal(err) + } + + if err := signer.SignJWT(jws); err != nil { + t.Fatal(err) + } + + // Should have used ES384, not ES256. + if got := jws.GetHeader().Alg; got != "ES384" { + t.Fatalf("alg: got %s, want ES384 (should have selected key2)", got) + } + if got := jws.GetHeader().KID; got != "ec384" { + t.Fatalf("kid: got %s, want ec384", got) + } + + // Verify round-trip. + verifier := signer.Verifier() + tokenStr, err := jwt.Encode(jws) + if err != nil { + t.Fatal(err) + } + parsed, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatal(err) + } + if err := verifier.Verify(parsed); err != nil { + t.Fatalf("Verify: %v", err) + } +} + +// TestSignJWTUnknownKID verifies that SignJWT returns ErrUnknownKID when the +// header requests a KID that the signer doesn't have. +func TestSignJWTUnknownKID(t *testing.T) { + key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, key1, "k1")}) + + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user-1", + Aud: jwt.Listish{"app"}, + Exp: time.Now().Add(time.Hour).Unix(), + IAt: time.Now().Unix(), + } + jws, _ := jwt.New(claims) + hdr := jws.GetHeader() + hdr.KID = "nonexistent" + _ = jws.SetHeader(&hdr) + + err := signer.SignJWT(jws) + if !errors.Is(err, jwt.ErrUnknownKID) { + t.Fatalf("expected ErrUnknownKID, got: %v", err) + } +} + +// TestJWKsRoundTrip verifies JWKS serialization and round-trip parsing. +func TestJWKsRoundTrip(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k1")}) + if err != nil { + t.Fatal(err) + } + + jwksBytes, err := json.Marshal(signer) + if err != nil { + t.Fatal(err) + } + + // Round-trip: parse the JWKS JSON and verify it produces a working Verifier. + var jwks jwt.WellKnownJWKs + if err := json.Unmarshal(jwksBytes, &jwks); err != nil { + t.Fatal(err) + } + keys := jwks.Keys + if len(keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(keys)) + } + if keys[0].KID != "k1" { + t.Errorf("expected kid 'k1', got %q", keys[0].KID) + } + + iss2, _ := jwt.NewVerifier(keys) + claims := goodClaims() + tokenStr, _ := signer.SignToString(&claims) + + parsed, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatal(err) + } + if err := iss2.Verify(parsed); err != nil { + t.Fatalf("Verify on round-tripped JWKS failed: %v", err) + } +} + +// TestKeyType verifies that KeyType returns the correct JWK kty string for each key type. +func TestKeyType(t *testing.T) { + ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + edPub, _, _ := ed25519.GenerateKey(rand.Reader) + + tests := []struct { + name string + key jwt.PublicKey + wantKty string + }{ + {"EC P-256", jwt.PublicKey{Key: &ecKey.PublicKey}, "EC"}, + {"RSA 2048", jwt.PublicKey{Key: &rsaKey.PublicKey}, "RSA"}, + {"Ed25519", jwt.PublicKey{Key: edPub}, "OKP"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.key.KeyType(); got != tt.wantKty { + t.Errorf("KeyType() = %q, want %q", got, tt.wantKty) + } + }) + } +} + +// TestPublicKeyOps verifies that PrivateKey.PublicKey() translates key_ops to their +// public-key counterparts ("sign"=>"verify", "decrypt"=>"encrypt", "unwrapKey"=>"wrapKey"). +func TestPublicKeyOps(t *testing.T) { + tests := []struct { + name string + privateOps []string + wantOps []string + }{ + {"sign=>verify", []string{"sign"}, []string{"verify"}}, + {"decrypt=>encrypt", []string{"decrypt"}, []string{"encrypt"}}, + {"unwrapKey=>wrapKey", []string{"unwrapKey"}, []string{"wrapKey"}}, + {"multiple", []string{"sign", "decrypt"}, []string{"verify", "encrypt"}}, + {"public op passthrough", []string{"verify"}, []string{"verify"}}, + {"no public equivalent dropped", []string{"deriveKey"}, nil}, + {"empty", nil, nil}, + } + base, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pk := *base + pk.KeyOps = tt.privateOps + pub, err := pk.PublicKey() + if err != nil { + t.Fatal(err) + } + if len(pub.KeyOps) != len(tt.wantOps) { + t.Fatalf("KeyOps = %v, want %v", pub.KeyOps, tt.wantOps) + } + for i, op := range pub.KeyOps { + if op != tt.wantOps[i] { + t.Errorf("KeyOps[%d] = %q, want %q", i, op, tt.wantOps[i]) + } + } + }) + } +} + +// TestDecodePublicJWKJSON verifies JWKS JSON parsing with real base64url-encoded +// key material from RFC 7517 / OIDC examples. +func TestDecodePublicJWKJSON(t *testing.T) { + jwksJSON := []byte(`{"keys":[ + {"kty":"EC","crv":"P-256", + "x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", + "kid":"ec-256","use":"sig"}, + {"kty":"RSA", + "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", + "e":"AQAB","kid":"rsa-2048","use":"sig"} + ]}`) + + var jwks jwt.WellKnownJWKs + if err := json.Unmarshal(jwksJSON, &jwks); err != nil { + t.Fatal(err) + } + keys := jwks.Keys + if len(keys) != 2 { + t.Fatalf("expected 2 keys, got %d", len(keys)) + } + + var ecCount, rsaCount int + for _, k := range keys { + switch k.KeyType() { + case "EC": + ecCount++ + if k.KID != "ec-256" { + t.Errorf("unexpected EC kid: %s", k.KID) + } + case "RSA": + rsaCount++ + if k.KID != "rsa-2048" { + t.Errorf("unexpected RSA kid: %s", k.KID) + } + } + } + if ecCount != 1 { + t.Errorf("expected 1 EC key, got %d", ecCount) + } + if rsaCount != 1 { + t.Errorf("expected 1 RSA key, got %d", rsaCount) + } +} + +// TestThumbprint verifies that Thumbprint returns a non-empty base64url string +// for EC, RSA, and Ed25519 keys, and that two equal keys produce the same thumbprint. +func TestThumbprint(t *testing.T) { + ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + edPub, _, _ := ed25519.GenerateKey(rand.Reader) + + tests := []struct { + name string + pub jwt.PublicKey + }{ + {"EC P-256", jwt.PublicKey{Key: &ecKey.PublicKey}}, + {"RSA 2048", jwt.PublicKey{Key: &rsaKey.PublicKey}}, + {"Ed25519", jwt.PublicKey{Key: edPub}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + thumb, err := tt.pub.Thumbprint() + if err != nil { + t.Fatalf("Thumbprint() error: %v", err) + } + if thumb == "" { + t.Fatal("Thumbprint() returned empty string") + } + // Must be valid base64url (no padding, no +/) + if strings.Contains(thumb, "+") || strings.Contains(thumb, "/") || strings.Contains(thumb, "=") { + t.Errorf("Thumbprint() contains non-base64url characters: %s", thumb) + } + // Same key, same thumbprint + thumb2, _ := tt.pub.Thumbprint() + if thumb != thumb2 { + t.Errorf("Thumbprint() not deterministic: %s != %s", thumb, thumb2) + } + }) + } +} + +// TestNoKidAutoThumbprint verifies that a JWKS key without a "kid" field gets +// its KID auto-populated from the RFC 7638 thumbprint. +func TestNoKidAutoThumbprint(t *testing.T) { + // EC key with no "kid" field in the JWKS + jwksJSON := []byte(`{"keys":[ + {"kty":"EC","crv":"P-256", + "x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", + "y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", + "use":"sig"} + ]}`) + + var jwks jwt.WellKnownJWKs + if err := json.Unmarshal(jwksJSON, &jwks); err != nil { + t.Fatal(err) + } + keys := jwks.Keys + if len(keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(keys)) + } + if keys[0].KID == "" { + t.Fatal("KID should be auto-populated from Thumbprint when absent in JWKS") + } + + // The auto-KID should be a valid base64url string. + kid := keys[0].KID + if strings.Contains(kid, "+") || strings.Contains(kid, "/") || strings.Contains(kid, "=") { + t.Errorf("auto-KID contains non-base64url characters: %s", kid) + } + + // Round-trip: compute Thumbprint directly and compare. + thumb, err := keys[0].Thumbprint() + if err != nil { + t.Fatalf("Thumbprint() error: %v", err) + } + if kid != thumb { + t.Errorf("auto-KID %q != direct Thumbprint %q", kid, thumb) + } +} + +// TestNewPrivateKey verifies that jwt.NewPrivateKey generates an Ed25519 key +// with a non-empty KID auto-derived from the thumbprint, and that the key +// works end-to-end for signing and verification. +func TestNewPrivateKey(t *testing.T) { + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatalf("NewPrivateKey() error: %v", err) + } + if pk.KID == "" { + t.Fatal("NewPrivateKey() returned empty KID") + } + // KID must be base64url (no +, /, or =). + if strings.Contains(pk.KID, "+") || strings.Contains(pk.KID, "/") || strings.Contains(pk.KID, "=") { + t.Errorf("KID contains non-base64url characters: %s", pk.KID) + } + // Two calls must produce different keys but always produce valid base64url KIDs. + pk2, _ := jwt.NewPrivateKey() + if pk.KID == pk2.KID { + t.Error("NewPrivateKey() produced identical KIDs for two different keys") + } + + // Full sign+verify round-trip with the generated key. + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatalf("NewSigner() error: %v", err) + } + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatalf("SignToString() error: %v", err) + } + iss := signer.Verifier() + jws, err := iss.VerifyJWT(tokenStr) + if err != nil { + t.Fatalf("VerifyJWT() error: %v", err) + } + var decoded AppClaims + if err := jws.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims() error: %v", err) + } + if decoded.Sub != claims.Sub { + t.Errorf("sub: got %q, want %q", decoded.Sub, claims.Sub) + } +} + +// --- DecodeRaw + UnmarshalHeader tests --- + +func TestDecodeRaw(t *testing.T) { + // Sign a real token to get a valid compact string. + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatalf("NewPrivateKey() error: %v", err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatalf("NewSigner() error: %v", err) + } + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatalf("SignToString() error: %v", err) + } + + raw, err := jwt.DecodeRaw(tokenStr) + if err != nil { + t.Fatalf("DecodeRaw() error: %v", err) + } + + // protected and payload should be non-empty base64url segments. + if len(raw.GetProtected()) == 0 { + t.Error("GetProtected() is empty") + } + if len(raw.GetPayload()) == 0 { + t.Error("GetPayload() is empty") + } + if len(raw.GetSignature()) == 0 { + t.Error("GetSignature() is empty") + } +} + +func TestDecodeRawErrors(t *testing.T) { + tests := []struct { + name string + input string + sentinel error + }{ + {"empty string", "", jwt.ErrMalformedToken}, + {"one segment", "abc", jwt.ErrMalformedToken}, + {"two segments", "abc.def", jwt.ErrMalformedToken}, + {"four segments", "a.b.c.d", jwt.ErrMalformedToken}, + {"bad signature base64", "eyJhbGciOiJFZERTQSJ9.eyJpc3MiOiJ4In0.!!!bad!!!", jwt.ErrSignatureInvalid}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := jwt.DecodeRaw(tc.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, tc.sentinel) { + t.Errorf("expected %v, got: %v", tc.sentinel, err) + } + }) + } +} + +func TestDecodeRawSegmentCount(t *testing.T) { + _, err := jwt.DecodeRaw("") + if err == nil { + t.Fatal("expected error") + } + // Empty string normalizes to 0 segments. + if !strings.Contains(err.Error(), "got 0") { + t.Errorf("expected segment count 0 in error, got: %v", err) + } + + _, err = jwt.DecodeRaw("a.b") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "got 2") { + t.Errorf("expected segment count 2 in error, got: %v", err) + } +} + +func TestUnmarshalHeader(t *testing.T) { + // Sign a token and use DecodeRaw + UnmarshalHeader to recover the header. + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatalf("NewPrivateKey() error: %v", err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatalf("NewSigner() error: %v", err) + } + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatalf("SignToString() error: %v", err) + } + + raw, err := jwt.DecodeRaw(tokenStr) + if err != nil { + t.Fatalf("DecodeRaw() error: %v", err) + } + + var h jwt.RFCHeader + if err := raw.UnmarshalHeader(&h); err != nil { + t.Fatalf("UnmarshalHeader() error: %v", err) + } + + if h.Alg != "EdDSA" { + t.Errorf("alg: got %q, want %q", h.Alg, "EdDSA") + } + if h.KID != pk.KID { + t.Errorf("kid: got %q, want %q", h.KID, pk.KID) + } + if h.Typ != "JWT" { + t.Errorf("typ: got %q, want %q", h.Typ, "JWT") + } +} + +func TestUnmarshalHeaderCustomFields(t *testing.T) { + // Build a token whose header has a custom "nonce" field by constructing + // the compact string manually: base64(header).base64(payload).base64(sig). + type CustomHeader struct { + jwt.RFCHeader + Nonce string `json:"nonce"` + } + + hdr := CustomHeader{ + RFCHeader: jwt.RFCHeader{Alg: "EdDSA", KID: "test-key", Typ: "dpop+jwt"}, + Nonce: "server-nonce-42", + } + hdrJSON, _ := json.Marshal(hdr) + payJSON := []byte(`{"sub":"user"}`) + fakeSig := []byte{0xDE, 0xAD} + + compact := base64.RawURLEncoding.EncodeToString(hdrJSON) + + "." + base64.RawURLEncoding.EncodeToString(payJSON) + + "." + base64.RawURLEncoding.EncodeToString(fakeSig) + + raw, err := jwt.DecodeRaw(compact) + if err != nil { + t.Fatalf("DecodeRaw() error: %v", err) + } + + var got CustomHeader + if err := raw.UnmarshalHeader(&got); err != nil { + t.Fatalf("UnmarshalHeader() error: %v", err) + } + + if got.Nonce != "server-nonce-42" { + t.Errorf("nonce: got %q, want %q", got.Nonce, "server-nonce-42") + } + if got.Alg != "EdDSA" { + t.Errorf("alg: got %q, want %q", got.Alg, "EdDSA") + } + if got.Typ != "dpop+jwt" { + t.Errorf("typ: got %q, want %q", got.Typ, "dpop+jwt") + } +} + +func TestUnmarshalHeaderViaJWS(t *testing.T) { + // Verify that UnmarshalHeader is promoted from RawJWT to *JWT. + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatalf("NewPrivateKey() error: %v", err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatalf("NewSigner() error: %v", err) + } + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatalf("SignToString() error: %v", err) + } + + // Use Decode (not DecodeRaw) - UnmarshalHeader should still work via promotion. + jws, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatalf("Decode() error: %v", err) + } + + var h jwt.RFCHeader + if err := jws.UnmarshalHeader(&h); err != nil { + t.Fatalf("jws.UnmarshalHeader() error: %v", err) + } + if h.Alg != "EdDSA" { + t.Errorf("alg: got %q, want %q", h.Alg, "EdDSA") + } +} + +// --- SpaceDelimited tests --- + +func TestSpaceDelimitedMarshalJSON(t *testing.T) { + tests := []struct { + name string + in jwt.SpaceDelimited + want string + }{ + {"multiple", jwt.SpaceDelimited{"openid", "profile", "email"}, `"openid profile email"`}, + {"single", jwt.SpaceDelimited{"openid"}, `"openid"`}, + {"empty", jwt.SpaceDelimited{}, `""`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.in) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + if string(got) != tt.want { + t.Errorf("got %s, want %s", got, tt.want) + } + }) + } +} + +func TestSpaceDelimitedUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + in string + want jwt.SpaceDelimited + wantNil bool + }{ + {"multiple", `"openid profile email"`, jwt.SpaceDelimited{"openid", "profile", "email"}, false}, + {"single", `"openid"`, jwt.SpaceDelimited{"openid"}, false}, + {"empty", `""`, jwt.SpaceDelimited{}, false}, + {"extra whitespace", `"openid profile\temail"`, jwt.SpaceDelimited{"openid", "profile", "email"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got jwt.SpaceDelimited + if err := json.Unmarshal([]byte(tt.in), &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if tt.wantNil { + if got != nil { + t.Errorf("got %v, want nil", got) + } + return + } + if len(got) != len(tt.want) { + t.Fatalf("got %v (len %d), want %v (len %d)", got, len(got), tt.want, len(tt.want)) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("got[%d] = %q, want %q", i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestSpaceDelimitedRoundTrip(t *testing.T) { + type claims struct { + Scope jwt.SpaceDelimited `json:"scope,omitempty"` + } + orig := claims{Scope: jwt.SpaceDelimited{"openid", "profile"}} + data, err := json.Marshal(orig) + if err != nil { + t.Fatal(err) + } + + // Verify wire format is a space-separated string. + if !strings.Contains(string(data), `"openid profile"`) { + t.Fatalf("expected space-separated scope in JSON, got %s", data) + } + + var decoded claims + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatal(err) + } + if len(decoded.Scope) != 2 || decoded.Scope[0] != "openid" || decoded.Scope[1] != "profile" { + t.Errorf("round-trip failed: got %v", decoded.Scope) + } +} + +// --- SetTyp / NewAccessToken tests --- + +func TestSetTyp(t *testing.T) { + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user1", + Aud: jwt.Listish{"api"}, + Exp: time.Now().Add(time.Hour).Unix(), + IAt: time.Now().Unix(), + } + jws, err := jwt.New(claims) + if err != nil { + t.Fatal(err) + } + + // Default typ is "JWT". + if got := jws.GetHeader().Typ; got != "JWT" { + t.Fatalf("default typ: got %q, want %q", got, "JWT") + } + + jws.SetTyp(jwt.AccessTokenTyp) + + if got := jws.GetHeader().Typ; got != jwt.AccessTokenTyp { + t.Fatalf("after SetTyp: got %q, want %q", got, jwt.AccessTokenTyp) + } +} + +func TestSetTypSurvivesSigning(t *testing.T) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "")}) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + claims := &jwt.TokenClaims{ + Iss: "https://auth.example.com", + Sub: "user1", + Aud: jwt.Listish{"https://api.example.com"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + JTI: "tok-001", + ClientID: "webapp", + Scope: jwt.SpaceDelimited{"openid", "profile"}, + } + + jws, err := jwt.NewAccessToken(claims) + if err != nil { + t.Fatal(err) + } + if err := signer.SignJWT(jws); err != nil { + t.Fatal(err) + } + + // Verify typ survived signing. + if got := jws.GetHeader().Typ; got != jwt.AccessTokenTyp { + t.Fatalf("typ after signing: got %q, want %q", got, jwt.AccessTokenTyp) + } + + // Decode the token and verify typ is in the wire format. + token, err := jwt.Encode(jws) + if err != nil { + t.Fatalf("Encode failed: %v", err) + } + decoded, err := jwt.Decode(token) + if err != nil { + t.Fatal(err) + } + if got := decoded.GetHeader().Typ; got != jwt.AccessTokenTyp { + t.Fatalf("typ after decode: got %q, want %q", got, jwt.AccessTokenTyp) + } + + // Verify claims round-trip. + var rt jwt.TokenClaims + if err := decoded.UnmarshalClaims(&rt); err != nil { + t.Fatal(err) + } + if rt.ClientID != "webapp" { + t.Errorf("client_id: got %q, want %q", rt.ClientID, "webapp") + } + if len(rt.Scope) != 2 || rt.Scope[0] != "openid" { + t.Errorf("scope: got %v, want [openid profile]", rt.Scope) + } +} + +// --- Access token Validator tests --- + +func goodAccessTokenClaims() *jwt.TokenClaims { + now := time.Now() + return &jwt.TokenClaims{ + Iss: "https://auth.example.com", + Sub: "user1", + Aud: jwt.Listish{"https://api.example.com"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + JTI: "tok-001", + ClientID: "webapp", + Scope: jwt.SpaceDelimited{"openid", "profile"}, + AMR: []string{"pwd"}, + } +} + +func TestAccessTokenValidatorHappyPath(t *testing.T) { + v := jwt.NewAccessTokenValidator( + []string{"https://auth.example.com"}, + []string{"https://api.example.com"}, + 0, + ) + claims := goodAccessTokenClaims() + if err := v.Validate(nil, claims, time.Now()); err != nil { + t.Fatalf("valid access token rejected: %v", err) + } +} + +func TestAccessTokenValidatorRequiresJTI(t *testing.T) { + v := jwt.NewAccessTokenValidator( + []string{"https://auth.example.com"}, + []string{"https://api.example.com"}, + 0, + ) + claims := goodAccessTokenClaims() + claims.JTI = "" + err := v.Validate(nil, claims, time.Now()) + if !errors.Is(err, jwt.ErrMissingClaim) { + t.Fatalf("expected ErrMissingClaim for missing jti, got: %v", err) + } +} + +func TestAccessTokenValidatorRequiresClientID(t *testing.T) { + v := jwt.NewAccessTokenValidator( + []string{"https://auth.example.com"}, + []string{"https://api.example.com"}, + 0, + ) + claims := goodAccessTokenClaims() + claims.ClientID = "" + err := v.Validate(nil, claims, time.Now()) + if !errors.Is(err, jwt.ErrMissingClaim) { + t.Fatalf("expected ErrMissingClaim for missing client_id, got: %v", err) + } +} + +func TestAccessTokenValidatorDisableClientID(t *testing.T) { + v := jwt.NewAccessTokenValidator( + []string{"https://auth.example.com"}, + []string{"https://api.example.com"}, + 0, + ) + v.Checks &^= jwt.CheckClientID + claims := goodAccessTokenClaims() + claims.ClientID = "" + if err := v.Validate(nil, claims, time.Now()); err != nil { + t.Fatalf("disabling CheckClientID should accept empty client_id: %v", err) + } +} + +func TestAccessTokenValidatorExpiredToken(t *testing.T) { + v := jwt.NewAccessTokenValidator( + []string{"https://auth.example.com"}, + []string{"https://api.example.com"}, + 0, + ) + claims := goodAccessTokenClaims() + claims.Exp = time.Now().Add(-time.Hour).Unix() + err := v.Validate(nil, claims, time.Now()) + if !errors.Is(err, jwt.ErrAfterExp) { + t.Fatalf("expected ErrAfterExp, got: %v", err) + } +} + +func TestAccessTokenValidatorRequiredScopes(t *testing.T) { + v := jwt.NewAccessTokenValidator( + []string{"https://auth.example.com"}, + []string{"https://api.example.com"}, + 0, + ) + v.RequiredScopes = []string{"openid", "admin"} + claims := goodAccessTokenClaims() + // claims has ["openid", "profile"] - missing "admin" + err := v.Validate(nil, claims, time.Now()) + if !errors.Is(err, jwt.ErrInsufficientScope) { + t.Fatalf("expected ErrInsufficientScope for missing scope, got: %v", err) + } +} + +func TestAccessTokenValidatorExpectScope(t *testing.T) { + v := jwt.NewAccessTokenValidator( + []string{"https://auth.example.com"}, + []string{"https://api.example.com"}, + 0, + ) + v.Checks |= jwt.CheckScope // enable scope presence check + claims := goodAccessTokenClaims() + claims.Scope = nil + err := v.Validate(nil, claims, time.Now()) + if !errors.Is(err, jwt.ErrMissingClaim) { + t.Fatalf("expected ErrMissingClaim for empty scope, got: %v", err) + } +} + +func TestAccessTokenValidatorDisableJTI(t *testing.T) { + v := jwt.NewAccessTokenValidator( + []string{"https://auth.example.com"}, + []string{"https://api.example.com"}, + 0, + ) + v.Checks &^= jwt.CheckJTI + claims := goodAccessTokenClaims() + claims.JTI = "" + if err := v.Validate(nil, claims, time.Now()); err != nil { + t.Fatalf("disabling CheckJTI should accept empty jti: %v", err) + } +} + +// --- Encode validation tests --- + +// stubJWT is a minimal VerifiableJWT for testing Encode validation. +type stubJWT struct { + protected []byte + payload []byte + signature []byte + header jwt.RFCHeader +} + +func (s *stubJWT) GetProtected() []byte { return s.protected } +func (s *stubJWT) GetPayload() []byte { return s.payload } +func (s *stubJWT) GetSignature() []byte { return s.signature } +func (s *stubJWT) GetHeader() jwt.RFCHeader { return s.header } + +// TestEncodeRejectsEmptyAlg verifies that Encode returns an error +// when the alg header field is empty (unsigned token). +func TestEncodeRejectsEmptyAlg(t *testing.T) { + // Zero-value stub: no alg set. + jws := &stubJWT{} + _, err := jwt.Encode(jws) + if err == nil { + t.Fatal("expected error for empty alg") + } + if !errors.Is(err, jwt.ErrInvalidHeader) { + t.Fatalf("expected ErrInvalidHeader, got: %v", err) + } + if !strings.Contains(err.Error(), "alg is empty") { + t.Fatalf("unexpected message: %v", err) + } + + // Explicit header with typ but no alg. + jws2 := &stubJWT{ + protected: []byte(base64.RawURLEncoding.EncodeToString([]byte(`{"typ":"JWT"}`))), + payload: []byte(base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"user1"}`))), + signature: []byte{0x01, 0x02, 0x03}, + header: jwt.RFCHeader{Typ: "JWT"}, + } + _, err = jwt.Encode(jws2) + if err == nil { + t.Fatal("expected error for empty alg with typ-only header") + } + if !errors.Is(err, jwt.ErrInvalidHeader) { + t.Fatalf("expected ErrInvalidHeader, got: %v", err) + } +} + +// TestEncodeSucceedsAfterSigning verifies the happy path: a signed JWT +// encodes without error. +func TestEncodeSucceedsAfterSigning(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k1")}) + if err != nil { + t.Fatal(err) + } + claims := goodClaims() + jws, err := signer.Sign(&claims) + if err != nil { + t.Fatal(err) + } + + token, err := jwt.Encode(jws) + if err != nil { + t.Fatalf("Encode failed on signed JWT: %v", err) + } + parts := strings.Split(token, ".") + if len(parts) != 3 { + t.Fatalf("expected 3 segments, got %d", len(parts)) + } + for i, p := range parts { + if p == "" { + t.Fatalf("segment %d is empty", i) + } + } +} + +// --- Full pipeline round-trip tests --- + +// TestRoundTrip_IDToken exercises the full ID token pipeline: +// NewPrivateKey -> NewSigner -> SignToString -> Decode -> Verify -> UnmarshalClaims -> Validate. +func TestRoundTrip_IDToken(t *testing.T) { + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + claims := &jwt.TokenClaims{ + Iss: "https://idp.example.com", + Sub: "user-42", + Aud: jwt.Listish{"my-client"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + AuthTime: now.Unix(), + AzP: "my-client", + Nonce: "n-0S6_WzA2Mj", + AMR: []string{"pwd", "otp"}, + JTI: "id-tok-001", + } + + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + + // Decode + Verify + decoded, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatalf("Decode: %v", err) + } + verifier := signer.Verifier() + if err := verifier.Verify(decoded); err != nil { + t.Fatalf("Verify: %v", err) + } + + // UnmarshalClaims + var got jwt.TokenClaims + if err := decoded.UnmarshalClaims(&got); err != nil { + t.Fatalf("UnmarshalClaims: %v", err) + } + + // Validate with NewIDTokenValidator + v := jwt.NewIDTokenValidator( + []string{"https://idp.example.com"}, + []string{"my-client"}, + []string{"my-client"}, + 0, + ) + if err := v.Validate(nil, &got, time.Now()); err != nil { + t.Fatalf("Validate: %v", err) + } + + // Spot-check round-tripped fields. + if got.Sub != "user-42" { + t.Errorf("sub: got %q, want %q", got.Sub, "user-42") + } + if got.Nonce != "n-0S6_WzA2Mj" { + t.Errorf("nonce: got %q, want %q", got.Nonce, "n-0S6_WzA2Mj") + } + if len(got.AMR) != 2 || got.AMR[0] != "pwd" || got.AMR[1] != "otp" { + t.Errorf("amr: got %v, want [pwd otp]", got.AMR) + } +} + +// TestRoundTrip_AccessToken exercises the full access token pipeline: +// NewAccessToken -> SignJWT -> Encode -> Decode -> Verify -> UnmarshalClaims -> Validate with RequiredScopes. +func TestRoundTrip_AccessToken(t *testing.T) { + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + claims := &jwt.TokenClaims{ + Iss: "https://auth.example.com", + Sub: "svc-account", + Aud: jwt.Listish{"https://api.example.com"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + JTI: "at-001", + ClientID: "backend-svc", + Scope: jwt.SpaceDelimited{"read", "write", "admin"}, + } + + jws, err := jwt.NewAccessToken(claims) + if err != nil { + t.Fatal(err) + } + if err := signer.SignJWT(jws); err != nil { + t.Fatal(err) + } + + tokenStr, err := jwt.Encode(jws) + if err != nil { + t.Fatalf("Encode: %v", err) + } + + // Decode + Verify + decoded, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatalf("Decode: %v", err) + } + + // Verify typ survived the round-trip. + if got := decoded.GetHeader().Typ; got != jwt.AccessTokenTyp { + t.Errorf("typ: got %q, want %q", got, jwt.AccessTokenTyp) + } + + verifier := signer.Verifier() + if err := verifier.Verify(decoded); err != nil { + t.Fatalf("Verify: %v", err) + } + + var got jwt.TokenClaims + if err := decoded.UnmarshalClaims(&got); err != nil { + t.Fatalf("UnmarshalClaims: %v", err) + } + + // Validate with RequiredScopes. + v := jwt.NewAccessTokenValidator( + []string{"https://auth.example.com"}, + []string{"https://api.example.com"}, + 0, + ) + v.RequiredScopes = []string{"read", "write"} + if err := v.Validate(nil, &got, time.Now()); err != nil { + t.Fatalf("Validate: %v", err) + } + + // Spot-check scope round-trip. + if len(got.Scope) != 3 || got.Scope[0] != "read" || got.Scope[2] != "admin" { + t.Errorf("scope: got %v, want [read write admin]", got.Scope) + } + if got.ClientID != "backend-svc" { + t.Errorf("client_id: got %q, want %q", got.ClientID, "backend-svc") + } +} + +// TestRoundTrip_StandardClaims verifies that StandardClaims with NullBool fields +// survive the sign-decode round-trip. +func TestRoundTrip_StandardClaims(t *testing.T) { + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + claims := &jwt.StandardClaims{ + TokenClaims: jwt.TokenClaims{ + Iss: "https://idp.example.com", + Sub: "user-99", + Aud: jwt.Listish{"app"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + AuthTime: now.Unix(), + AzP: "app", + }, + Name: "Jane Doe", + Email: "jane@example.com", + EmailVerified: jwt.NullBool{Bool: true, Valid: true}, + PhoneNumber: "+1-555-0100", + PhoneNumberVerified: jwt.NullBool{Bool: false, Valid: true}, + Locale: "en-US", + } + + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + + // Decode + Verify + verifier := signer.Verifier() + jws, err := verifier.VerifyJWT(tokenStr) + if err != nil { + t.Fatalf("VerifyJWT: %v", err) + } + + var got jwt.StandardClaims + if err := jws.UnmarshalClaims(&got); err != nil { + t.Fatalf("UnmarshalClaims: %v", err) + } + + // Verify NullBool fields survived the round-trip. + if !got.EmailVerified.Valid || !got.EmailVerified.Bool { + t.Errorf("email_verified: got %+v, want {Bool:true Valid:true}", got.EmailVerified) + } + if !got.PhoneNumberVerified.Valid || got.PhoneNumberVerified.Bool { + t.Errorf("phone_number_verified: got %+v, want {Bool:false Valid:true}", got.PhoneNumberVerified) + } + + // Verify other profile fields. + if got.Name != "Jane Doe" { + t.Errorf("name: got %q, want %q", got.Name, "Jane Doe") + } + if got.Email != "jane@example.com" { + t.Errorf("email: got %q, want %q", got.Email, "jane@example.com") + } + if got.Locale != "en-US" { + t.Errorf("locale: got %q, want %q", got.Locale, "en-US") + } + + // Verify that an unset NullBool (not in JSON) comes back as zero value. + // StandardClaims has no field we explicitly omitted that is a NullBool + // other than the two we set, so we create a fresh StandardClaims without + // email_verified to test omission. + claims2 := &jwt.StandardClaims{ + TokenClaims: jwt.TokenClaims{ + Iss: "https://idp.example.com", + Sub: "user-100", + Aud: jwt.Listish{"app"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + AuthTime: now.Unix(), + AzP: "app", + }, + Email: "bob@example.com", + // EmailVerified left as zero value (NullBool{}) + } + tok2, err := signer.SignToString(claims2) + if err != nil { + t.Fatal(err) + } + jws2, err := verifier.VerifyJWT(tok2) + if err != nil { + t.Fatal(err) + } + var got2 jwt.StandardClaims + if err := jws2.UnmarshalClaims(&got2); err != nil { + t.Fatal(err) + } + if got2.EmailVerified.Valid { + t.Errorf("omitted email_verified should be invalid, got %+v", got2.EmailVerified) + } +} + +// --- DPoPJWT: custom header type used by TestRoundTrip_CustomHeader --- + +// dpopHeader extends the standard JOSE header with a DPoP nonce. +type dpopHeader struct { + jwt.RFCHeader + Nonce string `json:"nonce,omitempty"` +} + +// dpopJWT is a custom JWT that carries a dpopHeader. +type dpopJWT struct { + jwt.RawJWT + Header dpopHeader +} + +func (d *dpopJWT) GetHeader() jwt.RFCHeader { return d.Header.RFCHeader } + +func (d *dpopJWT) SetHeader(hdr jwt.Header) error { + d.Header.RFCHeader = *hdr.GetRFCHeader() + data, err := json.Marshal(d.Header) + if err != nil { + return err + } + d.Protected = []byte(base64.RawURLEncoding.EncodeToString(data)) + return nil +} + +// TestRoundTrip_CustomHeader verifies that custom header fields survive the +// full sign-decode-verify round-trip when using a custom SignableJWT type. +func TestRoundTrip_CustomHeader(t *testing.T) { + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatal(err) + } + verifier := signer.Verifier() + + now := time.Now() + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user-1", + Aud: jwt.Listish{"api"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + } + + // Build and sign a DPoP JWT with a custom nonce header. + dpop := &dpopJWT{Header: dpopHeader{ + RFCHeader: jwt.RFCHeader{Typ: "dpop+jwt"}, + Nonce: "server-nonce-abc", + }} + if err := dpop.SetClaims(claims); err != nil { + t.Fatal(err) + } + if err := signer.SignJWT(dpop); err != nil { + t.Fatal(err) + } + tokenStr, err := jwt.Encode(dpop) + if err != nil { + t.Fatal(err) + } + + // Verify the signature with the standard Decode path. + jws, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatalf("Decode: %v", err) + } + if err := verifier.Verify(jws); err != nil { + t.Fatalf("Verify: %v", err) + } + + // Recover the custom header via DecodeRaw + UnmarshalHeader. + raw, err := jwt.DecodeRaw(tokenStr) + if err != nil { + t.Fatalf("DecodeRaw: %v", err) + } + var gotHdr dpopHeader + if err := raw.UnmarshalHeader(&gotHdr); err != nil { + t.Fatalf("UnmarshalHeader: %v", err) + } + + if gotHdr.Nonce != "server-nonce-abc" { + t.Errorf("nonce: got %q, want %q", gotHdr.Nonce, "server-nonce-abc") + } + if gotHdr.Typ != "dpop+jwt" { + t.Errorf("typ: got %q, want %q", gotHdr.Typ, "dpop+jwt") + } + if gotHdr.Alg != "EdDSA" { + t.Errorf("alg: got %q, want %q", gotHdr.Alg, "EdDSA") + } + if gotHdr.KID != pk.KID { + t.Errorf("kid: got %q, want %q", gotHdr.KID, pk.KID) + } +} + +// TestRoundTrip_ExpiredTokenRejection signs a token with a past exp, verifies +// that Decode+Verify succeeds (signature is valid) but Validate rejects it. +func TestRoundTrip_ExpiredTokenRejection(t *testing.T) { + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + claims := &jwt.TokenClaims{ + Iss: "https://idp.example.com", + Sub: "user-1", + Aud: jwt.Listish{"app"}, + Exp: now.Add(-time.Hour).Unix(), // expired 1 hour ago + IAt: now.Add(-2 * time.Hour).Unix(), + AuthTime: now.Add(-2 * time.Hour).Unix(), + AzP: "app", + } + + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + + // Decode + Verify should succeed - the signature is valid. + decoded, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatalf("Decode: %v", err) + } + verifier := signer.Verifier() + if err := verifier.Verify(decoded); err != nil { + t.Fatalf("Verify should succeed for expired token: %v", err) + } + + var got jwt.TokenClaims + if err := decoded.UnmarshalClaims(&got); err != nil { + t.Fatalf("UnmarshalClaims: %v", err) + } + + // Validate should fail with ErrAfterExp. + v := jwt.NewIDTokenValidator( + []string{"https://idp.example.com"}, + []string{"app"}, + []string{"app"}, + 0, + ) + err = v.Validate(nil, &got, time.Now()) + if err == nil { + t.Fatal("expected validation error for expired token") + } + if !errors.Is(err, jwt.ErrAfterExp) { + t.Fatalf("expected ErrAfterExp, got: %v", err) + } +} + +// TestRoundTrip_WrongAudienceRejection signs a token with one audience and +// validates against a different audience, expecting rejection. +func TestRoundTrip_WrongAudienceRejection(t *testing.T) { + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + claims := &jwt.TokenClaims{ + Iss: "https://idp.example.com", + Sub: "user-1", + Aud: jwt.Listish{"app-A"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + AuthTime: now.Unix(), + AzP: "app-A", + } + + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + + // Decode + Verify succeeds. + verifier := signer.Verifier() + jws, err := verifier.VerifyJWT(tokenStr) + if err != nil { + t.Fatalf("VerifyJWT: %v", err) + } + + var got jwt.TokenClaims + if err := jws.UnmarshalClaims(&got); err != nil { + t.Fatalf("UnmarshalClaims: %v", err) + } + + // Validate with a different audience - should fail. + v := jwt.NewIDTokenValidator( + []string{"https://idp.example.com"}, + []string{"app-B"}, // wrong audience + []string{"app-A"}, + 0, + ) + err = v.Validate(nil, &got, time.Now()) + if err == nil { + t.Fatal("expected validation error for wrong audience") + } + if !errors.Is(err, jwt.ErrInvalidClaim) { + t.Fatalf("expected ErrInvalidClaim for aud mismatch, got: %v", err) + } +} + +// TestDuplicateKIDRotation verifies that when multiple keys share the same KID +// (e.g. during key rotation), the verifier tries all matching keys and succeeds +// if any one of them verifies the signature. +func TestDuplicateKIDRotation(t *testing.T) { + oldKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + newKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + // Both keys share the same KID (simulating a rotation where the KID is reused). + sharedKID := "rotating-key" + + // Sign a token with the OLD key. + oldSigner, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, oldKey, sharedKID)}) + if err != nil { + t.Fatal(err) + } + claims := goodClaims() + oldToken, err := oldSigner.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + // Sign a token with the NEW key. + newSigner, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, newKey, sharedKID)}) + if err != nil { + t.Fatal(err) + } + newToken, err := newSigner.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + // Verifier has both keys under the same KID. + verifier, err := jwt.NewVerifier([]jwt.PublicKey{ + {Key: &oldKey.PublicKey, KID: sharedKID}, + {Key: &newKey.PublicKey, KID: sharedKID}, + }) + if err != nil { + t.Fatal(err) + } + + // Both tokens should verify successfully. + for _, tt := range []struct { + name string + token string + }{ + {"old key", oldToken}, + {"new key", newToken}, + } { + t.Run(tt.name, func(t *testing.T) { + parsed, err := jwt.Decode(tt.token) + if err != nil { + t.Fatalf("Decode: %v", err) + } + if err := verifier.Verify(parsed); err != nil { + t.Fatalf("Verify should succeed for %s: %v", tt.name, err) + } + }) + } +} + +// TestNoKIDTriesAllKeys verifies that when a token has no KID, all verifier +// keys are tried and the first successful verification wins. +func TestNoKIDTriesAllKeys(t *testing.T) { + wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + rightKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + // Sign with rightKey using SignRaw with a header that has no KID. + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, rightKey, "any")}) + hdr := &jwt.RFCHeader{} // no KID, no typ + payloadJSON := []byte(`{"sub":"user-1"}`) + raw, err := signer.SignRaw(hdr, payloadJSON) + if err != nil { + t.Fatal(err) + } + + // Reconstruct as compact token: protected.payload.signature + token := string(raw.GetProtected()) + "." + string(raw.GetPayload()) + + "." + base64.RawURLEncoding.EncodeToString(raw.GetSignature()) + + // Verifier has wrongKey first, then rightKey. + verifier, err := jwt.NewVerifier([]jwt.PublicKey{ + {Key: &wrongKey.PublicKey, KID: "wrong"}, + {Key: &rightKey.PublicKey, KID: "right"}, + }) + if err != nil { + t.Fatal(err) + } + + parsed, err := jwt.Decode(token) + if err != nil { + t.Fatal(err) + } + if parsed.GetHeader().KID != "" { + t.Fatalf("token should have no KID, got %q", parsed.GetHeader().KID) + } + if err := verifier.Verify(parsed); err != nil { + t.Fatalf("Verify should try all keys when token has no KID: %v", err) + } +} + +// TestEmptyKIDTokenEmptyKIDKey verifies that when both the token and a +// verifier key have empty KIDs, verification succeeds (all keys are tried). +func TestEmptyKIDTokenEmptyKIDKey(t *testing.T) { + rightKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + // Sign with SignRaw to produce a token with no KID. + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, rightKey, "any")}) + raw, err := signer.SignRaw(&jwt.RFCHeader{}, []byte(`{"sub":"user-1"}`)) + if err != nil { + t.Fatal(err) + } + token := string(raw.GetProtected()) + "." + string(raw.GetPayload()) + + "." + base64.RawURLEncoding.EncodeToString(raw.GetSignature()) + + // Verifier key also has empty KID. + verifier, err := jwt.NewVerifier([]jwt.PublicKey{ + {Key: &rightKey.PublicKey, KID: ""}, + }) + if err != nil { + t.Fatal(err) + } + + parsed, err := jwt.Decode(token) + if err != nil { + t.Fatal(err) + } + if err := verifier.Verify(parsed); err != nil { + t.Fatalf("empty KID token + empty KID key should verify: %v", err) + } +} + +// TestKIDTokenEmptyKIDKey verifies that when the token has a KID but the +// verifier key has an empty KID, the key is not a candidate (no match). +func TestKIDTokenEmptyKIDKey(t *testing.T) { + rightKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + // Sign normally -- token gets a KID from the signer. + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, rightKey, "my-kid")}) + claims := goodClaims() + token, _ := signer.SignToString(&claims) + + // Verifier has the same key material but with an empty KID. + verifier, err := jwt.NewVerifier([]jwt.PublicKey{ + {Key: &rightKey.PublicKey, KID: ""}, + }) + if err != nil { + t.Fatal(err) + } + + parsed, err := jwt.Decode(token) + if err != nil { + t.Fatal(err) + } + err = verifier.Verify(parsed) + if !errors.Is(err, jwt.ErrUnknownKID) { + t.Fatalf("token KID %q should not match empty-KID key, expected ErrUnknownKID, got: %v", + parsed.GetHeader().KID, err) + } +} + +// TestMultiKeyVerifier verifies that a Verifier with keys of different algorithms +// correctly selects the right key by KID when verifying tokens. +func TestMultiKeyVerifier(t *testing.T) { + ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) + edPub, edPriv, _ := ed25519.GenerateKey(rand.Reader) + + ecPK := mustPK(t, ecKey, "ec-key") + rsaPK := mustPK(t, rsaKey, "rsa-key") + edPK := mustPK(t, edPriv, "ed-key") + + // Create a verifier with all three public keys. + verifier, err := jwt.NewVerifier([]jwt.PublicKey{ + {Key: &ecKey.PublicKey, KID: "ec-key"}, + {Key: &rsaKey.PublicKey, KID: "rsa-key"}, + {Key: edPub, KID: "ed-key"}, + }) + if err != nil { + t.Fatal(err) + } + + // Sign a token with each key type and verify the multi-key verifier picks the right one. + for _, tt := range []struct { + name string + pk *jwt.PrivateKey + alg string + }{ + {"EC/ES256", ecPK, "ES256"}, + {"RSA/RS256", rsaPK, "RS256"}, + {"Ed25519/EdDSA", edPK, "EdDSA"}, + } { + t.Run(tt.name, func(t *testing.T) { + signer, err := jwt.NewSigner([]*jwt.PrivateKey{tt.pk}) + if err != nil { + t.Fatal(err) + } + claims := goodClaims() + tokenStr, err := signer.SignToString(&claims) + if err != nil { + t.Fatal(err) + } + + parsed, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatalf("Decode: %v", err) + } + if parsed.GetHeader().Alg != tt.alg { + t.Fatalf("alg: got %s, want %s", parsed.GetHeader().Alg, tt.alg) + } + if err := verifier.Verify(parsed); err != nil { + t.Fatalf("Verify failed for %s: %v", tt.alg, err) + } + }) + } +} + +// TestAudienceSingleString verifies that a single-string "aud" claim +// (RFC 7519 ยง4.1.3) is correctly unmarshaled as a single-element Audience. +func TestAudienceSingleString(t *testing.T) { + pk, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user-1", + Aud: jwt.Listish{"single-aud"}, // single element + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + AuthTime: now.Unix(), + } + + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + + // Verify the wire format uses a string (not an array) for single-element aud. + decoded, err := jwt.Decode(tokenStr) + if err != nil { + t.Fatal(err) + } + var rawPayload map[string]json.RawMessage + payloadBytes, _ := base64.RawURLEncoding.DecodeString(string(decoded.GetPayload())) + if err := json.Unmarshal(payloadBytes, &rawPayload); err != nil { + t.Fatal(err) + } + audRaw := string(rawPayload["aud"]) + if audRaw[0] == '[' { + t.Errorf("single-element aud should be a string, got array: %s", audRaw) + } + + // Unmarshal and validate. + var got jwt.TokenClaims + if err := decoded.UnmarshalClaims(&got); err != nil { + t.Fatal(err) + } + if len(got.Aud) != 1 || got.Aud[0] != "single-aud" { + t.Errorf("aud: got %v, want [single-aud]", got.Aud) + } + + // Also test unmarshaling from a manually constructed single-string aud. + singleJSON := []byte(`"just-one"`) + var aud jwt.Listish + if err := json.Unmarshal(singleJSON, &aud); err != nil { + t.Fatalf("Unmarshal single-string aud: %v", err) + } + if len(aud) != 1 || aud[0] != "just-one" { + t.Errorf("single-string unmarshal: got %v, want [just-one]", aud) + } + + // And array form. + arrayJSON := []byte(`["a","b"]`) + var aud2 jwt.Listish + if err := json.Unmarshal(arrayJSON, &aud2); err != nil { + t.Fatalf("Unmarshal array aud: %v", err) + } + if len(aud2) != 2 || aud2[0] != "a" || aud2[1] != "b" { + t.Errorf("array unmarshal: got %v, want [a b]", aud2) + } +} + +// TestVerifyTamperedPayload confirms that a tampered payload (modified after signing) +// is rejected by signature verification. +func TestVerifyTamperedPayload(t *testing.T) { + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) + + claims := goodClaims() + token, _ := signer.SignToString(&claims) + + verifier := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"}) + + // Tamper with the payload: change the sub claim. + parts := strings.SplitN(token, ".", 3) + payloadBytes, _ := base64.RawURLEncoding.DecodeString(parts[1]) + tampered := strings.Replace(string(payloadBytes), claims.Sub, "evil-user", 1) + parts[1] = base64.RawURLEncoding.EncodeToString([]byte(tampered)) + tamperedToken := strings.Join(parts, ".") + + parsed, err := jwt.Decode(tamperedToken) + if err != nil { + t.Fatalf("Decode should succeed for well-formed tampered token: %v", err) + } + if err := verifier.Verify(parsed); err == nil { + t.Fatal("expected Verify to fail for tampered payload") + } else if !errors.Is(err, jwt.ErrSignatureInvalid) { + t.Fatalf("expected ErrSignatureInvalid, got: %v", err) + } +} + +// TestValidationErrorAnnotation verifies that time-related validation errors +// include the server time annotation. +func TestValidationErrorAnnotation(t *testing.T) { + now := time.Now() + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user-1", + Aud: jwt.Listish{"app"}, + Exp: now.Add(-time.Hour).Unix(), // expired + IAt: now.Unix(), + AuthTime: now.Unix(), + } + + v := jwt.NewIDTokenValidator( + []string{"https://example.com"}, + []string{"app"}, + nil, + 0, + ) + err := v.Validate(nil, claims, now) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), "server time") { + t.Errorf("time error should include server time annotation: %v", err) + } +} + +// TestValidateThreadsHeaderErrors verifies that errors from header validation +// (IsAllowedTyp) are preserved when threaded into Validate. +func TestValidateThreadsHeaderErrors(t *testing.T) { + now := time.Now() + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user-1", + Aud: jwt.Listish{"app"}, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + AuthTime: now.Unix(), + } + + v := jwt.NewIDTokenValidator( + []string{"https://example.com"}, + []string{"app"}, + nil, + 0, + ) + + // Simulate a header check failure. + hdr := jwt.RFCHeader{Typ: "at+jwt"} // wrong typ for ID token + var errs []error + errs = hdr.IsAllowedTyp(errs, []string{"JWT"}) + if len(errs) == 0 { + t.Fatal("expected IsAllowedTyp to produce an error") + } + + // Thread header errors into Validate - claims are valid, so only header error remains. + err := v.Validate(errs, claims, now) + if err == nil { + t.Fatal("expected error from threaded header validation") + } + if !strings.Contains(err.Error(), "typ") { + t.Errorf("expected typ error in output: %v", err) + } +} diff --git a/auth/jwt/nullbool_test.go b/auth/jwt/nullbool_test.go new file mode 100644 index 0000000..aa85cc0 --- /dev/null +++ b/auth/jwt/nullbool_test.go @@ -0,0 +1,280 @@ +package jwt_test + +import ( + "encoding/json" + "testing" + + "github.com/therootcompany/golib/auth/jwt" +) + +func TestNullBool_MarshalJSON(t *testing.T) { + tests := []struct { + name string + nb jwt.NullBool + want string + }{ + {"true", jwt.NullBool{Bool: true, Valid: true}, "true"}, + {"false", jwt.NullBool{Bool: false, Valid: true}, "false"}, + {"null (zero value)", jwt.NullBool{}, "null"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.nb) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + if string(got) != tt.want { + t.Errorf("Marshal = %s, want %s", got, tt.want) + } + }) + } +} + +func TestNullBool_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + wantValue bool + wantValid bool + }{ + {"true", "true", true, true}, + {"false", "false", false, true}, + {"null", "null", false, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var nb jwt.NullBool + if err := json.Unmarshal([]byte(tt.input), &nb); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if nb.Bool != tt.wantValue { + t.Errorf("Value = %v, want %v", nb.Bool, tt.wantValue) + } + if nb.Valid != tt.wantValid { + t.Errorf("Valid = %v, want %v", nb.Valid, tt.wantValid) + } + }) + } +} + +func TestNullBool_UnmarshalJSON_InvalidInput(t *testing.T) { + var nb jwt.NullBool + if err := json.Unmarshal([]byte(`"yes"`), &nb); err == nil { + t.Error("expected error for invalid input, got nil") + } +} + +func TestNullBool_IsZero(t *testing.T) { + tests := []struct { + name string + nb jwt.NullBool + want bool + }{ + {"zero value", jwt.NullBool{}, true}, + {"true", jwt.NullBool{Bool: true, Valid: true}, false}, + {"false", jwt.NullBool{Bool: false, Valid: true}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.nb.IsZero(); got != tt.want { + t.Errorf("IsZero() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNullBool_RoundTrip(t *testing.T) { + values := []jwt.NullBool{ + {Bool: true, Valid: true}, + {Bool: false, Valid: true}, + {Bool: false, Valid: false}, + } + for _, orig := range values { + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + var got jwt.NullBool + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if got.Bool != orig.Bool || got.Valid != orig.Valid { + t.Errorf("round-trip: got {%v, %v}, want {%v, %v}", + got.Bool, got.Valid, orig.Bool, orig.Valid) + } + } +} + +func TestNullBool_ClaimsIntegration(t *testing.T) { + t.Run("marshal with email verified true", func(t *testing.T) { + claims := jwt.StandardClaims{ + TokenClaims: jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user123", + Exp: 9999999999, + IAt: 1000000000, + }, + Email: "user@example.com", + EmailVerified: jwt.NullBool{Bool: true, Valid: true}, + } + data, err := json.Marshal(claims) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + if string(raw["email_verified"]) != "true" { + t.Errorf("email_verified = %s, want true", raw["email_verified"]) + } + }) + + t.Run("marshal with email verified false", func(t *testing.T) { + claims := jwt.StandardClaims{ + TokenClaims: jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user123", + Exp: 9999999999, + IAt: 1000000000, + }, + Email: "user@example.com", + EmailVerified: jwt.NullBool{Bool: false, Valid: true}, + } + data, err := json.Marshal(claims) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + if string(raw["email_verified"]) != "false" { + t.Errorf("email_verified = %s, want false", raw["email_verified"]) + } + }) + + t.Run("marshal omits verified when no email", func(t *testing.T) { + claims := jwt.StandardClaims{ + TokenClaims: jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user123", + Exp: 9999999999, + IAt: 1000000000, + }, + // No email, no EmailVerified -> field omitted via omitzero + } + data, err := json.Marshal(claims) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + if _, ok := raw["email_verified"]; ok { + t.Errorf("email_verified present = %s, want omitted", raw["email_verified"]) + } + if _, ok := raw["phone_number_verified"]; ok { + t.Errorf("phone_number_verified present = %s, want omitted", raw["phone_number_verified"]) + } + }) + + t.Run("unmarshal claims with verified fields", func(t *testing.T) { + input := `{ + "iss": "https://example.com", + "sub": "user123", + "exp": 9999999999, + "iat": 1000000000, + "email": "user@example.com", + "email_verified": true, + "phone_number": "+1555000000", + "phone_number_verified": false + }` + var claims jwt.StandardClaims + if err := json.Unmarshal([]byte(input), &claims); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if !claims.EmailVerified.Valid || !claims.EmailVerified.Bool { + t.Errorf("EmailVerified = {%v, %v}, want {true, true}", + claims.EmailVerified.Bool, claims.EmailVerified.Valid) + } + if !claims.PhoneNumberVerified.Valid || claims.PhoneNumberVerified.Bool { + t.Errorf("PhoneNumberVerified = {%v, %v}, want {false, true}", + claims.PhoneNumberVerified.Bool, claims.PhoneNumberVerified.Valid) + } + }) + + t.Run("unmarshal claims with null verified fields", func(t *testing.T) { + input := `{ + "iss": "https://example.com", + "sub": "user123", + "exp": 9999999999, + "iat": 1000000000, + "email_verified": null, + "phone_number_verified": null + }` + var claims jwt.StandardClaims + if err := json.Unmarshal([]byte(input), &claims); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if claims.EmailVerified.Valid { + t.Error("EmailVerified.Valid = true, want false") + } + if claims.PhoneNumberVerified.Valid { + t.Error("PhoneNumberVerified.Valid = true, want false") + } + }) + + t.Run("unmarshal claims with omitted verified fields", func(t *testing.T) { + input := `{ + "iss": "https://example.com", + "sub": "user123", + "exp": 9999999999, + "iat": 1000000000 + }` + var claims jwt.StandardClaims + if err := json.Unmarshal([]byte(input), &claims); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + // Omitted fields -> zero value: {false, false} + if claims.EmailVerified.Valid { + t.Error("EmailVerified.Valid = true, want false") + } + if claims.PhoneNumberVerified.Valid { + t.Error("PhoneNumberVerified.Valid = true, want false") + } + }) + + t.Run("round-trip claims", func(t *testing.T) { + orig := jwt.StandardClaims{ + TokenClaims: jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "user123", + Exp: 9999999999, + IAt: 1000000000, + }, + Email: "user@example.com", + EmailVerified: jwt.NullBool{Bool: true, Valid: true}, + PhoneNumber: "+1555000000", + PhoneNumberVerified: jwt.NullBool{Bool: false, Valid: true}, + } + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + var got jwt.StandardClaims + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if got.EmailVerified != orig.EmailVerified { + t.Errorf("EmailVerified = %+v, want %+v", got.EmailVerified, orig.EmailVerified) + } + if got.PhoneNumberVerified != orig.PhoneNumberVerified { + t.Errorf("PhoneNumberVerified = %+v, want %+v", got.PhoneNumberVerified, orig.PhoneNumberVerified) + } + }) +}