// Copyright 2026 AJ ONeal (https://therootcompany.com) // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. // // SPDX-License-Identifier: MPL-2.0 package jwt_test import ( "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/rsa" "encoding/base64" "encoding/json" "errors" "strings" "testing" "time" "crypto" "github.com/therootcompany/golib/auth/jwt" ) func mustPK(t testing.TB, signer crypto.Signer, kid string) *jwt.PrivateKey { t.Helper() pk, err := jwt.FromPrivateKey(signer, kid) if err != nil { t.Fatal(err) } return pk } // AppClaims embeds TokenClaims and adds application-specific fields. // // Because TokenClaims is embedded, AppClaims satisfies Claims // for free via Go's method promotion - no interface to implement. type AppClaims struct { jwt.TokenClaims Email string `json:"email"` Roles []string `json:"roles"` } // validateAppClaims is a plain function - not a method satisfying an interface. // It demonstrates the Decode+Verify pattern: custom validation logic lives here, // calling Validator.Validate and adding app-specific checks. func validateAppClaims(c AppClaims, v *jwt.Validator, now time.Time) error { var errs []error if err := v.Validate(nil, &c, now); err != nil { errs = append(errs, err) } if c.Email == "" { errs = append(errs, errors.New("missing email claim")) } return errors.Join(errs...) } func goodClaims() AppClaims { now := time.Now() return AppClaims{ TokenClaims: jwt.TokenClaims{ Iss: "https://example.com", Sub: "user123", Aud: jwt.Listish{"myapp"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), AuthTime: now.Unix(), AMR: []string{"pwd"}, JTI: "abc123", AzP: "myapp", Nonce: "nonce1", }, Email: "user@example.com", Roles: []string{"admin"}, } } // goodValidator configures the ID token validator with iss set to "https://example.com". // Iss checking is now the Validator's responsibility, not the Verifier's. func goodValidator() *jwt.Validator { return jwt.NewIDTokenValidator( []string{"https://example.com"}, []string{"myapp"}, []string{"myapp"}, 0, ) } func goodVerifier(pub jwt.PublicKey) *jwt.Verifier { v, err := jwt.NewVerifier([]jwt.PublicKey{pub}) if err != nil { panic(err) } return v } // TestRoundTrip is the primary happy path using ES256. // It demonstrates the full Verify / UnmarshalClaims / Validate flow. func TestRoundTrip(t *testing.T) { privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) if err != nil { t.Fatal(err) } claims := goodClaims() jws, err := signer.Sign(&claims) if err != nil { t.Fatal(err) } if jws.GetHeader().Alg != "ES256" { t.Fatalf("expected ES256, got %s", jws.GetHeader().Alg) } token, err := jwt.Encode(jws) if err != nil { t.Fatalf("Encode failed: %v", err) } iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"}) jws2, err := iss.VerifyJWT(token) if err != nil { t.Fatalf("VerifyJWT failed: %v", err) } if jws2.GetHeader().Alg != "ES256" { t.Errorf("expected ES256 alg in jws, got %s", jws2.GetHeader().Alg) } var decoded AppClaims if err := jws2.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims failed: %v", err) } if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { t.Fatalf("claim validation failed: %v", err) } // Direct field access - no type assertion needed. if decoded.Email != claims.Email { t.Errorf("email: got %s, want %s", decoded.Email, claims.Email) } } // TestRoundTripRS256 exercises RSA PKCS#1 v1.5 / RS256. func TestRoundTripRS256(t *testing.T) { privKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) if err != nil { t.Fatal(err) } claims := goodClaims() jws, err := signer.Sign(&claims) if err != nil { t.Fatal(err) } if jws.GetHeader().Alg != "RS256" { t.Fatalf("expected RS256, got %s", jws.GetHeader().Alg) } token, err := jwt.Encode(jws) if err != nil { t.Fatalf("Encode failed: %v", err) } iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"}) jws2, err := iss.VerifyJWT(token) if err != nil { t.Fatalf("VerifyJWT failed: %v", err) } var decoded AppClaims if err := jws2.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims failed: %v", err) } if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { t.Fatalf("claim validation failed: %v", err) } } // TestRoundTripEdDSA exercises Ed25519 / EdDSA (RFC 8037). func TestRoundTripEdDSA(t *testing.T) { pubKeyBytes, privKey, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) if err != nil { t.Fatal(err) } claims := goodClaims() jws, err := signer.Sign(&claims) if err != nil { t.Fatal(err) } if jws.GetHeader().Alg != "EdDSA" { t.Fatalf("expected EdDSA, got %s", jws.GetHeader().Alg) } token, err := jwt.Encode(jws) if err != nil { t.Fatalf("Encode failed: %v", err) } iss := goodVerifier(jwt.PublicKey{Key: pubKeyBytes, KID: "key-1"}) jws2, err := iss.VerifyJWT(token) if err != nil { t.Fatalf("VerifyJWT failed: %v", err) } var decoded AppClaims if err := jws2.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims failed: %v", err) } if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { t.Fatalf("claim validation failed: %v", err) } } // TestRoundTripES384 exercises ECDSA P-384 / ES384. func TestRoundTripES384(t *testing.T) { privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) if err != nil { t.Fatal(err) } claims := goodClaims() jws, err := signer.Sign(&claims) if err != nil { t.Fatal(err) } if jws.GetHeader().Alg != "ES384" { t.Fatalf("expected ES384, got %s", jws.GetHeader().Alg) } token, err := jwt.Encode(jws) if err != nil { t.Fatalf("Encode failed: %v", err) } iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"}) jws2, err := iss.VerifyJWT(token) if err != nil { t.Fatalf("VerifyJWT failed: %v", err) } var decoded AppClaims if err := jws2.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims failed: %v", err) } if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { t.Fatalf("claim validation failed: %v", err) } } // TestRoundTripES512 exercises ECDSA P-521 / ES512. func TestRoundTripES512(t *testing.T) { privKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "key-1")}) if err != nil { t.Fatal(err) } claims := goodClaims() jws, err := signer.Sign(&claims) if err != nil { t.Fatal(err) } if jws.GetHeader().Alg != "ES512" { t.Fatalf("expected ES512, got %s", jws.GetHeader().Alg) } token, err := jwt.Encode(jws) if err != nil { t.Fatalf("Encode failed: %v", err) } iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "key-1"}) jws2, err := iss.VerifyJWT(token) if err != nil { t.Fatalf("VerifyJWT failed: %v", err) } var decoded AppClaims if err := jws2.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims failed: %v", err) } if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { t.Fatalf("claim validation failed: %v", err) } } // TestDecodeVerifyFlow demonstrates the Decode + Verify + custom validation pattern. // The caller owns the full validation pipeline. func TestDecodeVerifyFlow(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) claims := goodClaims() token, _ := signer.SignToString(&claims) iss, _ := jwt.NewVerifier([]jwt.PublicKey{{Key: &privKey.PublicKey, KID: "k"}}) jws2, err := jwt.Decode(token) if err != nil { t.Fatalf("Decode failed: %v", err) } if err := iss.Verify(jws2); err != nil { t.Fatalf("Verify failed: %v", err) } var decoded AppClaims if err := jws2.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims failed: %v", err) } if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { t.Fatalf("Validate failed: %v", err) } } // TestDecodeReturnsParsedOnSigFailure verifies that Decode returns a non-nil // *StandardJWS even when the token will later fail signature verification. // Callers can inspect the header (kid, alg) for routing before calling Verify. func TestDecodeReturnsParsedOnSigFailure(t *testing.T) { signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, signingKey, "k")}) claims := goodClaims() token, _ := signer.SignToString(&claims) // Verifier has wrong public key - sig verification will fail. iss, _ := jwt.NewVerifier([]jwt.PublicKey{{Key: &wrongKey.PublicKey, KID: "k"}}) // Decode always succeeds for well-formed tokens. result, err := jwt.Decode(token) if err != nil { t.Fatalf("Decode failed: %v", err) } if result == nil { t.Fatal("Decode should return non-nil StandardJWS") } if result.GetHeader().KID != "k" { t.Errorf("expected kid %q, got %q", "k", result.GetHeader().KID) } // Verify should fail with the wrong key. if err := iss.Verify(result); err == nil { t.Fatal("expected Verify to fail with wrong key") } } // TestCustomValidation demonstrates that Validator.Validate is called // explicitly and custom fields are validated alongside the standard ones. func TestCustomValidation(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) // Token with empty Email - our custom validator should reject it. signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) claims := goodClaims() claims.Email = "" token, _ := signer.SignToString(&claims) iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"}) jws2, err := jwt.Decode(token) if err != nil { t.Fatalf("Decode failed: %v", err) } if err := iss.Verify(jws2); err != nil { t.Fatalf("Verify failed unexpectedly: %v", err) } var decoded AppClaims _ = jws2.UnmarshalClaims(&decoded) err = validateAppClaims(decoded, goodValidator(), time.Now()) if err == nil { t.Fatal("expected validation to fail: email is empty") } if !strings.Contains(err.Error(), "missing email claim") { t.Fatalf("expected 'missing email claim' in error: %v", err) } } // TestNBFValidation confirms that a token with nbf in the future is rejected, // and that a token with nbf in the past (or absent) is accepted. func TestNBFValidation(t *testing.T) { now := time.Now() base := AppClaims{ TokenClaims: jwt.TokenClaims{ Iss: "https://example.com", Aud: jwt.Listish{"myapp"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), }, } v := &jwt.Validator{ Checks: jwt.CheckIss | jwt.CheckAud | jwt.CheckExp | jwt.CheckIAt | jwt.CheckNBf, Iss: []string{"https://example.com"}, Aud: []string{"myapp"}, } // No nbf: should pass. if err := v.Validate(nil, &base, now); err != nil { t.Fatalf("expected no error without nbf: %v", err) } // nbf in the past: should pass. pastNBF := base pastNBF.NBf = now.Add(-time.Hour).Unix() if err := v.Validate(nil, &pastNBF, now); err != nil { t.Fatalf("expected no error with past nbf: %v", err) } // nbf in the future: must be rejected. futureNBF := base futureNBF.NBf = now.Add(time.Hour).Unix() err := v.Validate(nil, &futureNBF, now) if err == nil { t.Fatal("expected error for future nbf") } if !errors.Is(err, jwt.ErrBeforeNBf) { t.Fatalf("expected ErrBeforeNBf, got: %v", err) } } // TestVerifyWithoutValidation confirms that Verify + UnmarshalClaims succeeds // independently of claim validation - the caller decides whether to validate. func TestVerifyWithoutValidation(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) c := goodClaims() token, _ := signer.SignToString(&c) iss, _ := jwt.NewVerifier([]jwt.PublicKey{{Key: &privKey.PublicKey, KID: "k"}}) jws2, err := iss.VerifyJWT(token) if err != nil { t.Fatalf("VerifyJWT failed: %v", err) } var claims AppClaims if err := jws2.UnmarshalClaims(&claims); err != nil { t.Fatalf("UnmarshalClaims failed: %v", err) } if claims.Email != c.Email { t.Errorf("claims not unmarshalled: email got %q, want %q", claims.Email, c.Email) } } // TestVerifierWrongKey confirms that a different key's public key is rejected. func TestVerifierWrongKey(t *testing.T) { signingKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, signingKey, "k")}) claims := goodClaims() token, _ := signer.SignToString(&claims) iss := goodVerifier(jwt.PublicKey{Key: &wrongKey.PublicKey, KID: "k"}) parsed, err := jwt.Decode(token) if err != nil { t.Fatal(err) } if err := iss.Verify(parsed); !errors.Is(err, jwt.ErrSignatureInvalid) { t.Fatalf("expected ErrSignatureInvalid, got: %v", err) } } // TestVerifierUnknownKid confirms that an unknown kid is rejected. func TestVerifierUnknownKid(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "unknown-kid")}) claims := goodClaims() token, _ := signer.SignToString(&claims) iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "known-kid"}) parsed, err := jwt.Decode(token) if err != nil { t.Fatal(err) } if err := iss.Verify(parsed); !errors.Is(err, jwt.ErrUnknownKID) { t.Fatalf("expected ErrUnknownKID, got: %v", err) } } // TestVerifierIssMismatch confirms that a token with a mismatched iss is caught // by the Validator, not the Verifier. Signature verification succeeds; the iss // mismatch appears as a soft validation error. func TestVerifierIssMismatch(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) claims := goodClaims() claims.Iss = "https://evil.example.com" token, _ := signer.SignToString(&claims) iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"}) // Decode+Verify succeeds - iss is not checked at the Verifier level. parsed, err := jwt.Decode(token) if err != nil { t.Fatal(err) } if err := iss.Verify(parsed); err != nil { t.Fatalf("Verify should succeed (no iss check): %v", err) } // VerifyJWT + Validate: signature passes but iss validation catches the mismatch. jws2, err := iss.VerifyJWT(token) if err != nil { t.Fatalf("unexpected hard error from VerifyJWT: %v", err) } var decoded AppClaims if err := jws2.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims failed: %v", err) } err = goodValidator().Validate(nil, &decoded, time.Now()) if err == nil { t.Fatal("expected validation errors for iss mismatch") } if !errors.Is(err, jwt.ErrInvalidClaim) { t.Fatalf("expected ErrInvalidClaim for iss mismatch, got: %v", err) } } // TestVerifyTamperedAlg confirms that a tampered alg header ("none") is rejected. // The token is reconstructed with a replaced protected header; the original // ES256 signature is kept, making the signing input mismatch detectable. func TestVerifyTamperedAlg(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) claims := goodClaims() token, _ := signer.SignToString(&claims) iss := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"}) // Replace the protected header with one that has alg:"none". // The original ES256 signature stays - the signing input will mismatch. noneHeader := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","kid":"k","typ":"JWT"}`)) parts := strings.SplitN(token, ".", 3) tamperedToken := noneHeader + "." + parts[1] + "." + parts[2] parsed, err := jwt.Decode(tamperedToken) if err != nil { t.Fatal(err) } if err := iss.Verify(parsed); !errors.Is(err, jwt.ErrUnsupportedAlg) { t.Fatalf("expected ErrUnsupportedAlg for tampered alg, got: %v", err) } } // TestSignerRoundTrip verifies the Signer / Sign / Verifier / Verify / Validate flow. func TestSignerRoundTrip(t *testing.T) { privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k1")}) if err != nil { t.Fatal(err) } claims := goodClaims() tokenStr, err := signer.SignToString(&claims) if err != nil { t.Fatal(err) } iss := signer.Verifier() jws, err := iss.VerifyJWT(tokenStr) if err != nil { t.Fatalf("VerifyJWT failed: %v", err) } var decoded AppClaims if err := jws.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims failed: %v", err) } if err := goodValidator().Validate(nil, &decoded, time.Now()); err != nil { t.Fatalf("claim validation failed: %v", err) } if decoded.Email != claims.Email { t.Errorf("email: got %s, want %s", decoded.Email, claims.Email) } } // TestSignerAutoKID verifies that KID is auto-computed from the key thumbprint // when PrivateKey.KID is empty. func TestSignerAutoKID(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "")}) if err != nil { t.Fatal(err) } keys := signer.Keys if len(keys) != 1 { t.Fatalf("expected 1 key, got %d", len(keys)) } if keys[0].KID == "" { t.Fatal("KID should be auto-computed from thumbprint") } // Token should verify with the auto-KID issuer. iss := signer.Verifier() claims := goodClaims() tokenStr, _ := signer.SignToString(&claims) parsed, err := jwt.Decode(tokenStr) if err != nil { t.Fatal(err) } if err := iss.Verify(parsed); err != nil { t.Fatalf("Verify failed: %v", err) } } // TestSignerRoundRobin verifies that signing round-robins across keys and that // all resulting tokens verify with the combined Verifier. func TestSignerRoundRobin(t *testing.T) { key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, err := jwt.NewSigner([]*jwt.PrivateKey{ mustPK(t, key1, "k1"), mustPK(t, key2, "k2"), }) if err != nil { t.Fatal(err) } iss := signer.Verifier() v := goodValidator() for i := range 4 { claims := goodClaims() tokenStr, err := signer.SignToString(&claims) if err != nil { t.Fatalf("Sign[%d] failed: %v", i, err) } jws, err := iss.VerifyJWT(tokenStr) if err != nil { t.Fatalf("VerifyJWT[%d] failed: %v", i, err) } var decoded AppClaims if err := jws.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims[%d] failed: %v", i, err) } if err := v.Validate(nil, &decoded, time.Now()); err != nil { t.Fatalf("Validate[%d] failed: %v", i, err) } } } // TestSignJWTSelectsByKID verifies that when the header already has a KID, // SignJWT uses that specific key instead of round-robin. func TestSignJWTSelectsByKID(t *testing.T) { key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key2, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) signer, err := jwt.NewSigner([]*jwt.PrivateKey{ mustPK(t, key1, "ec256"), mustPK(t, key2, "ec384"), }) if err != nil { t.Fatal(err) } // Request signing with key2 specifically by setting KID in the header. claims := &jwt.TokenClaims{ Iss: "https://example.com", Sub: "user-1", Aud: jwt.Listish{"app"}, Exp: time.Now().Add(time.Hour).Unix(), IAt: time.Now().Unix(), } jws, err := jwt.New(claims) if err != nil { t.Fatal(err) } jws.SetTyp("JWT") // Pre-set the KID to select key2. hdr := jws.GetHeader() hdr.KID = "ec384" if err := jws.SetHeader(&hdr); err != nil { t.Fatal(err) } if err := signer.SignJWT(jws); err != nil { t.Fatal(err) } // Should have used ES384, not ES256. if got := jws.GetHeader().Alg; got != "ES384" { t.Fatalf("alg: got %s, want ES384 (should have selected key2)", got) } if got := jws.GetHeader().KID; got != "ec384" { t.Fatalf("kid: got %s, want ec384", got) } // Verify round-trip. verifier := signer.Verifier() tokenStr, err := jwt.Encode(jws) if err != nil { t.Fatal(err) } parsed, err := jwt.Decode(tokenStr) if err != nil { t.Fatal(err) } if err := verifier.Verify(parsed); err != nil { t.Fatalf("Verify: %v", err) } } // TestSignJWTUnknownKID verifies that SignJWT returns ErrUnknownKID when the // header requests a KID that the signer doesn't have. func TestSignJWTUnknownKID(t *testing.T) { key1, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, key1, "k1")}) claims := &jwt.TokenClaims{ Iss: "https://example.com", Sub: "user-1", Aud: jwt.Listish{"app"}, Exp: time.Now().Add(time.Hour).Unix(), IAt: time.Now().Unix(), } jws, _ := jwt.New(claims) hdr := jws.GetHeader() hdr.KID = "nonexistent" _ = jws.SetHeader(&hdr) err := signer.SignJWT(jws) if !errors.Is(err, jwt.ErrUnknownKID) { t.Fatalf("expected ErrUnknownKID, got: %v", err) } } // TestJWKsRoundTrip verifies JWKS serialization and round-trip parsing. func TestJWKsRoundTrip(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k1")}) if err != nil { t.Fatal(err) } jwksBytes, err := json.Marshal(signer) if err != nil { t.Fatal(err) } // Round-trip: parse the JWKS JSON and verify it produces a working Verifier. var jwks jwt.WellKnownJWKs if err := json.Unmarshal(jwksBytes, &jwks); err != nil { t.Fatal(err) } keys := jwks.Keys if len(keys) != 1 { t.Fatalf("expected 1 key, got %d", len(keys)) } if keys[0].KID != "k1" { t.Errorf("expected kid 'k1', got %q", keys[0].KID) } iss2, _ := jwt.NewVerifier(keys) claims := goodClaims() tokenStr, _ := signer.SignToString(&claims) parsed, err := jwt.Decode(tokenStr) if err != nil { t.Fatal(err) } if err := iss2.Verify(parsed); err != nil { t.Fatalf("Verify on round-tripped JWKS failed: %v", err) } } // TestKeyType verifies that KeyType returns the correct JWK kty string for each key type. func TestKeyType(t *testing.T) { ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) edPub, _, _ := ed25519.GenerateKey(rand.Reader) tests := []struct { name string key jwt.PublicKey wantKty string }{ {"EC P-256", jwt.PublicKey{Key: &ecKey.PublicKey}, "EC"}, {"RSA 2048", jwt.PublicKey{Key: &rsaKey.PublicKey}, "RSA"}, {"Ed25519", jwt.PublicKey{Key: edPub}, "OKP"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := tt.key.KeyType(); got != tt.wantKty { t.Errorf("KeyType() = %q, want %q", got, tt.wantKty) } }) } } // TestPublicKeyOps verifies that PrivateKey.PublicKey() translates key_ops to their // public-key counterparts ("sign"=>"verify", "decrypt"=>"encrypt", "unwrapKey"=>"wrapKey"). func TestPublicKeyOps(t *testing.T) { tests := []struct { name string privateOps []string wantOps []string }{ {"sign=>verify", []string{"sign"}, []string{"verify"}}, {"decrypt=>encrypt", []string{"decrypt"}, []string{"encrypt"}}, {"unwrapKey=>wrapKey", []string{"unwrapKey"}, []string{"wrapKey"}}, {"multiple", []string{"sign", "decrypt"}, []string{"verify", "encrypt"}}, {"public op passthrough", []string{"verify"}, []string{"verify"}}, {"no public equivalent dropped", []string{"deriveKey"}, nil}, {"empty", nil, nil}, } base, err := jwt.NewPrivateKey() if err != nil { t.Fatal(err) } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pk := *base pk.KeyOps = tt.privateOps pub, err := pk.PublicKey() if err != nil { t.Fatal(err) } if len(pub.KeyOps) != len(tt.wantOps) { t.Fatalf("KeyOps = %v, want %v", pub.KeyOps, tt.wantOps) } for i, op := range pub.KeyOps { if op != tt.wantOps[i] { t.Errorf("KeyOps[%d] = %q, want %q", i, op, tt.wantOps[i]) } } }) } } // TestDecodePublicJWKJSON verifies JWKS JSON parsing with real base64url-encoded // key material from RFC 7517 / OIDC examples. func TestDecodePublicJWKJSON(t *testing.T) { jwksJSON := []byte(`{"keys":[ {"kty":"EC","crv":"P-256", "x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", "y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", "kid":"ec-256","use":"sig"}, {"kty":"RSA", "n":"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", "e":"AQAB","kid":"rsa-2048","use":"sig"} ]}`) var jwks jwt.WellKnownJWKs if err := json.Unmarshal(jwksJSON, &jwks); err != nil { t.Fatal(err) } keys := jwks.Keys if len(keys) != 2 { t.Fatalf("expected 2 keys, got %d", len(keys)) } var ecCount, rsaCount int for _, k := range keys { switch k.KeyType() { case "EC": ecCount++ if k.KID != "ec-256" { t.Errorf("unexpected EC kid: %s", k.KID) } case "RSA": rsaCount++ if k.KID != "rsa-2048" { t.Errorf("unexpected RSA kid: %s", k.KID) } } } if ecCount != 1 { t.Errorf("expected 1 EC key, got %d", ecCount) } if rsaCount != 1 { t.Errorf("expected 1 RSA key, got %d", rsaCount) } } // TestThumbprint verifies that Thumbprint returns a non-empty base64url string // for EC, RSA, and Ed25519 keys, and that two equal keys produce the same thumbprint. func TestThumbprint(t *testing.T) { ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) edPub, _, _ := ed25519.GenerateKey(rand.Reader) tests := []struct { name string pub jwt.PublicKey }{ {"EC P-256", jwt.PublicKey{Key: &ecKey.PublicKey}}, {"RSA 2048", jwt.PublicKey{Key: &rsaKey.PublicKey}}, {"Ed25519", jwt.PublicKey{Key: edPub}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { thumb, err := tt.pub.Thumbprint() if err != nil { t.Fatalf("Thumbprint() error: %v", err) } if thumb == "" { t.Fatal("Thumbprint() returned empty string") } // Must be valid base64url (no padding, no +/) if strings.Contains(thumb, "+") || strings.Contains(thumb, "/") || strings.Contains(thumb, "=") { t.Errorf("Thumbprint() contains non-base64url characters: %s", thumb) } // Same key, same thumbprint thumb2, _ := tt.pub.Thumbprint() if thumb != thumb2 { t.Errorf("Thumbprint() not deterministic: %s != %s", thumb, thumb2) } }) } } // TestNoKidAutoThumbprint verifies that a JWKS key without a "kid" field gets // its KID auto-populated from the RFC 7638 thumbprint. func TestNoKidAutoThumbprint(t *testing.T) { // EC key with no "kid" field in the JWKS jwksJSON := []byte(`{"keys":[ {"kty":"EC","crv":"P-256", "x":"MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4", "y":"4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM", "use":"sig"} ]}`) var jwks jwt.WellKnownJWKs if err := json.Unmarshal(jwksJSON, &jwks); err != nil { t.Fatal(err) } keys := jwks.Keys if len(keys) != 1 { t.Fatalf("expected 1 key, got %d", len(keys)) } if keys[0].KID == "" { t.Fatal("KID should be auto-populated from Thumbprint when absent in JWKS") } // The auto-KID should be a valid base64url string. kid := keys[0].KID if strings.Contains(kid, "+") || strings.Contains(kid, "/") || strings.Contains(kid, "=") { t.Errorf("auto-KID contains non-base64url characters: %s", kid) } // Round-trip: compute Thumbprint directly and compare. thumb, err := keys[0].Thumbprint() if err != nil { t.Fatalf("Thumbprint() error: %v", err) } if kid != thumb { t.Errorf("auto-KID %q != direct Thumbprint %q", kid, thumb) } } // TestNewPrivateKey verifies that jwt.NewPrivateKey generates an Ed25519 key // with a non-empty KID auto-derived from the thumbprint, and that the key // works end-to-end for signing and verification. func TestNewPrivateKey(t *testing.T) { pk, err := jwt.NewPrivateKey() if err != nil { t.Fatalf("NewPrivateKey() error: %v", err) } if pk.KID == "" { t.Fatal("NewPrivateKey() returned empty KID") } // KID must be base64url (no +, /, or =). if strings.Contains(pk.KID, "+") || strings.Contains(pk.KID, "/") || strings.Contains(pk.KID, "=") { t.Errorf("KID contains non-base64url characters: %s", pk.KID) } // Two calls must produce different keys but always produce valid base64url KIDs. pk2, _ := jwt.NewPrivateKey() if pk.KID == pk2.KID { t.Error("NewPrivateKey() produced identical KIDs for two different keys") } // Full sign+verify round-trip with the generated key. signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatalf("NewSigner() error: %v", err) } claims := goodClaims() tokenStr, err := signer.SignToString(&claims) if err != nil { t.Fatalf("SignToString() error: %v", err) } iss := signer.Verifier() jws, err := iss.VerifyJWT(tokenStr) if err != nil { t.Fatalf("VerifyJWT() error: %v", err) } var decoded AppClaims if err := jws.UnmarshalClaims(&decoded); err != nil { t.Fatalf("UnmarshalClaims() error: %v", err) } if decoded.Sub != claims.Sub { t.Errorf("sub: got %q, want %q", decoded.Sub, claims.Sub) } } // --- DecodeRaw + UnmarshalHeader tests --- func TestDecodeRaw(t *testing.T) { // Sign a real token to get a valid compact string. pk, err := jwt.NewPrivateKey() if err != nil { t.Fatalf("NewPrivateKey() error: %v", err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatalf("NewSigner() error: %v", err) } claims := goodClaims() tokenStr, err := signer.SignToString(&claims) if err != nil { t.Fatalf("SignToString() error: %v", err) } raw, err := jwt.DecodeRaw(tokenStr) if err != nil { t.Fatalf("DecodeRaw() error: %v", err) } // protected and payload should be non-empty base64url segments. if len(raw.GetProtected()) == 0 { t.Error("GetProtected() is empty") } if len(raw.GetPayload()) == 0 { t.Error("GetPayload() is empty") } if len(raw.GetSignature()) == 0 { t.Error("GetSignature() is empty") } } func TestDecodeRawErrors(t *testing.T) { tests := []struct { name string input string sentinel error }{ {"empty string", "", jwt.ErrMalformedToken}, {"one segment", "abc", jwt.ErrMalformedToken}, {"two segments", "abc.def", jwt.ErrMalformedToken}, {"four segments", "a.b.c.d", jwt.ErrMalformedToken}, {"bad signature base64", "eyJhbGciOiJFZERTQSJ9.eyJpc3MiOiJ4In0.!!!bad!!!", jwt.ErrSignatureInvalid}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { _, err := jwt.DecodeRaw(tc.input) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, tc.sentinel) { t.Errorf("expected %v, got: %v", tc.sentinel, err) } }) } } func TestDecodeRawSegmentCount(t *testing.T) { _, err := jwt.DecodeRaw("") if err == nil { t.Fatal("expected error") } // Empty string normalizes to 0 segments. if !strings.Contains(err.Error(), "got 0") { t.Errorf("expected segment count 0 in error, got: %v", err) } _, err = jwt.DecodeRaw("a.b") if err == nil { t.Fatal("expected error") } if !strings.Contains(err.Error(), "got 2") { t.Errorf("expected segment count 2 in error, got: %v", err) } } func TestUnmarshalHeader(t *testing.T) { // Sign a token and use DecodeRaw + UnmarshalHeader to recover the header. pk, err := jwt.NewPrivateKey() if err != nil { t.Fatalf("NewPrivateKey() error: %v", err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatalf("NewSigner() error: %v", err) } claims := goodClaims() tokenStr, err := signer.SignToString(&claims) if err != nil { t.Fatalf("SignToString() error: %v", err) } raw, err := jwt.DecodeRaw(tokenStr) if err != nil { t.Fatalf("DecodeRaw() error: %v", err) } var h jwt.RFCHeader if err := raw.UnmarshalHeader(&h); err != nil { t.Fatalf("UnmarshalHeader() error: %v", err) } if h.Alg != "EdDSA" { t.Errorf("alg: got %q, want %q", h.Alg, "EdDSA") } if h.KID != pk.KID { t.Errorf("kid: got %q, want %q", h.KID, pk.KID) } if h.Typ != "JWT" { t.Errorf("typ: got %q, want %q", h.Typ, "JWT") } } func TestUnmarshalHeaderCustomFields(t *testing.T) { // Build a token whose header has a custom "nonce" field by constructing // the compact string manually: base64(header).base64(payload).base64(sig). type CustomHeader struct { jwt.RFCHeader Nonce string `json:"nonce"` } hdr := CustomHeader{ RFCHeader: jwt.RFCHeader{Alg: "EdDSA", KID: "test-key", Typ: "dpop+jwt"}, Nonce: "server-nonce-42", } hdrJSON, _ := json.Marshal(hdr) payJSON := []byte(`{"sub":"user"}`) fakeSig := []byte{0xDE, 0xAD} compact := base64.RawURLEncoding.EncodeToString(hdrJSON) + "." + base64.RawURLEncoding.EncodeToString(payJSON) + "." + base64.RawURLEncoding.EncodeToString(fakeSig) raw, err := jwt.DecodeRaw(compact) if err != nil { t.Fatalf("DecodeRaw() error: %v", err) } var got CustomHeader if err := raw.UnmarshalHeader(&got); err != nil { t.Fatalf("UnmarshalHeader() error: %v", err) } if got.Nonce != "server-nonce-42" { t.Errorf("nonce: got %q, want %q", got.Nonce, "server-nonce-42") } if got.Alg != "EdDSA" { t.Errorf("alg: got %q, want %q", got.Alg, "EdDSA") } if got.Typ != "dpop+jwt" { t.Errorf("typ: got %q, want %q", got.Typ, "dpop+jwt") } } func TestUnmarshalHeaderViaJWS(t *testing.T) { // Verify that UnmarshalHeader is promoted from RawJWT to *JWT. pk, err := jwt.NewPrivateKey() if err != nil { t.Fatalf("NewPrivateKey() error: %v", err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatalf("NewSigner() error: %v", err) } claims := goodClaims() tokenStr, err := signer.SignToString(&claims) if err != nil { t.Fatalf("SignToString() error: %v", err) } // Use Decode (not DecodeRaw) - UnmarshalHeader should still work via promotion. jws, err := jwt.Decode(tokenStr) if err != nil { t.Fatalf("Decode() error: %v", err) } var h jwt.RFCHeader if err := jws.UnmarshalHeader(&h); err != nil { t.Fatalf("jws.UnmarshalHeader() error: %v", err) } if h.Alg != "EdDSA" { t.Errorf("alg: got %q, want %q", h.Alg, "EdDSA") } } // --- SpaceDelimited tests --- func TestSpaceDelimitedMarshalJSON(t *testing.T) { tests := []struct { name string in jwt.SpaceDelimited want string }{ {"multiple", jwt.SpaceDelimited{"openid", "profile", "email"}, `"openid profile email"`}, {"single", jwt.SpaceDelimited{"openid"}, `"openid"`}, {"empty", jwt.SpaceDelimited{}, `""`}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := json.Marshal(tt.in) if err != nil { t.Fatalf("Marshal error: %v", err) } if string(got) != tt.want { t.Errorf("got %s, want %s", got, tt.want) } }) } } func TestSpaceDelimitedUnmarshalJSON(t *testing.T) { tests := []struct { name string in string want jwt.SpaceDelimited wantNil bool }{ {"multiple", `"openid profile email"`, jwt.SpaceDelimited{"openid", "profile", "email"}, false}, {"single", `"openid"`, jwt.SpaceDelimited{"openid"}, false}, {"empty", `""`, jwt.SpaceDelimited{}, false}, {"extra whitespace", `"openid profile\temail"`, jwt.SpaceDelimited{"openid", "profile", "email"}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got jwt.SpaceDelimited if err := json.Unmarshal([]byte(tt.in), &got); err != nil { t.Fatalf("Unmarshal error: %v", err) } if tt.wantNil { if got != nil { t.Errorf("got %v, want nil", got) } return } if len(got) != len(tt.want) { t.Fatalf("got %v (len %d), want %v (len %d)", got, len(got), tt.want, len(tt.want)) } for i := range got { if got[i] != tt.want[i] { t.Errorf("got[%d] = %q, want %q", i, got[i], tt.want[i]) } } }) } } func TestSpaceDelimitedRoundTrip(t *testing.T) { type claims struct { Scope jwt.SpaceDelimited `json:"scope,omitempty"` } orig := claims{Scope: jwt.SpaceDelimited{"openid", "profile"}} data, err := json.Marshal(orig) if err != nil { t.Fatal(err) } // Verify wire format is a space-separated string. if !strings.Contains(string(data), `"openid profile"`) { t.Fatalf("expected space-separated scope in JSON, got %s", data) } var decoded claims if err := json.Unmarshal(data, &decoded); err != nil { t.Fatal(err) } if len(decoded.Scope) != 2 || decoded.Scope[0] != "openid" || decoded.Scope[1] != "profile" { t.Errorf("round-trip failed: got %v", decoded.Scope) } } // --- SetTyp / NewAccessToken tests --- func TestSetTyp(t *testing.T) { claims := &jwt.TokenClaims{ Iss: "https://example.com", Sub: "user1", Aud: jwt.Listish{"api"}, Exp: time.Now().Add(time.Hour).Unix(), IAt: time.Now().Unix(), } jws, err := jwt.New(claims) if err != nil { t.Fatal(err) } // Default typ is "JWT". if got := jws.GetHeader().Typ; got != "JWT" { t.Fatalf("default typ: got %q, want %q", got, "JWT") } jws.SetTyp(jwt.AccessTokenTyp) if got := jws.GetHeader().Typ; got != jwt.AccessTokenTyp { t.Fatalf("after SetTyp: got %q, want %q", got, jwt.AccessTokenTyp) } } func TestSetTypSurvivesSigning(t *testing.T) { _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "")}) if err != nil { t.Fatal(err) } now := time.Now() claims := &jwt.TokenClaims{ Iss: "https://auth.example.com", Sub: "user1", Aud: jwt.Listish{"https://api.example.com"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), JTI: "tok-001", ClientID: "webapp", Scope: jwt.SpaceDelimited{"openid", "profile"}, } jws, err := jwt.NewAccessToken(claims) if err != nil { t.Fatal(err) } if err := signer.SignJWT(jws); err != nil { t.Fatal(err) } // Verify typ survived signing. if got := jws.GetHeader().Typ; got != jwt.AccessTokenTyp { t.Fatalf("typ after signing: got %q, want %q", got, jwt.AccessTokenTyp) } // Decode the token and verify typ is in the wire format. token, err := jwt.Encode(jws) if err != nil { t.Fatalf("Encode failed: %v", err) } decoded, err := jwt.Decode(token) if err != nil { t.Fatal(err) } if got := decoded.GetHeader().Typ; got != jwt.AccessTokenTyp { t.Fatalf("typ after decode: got %q, want %q", got, jwt.AccessTokenTyp) } // Verify claims round-trip. var rt jwt.TokenClaims if err := decoded.UnmarshalClaims(&rt); err != nil { t.Fatal(err) } if rt.ClientID != "webapp" { t.Errorf("client_id: got %q, want %q", rt.ClientID, "webapp") } if len(rt.Scope) != 2 || rt.Scope[0] != "openid" { t.Errorf("scope: got %v, want [openid profile]", rt.Scope) } } // --- Access token Validator tests --- func goodAccessTokenClaims() *jwt.TokenClaims { now := time.Now() return &jwt.TokenClaims{ Iss: "https://auth.example.com", Sub: "user1", Aud: jwt.Listish{"https://api.example.com"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), JTI: "tok-001", ClientID: "webapp", Scope: jwt.SpaceDelimited{"openid", "profile"}, AMR: []string{"pwd"}, } } func TestAccessTokenValidatorHappyPath(t *testing.T) { v := jwt.NewAccessTokenValidator( []string{"https://auth.example.com"}, []string{"https://api.example.com"}, 0, ) claims := goodAccessTokenClaims() if err := v.Validate(nil, claims, time.Now()); err != nil { t.Fatalf("valid access token rejected: %v", err) } } func TestAccessTokenValidatorRequiresJTI(t *testing.T) { v := jwt.NewAccessTokenValidator( []string{"https://auth.example.com"}, []string{"https://api.example.com"}, 0, ) claims := goodAccessTokenClaims() claims.JTI = "" err := v.Validate(nil, claims, time.Now()) if !errors.Is(err, jwt.ErrMissingClaim) { t.Fatalf("expected ErrMissingClaim for missing jti, got: %v", err) } } func TestAccessTokenValidatorRequiresClientID(t *testing.T) { v := jwt.NewAccessTokenValidator( []string{"https://auth.example.com"}, []string{"https://api.example.com"}, 0, ) claims := goodAccessTokenClaims() claims.ClientID = "" err := v.Validate(nil, claims, time.Now()) if !errors.Is(err, jwt.ErrMissingClaim) { t.Fatalf("expected ErrMissingClaim for missing client_id, got: %v", err) } } func TestAccessTokenValidatorDisableClientID(t *testing.T) { v := jwt.NewAccessTokenValidator( []string{"https://auth.example.com"}, []string{"https://api.example.com"}, 0, ) v.Checks &^= jwt.CheckClientID claims := goodAccessTokenClaims() claims.ClientID = "" if err := v.Validate(nil, claims, time.Now()); err != nil { t.Fatalf("disabling CheckClientID should accept empty client_id: %v", err) } } func TestAccessTokenValidatorExpiredToken(t *testing.T) { v := jwt.NewAccessTokenValidator( []string{"https://auth.example.com"}, []string{"https://api.example.com"}, 0, ) claims := goodAccessTokenClaims() claims.Exp = time.Now().Add(-time.Hour).Unix() err := v.Validate(nil, claims, time.Now()) if !errors.Is(err, jwt.ErrAfterExp) { t.Fatalf("expected ErrAfterExp, got: %v", err) } } func TestAccessTokenValidatorRequiredScopes(t *testing.T) { v := jwt.NewAccessTokenValidator( []string{"https://auth.example.com"}, []string{"https://api.example.com"}, 0, ) v.RequiredScopes = []string{"openid", "admin"} claims := goodAccessTokenClaims() // claims has ["openid", "profile"] - missing "admin" err := v.Validate(nil, claims, time.Now()) if !errors.Is(err, jwt.ErrInsufficientScope) { t.Fatalf("expected ErrInsufficientScope for missing scope, got: %v", err) } } func TestAccessTokenValidatorExpectScope(t *testing.T) { v := jwt.NewAccessTokenValidator( []string{"https://auth.example.com"}, []string{"https://api.example.com"}, 0, ) v.Checks |= jwt.CheckScope // enable scope presence check claims := goodAccessTokenClaims() claims.Scope = nil err := v.Validate(nil, claims, time.Now()) if !errors.Is(err, jwt.ErrMissingClaim) { t.Fatalf("expected ErrMissingClaim for empty scope, got: %v", err) } } func TestAccessTokenValidatorDisableJTI(t *testing.T) { v := jwt.NewAccessTokenValidator( []string{"https://auth.example.com"}, []string{"https://api.example.com"}, 0, ) v.Checks &^= jwt.CheckJTI claims := goodAccessTokenClaims() claims.JTI = "" if err := v.Validate(nil, claims, time.Now()); err != nil { t.Fatalf("disabling CheckJTI should accept empty jti: %v", err) } } // --- Encode validation tests --- // stubJWT is a minimal VerifiableJWT for testing Encode validation. type stubJWT struct { protected []byte payload []byte signature []byte header jwt.RFCHeader } func (s *stubJWT) GetProtected() []byte { return s.protected } func (s *stubJWT) GetPayload() []byte { return s.payload } func (s *stubJWT) GetSignature() []byte { return s.signature } func (s *stubJWT) GetHeader() jwt.RFCHeader { return s.header } // TestEncodeRejectsEmptyAlg verifies that Encode returns an error // when the alg header field is empty (unsigned token). func TestEncodeRejectsEmptyAlg(t *testing.T) { // Zero-value stub: no alg set. jws := &stubJWT{} _, err := jwt.Encode(jws) if err == nil { t.Fatal("expected error for empty alg") } if !errors.Is(err, jwt.ErrInvalidHeader) { t.Fatalf("expected ErrInvalidHeader, got: %v", err) } if !strings.Contains(err.Error(), "alg is empty") { t.Fatalf("unexpected message: %v", err) } // Explicit header with typ but no alg. jws2 := &stubJWT{ protected: []byte(base64.RawURLEncoding.EncodeToString([]byte(`{"typ":"JWT"}`))), payload: []byte(base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"user1"}`))), signature: []byte{0x01, 0x02, 0x03}, header: jwt.RFCHeader{Typ: "JWT"}, } _, err = jwt.Encode(jws2) if err == nil { t.Fatal("expected error for empty alg with typ-only header") } if !errors.Is(err, jwt.ErrInvalidHeader) { t.Fatalf("expected ErrInvalidHeader, got: %v", err) } } // TestEncodeSucceedsAfterSigning verifies the happy path: a signed JWT // encodes without error. func TestEncodeSucceedsAfterSigning(t *testing.T) { privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k1")}) if err != nil { t.Fatal(err) } claims := goodClaims() jws, err := signer.Sign(&claims) if err != nil { t.Fatal(err) } token, err := jwt.Encode(jws) if err != nil { t.Fatalf("Encode failed on signed JWT: %v", err) } parts := strings.Split(token, ".") if len(parts) != 3 { t.Fatalf("expected 3 segments, got %d", len(parts)) } for i, p := range parts { if p == "" { t.Fatalf("segment %d is empty", i) } } } // --- Full pipeline round-trip tests --- // TestRoundTrip_IDToken exercises the full ID token pipeline: // NewPrivateKey -> NewSigner -> SignToString -> Decode -> Verify -> UnmarshalClaims -> Validate. func TestRoundTrip_IDToken(t *testing.T) { pk, err := jwt.NewPrivateKey() if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatal(err) } now := time.Now() claims := &jwt.TokenClaims{ Iss: "https://idp.example.com", Sub: "user-42", Aud: jwt.Listish{"my-client"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), AuthTime: now.Unix(), AzP: "my-client", Nonce: "n-0S6_WzA2Mj", AMR: []string{"pwd", "otp"}, JTI: "id-tok-001", } tokenStr, err := signer.SignToString(claims) if err != nil { t.Fatal(err) } // Decode + Verify decoded, err := jwt.Decode(tokenStr) if err != nil { t.Fatalf("Decode: %v", err) } verifier := signer.Verifier() if err := verifier.Verify(decoded); err != nil { t.Fatalf("Verify: %v", err) } // UnmarshalClaims var got jwt.TokenClaims if err := decoded.UnmarshalClaims(&got); err != nil { t.Fatalf("UnmarshalClaims: %v", err) } // Validate with NewIDTokenValidator v := jwt.NewIDTokenValidator( []string{"https://idp.example.com"}, []string{"my-client"}, []string{"my-client"}, 0, ) if err := v.Validate(nil, &got, time.Now()); err != nil { t.Fatalf("Validate: %v", err) } // Spot-check round-tripped fields. if got.Sub != "user-42" { t.Errorf("sub: got %q, want %q", got.Sub, "user-42") } if got.Nonce != "n-0S6_WzA2Mj" { t.Errorf("nonce: got %q, want %q", got.Nonce, "n-0S6_WzA2Mj") } if len(got.AMR) != 2 || got.AMR[0] != "pwd" || got.AMR[1] != "otp" { t.Errorf("amr: got %v, want [pwd otp]", got.AMR) } } // TestRoundTrip_AccessToken exercises the full access token pipeline: // NewAccessToken -> SignJWT -> Encode -> Decode -> Verify -> UnmarshalClaims -> Validate with RequiredScopes. func TestRoundTrip_AccessToken(t *testing.T) { pk, err := jwt.NewPrivateKey() if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatal(err) } now := time.Now() claims := &jwt.TokenClaims{ Iss: "https://auth.example.com", Sub: "svc-account", Aud: jwt.Listish{"https://api.example.com"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), JTI: "at-001", ClientID: "backend-svc", Scope: jwt.SpaceDelimited{"read", "write", "admin"}, } jws, err := jwt.NewAccessToken(claims) if err != nil { t.Fatal(err) } if err := signer.SignJWT(jws); err != nil { t.Fatal(err) } tokenStr, err := jwt.Encode(jws) if err != nil { t.Fatalf("Encode: %v", err) } // Decode + Verify decoded, err := jwt.Decode(tokenStr) if err != nil { t.Fatalf("Decode: %v", err) } // Verify typ survived the round-trip. if got := decoded.GetHeader().Typ; got != jwt.AccessTokenTyp { t.Errorf("typ: got %q, want %q", got, jwt.AccessTokenTyp) } verifier := signer.Verifier() if err := verifier.Verify(decoded); err != nil { t.Fatalf("Verify: %v", err) } var got jwt.TokenClaims if err := decoded.UnmarshalClaims(&got); err != nil { t.Fatalf("UnmarshalClaims: %v", err) } // Validate with RequiredScopes. v := jwt.NewAccessTokenValidator( []string{"https://auth.example.com"}, []string{"https://api.example.com"}, 0, ) v.RequiredScopes = []string{"read", "write"} if err := v.Validate(nil, &got, time.Now()); err != nil { t.Fatalf("Validate: %v", err) } // Spot-check scope round-trip. if len(got.Scope) != 3 || got.Scope[0] != "read" || got.Scope[2] != "admin" { t.Errorf("scope: got %v, want [read write admin]", got.Scope) } if got.ClientID != "backend-svc" { t.Errorf("client_id: got %q, want %q", got.ClientID, "backend-svc") } } // TestRoundTrip_StandardClaims verifies that StandardClaims with NullBool fields // survive the sign-decode round-trip. func TestRoundTrip_StandardClaims(t *testing.T) { pk, err := jwt.NewPrivateKey() if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatal(err) } now := time.Now() claims := &jwt.StandardClaims{ TokenClaims: jwt.TokenClaims{ Iss: "https://idp.example.com", Sub: "user-99", Aud: jwt.Listish{"app"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), AuthTime: now.Unix(), AzP: "app", }, Name: "Jane Doe", Email: "jane@example.com", EmailVerified: jwt.NullBool{Bool: true, Valid: true}, PhoneNumber: "+1-555-0100", PhoneNumberVerified: jwt.NullBool{Bool: false, Valid: true}, Locale: "en-US", } tokenStr, err := signer.SignToString(claims) if err != nil { t.Fatal(err) } // Decode + Verify verifier := signer.Verifier() jws, err := verifier.VerifyJWT(tokenStr) if err != nil { t.Fatalf("VerifyJWT: %v", err) } var got jwt.StandardClaims if err := jws.UnmarshalClaims(&got); err != nil { t.Fatalf("UnmarshalClaims: %v", err) } // Verify NullBool fields survived the round-trip. if !got.EmailVerified.Valid || !got.EmailVerified.Bool { t.Errorf("email_verified: got %+v, want {Bool:true Valid:true}", got.EmailVerified) } if !got.PhoneNumberVerified.Valid || got.PhoneNumberVerified.Bool { t.Errorf("phone_number_verified: got %+v, want {Bool:false Valid:true}", got.PhoneNumberVerified) } // Verify other profile fields. if got.Name != "Jane Doe" { t.Errorf("name: got %q, want %q", got.Name, "Jane Doe") } if got.Email != "jane@example.com" { t.Errorf("email: got %q, want %q", got.Email, "jane@example.com") } if got.Locale != "en-US" { t.Errorf("locale: got %q, want %q", got.Locale, "en-US") } // Verify that an unset NullBool (not in JSON) comes back as zero value. // StandardClaims has no field we explicitly omitted that is a NullBool // other than the two we set, so we create a fresh StandardClaims without // email_verified to test omission. claims2 := &jwt.StandardClaims{ TokenClaims: jwt.TokenClaims{ Iss: "https://idp.example.com", Sub: "user-100", Aud: jwt.Listish{"app"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), AuthTime: now.Unix(), AzP: "app", }, Email: "bob@example.com", // EmailVerified left as zero value (NullBool{}) } tok2, err := signer.SignToString(claims2) if err != nil { t.Fatal(err) } jws2, err := verifier.VerifyJWT(tok2) if err != nil { t.Fatal(err) } var got2 jwt.StandardClaims if err := jws2.UnmarshalClaims(&got2); err != nil { t.Fatal(err) } if got2.EmailVerified.Valid { t.Errorf("omitted email_verified should be invalid, got %+v", got2.EmailVerified) } } // --- DPoPJWT: custom header type used by TestRoundTrip_CustomHeader --- // dpopHeader extends the standard JOSE header with a DPoP nonce. type dpopHeader struct { jwt.RFCHeader Nonce string `json:"nonce,omitempty"` } // dpopJWT is a custom JWT that carries a dpopHeader. type dpopJWT struct { jwt.RawJWT Header dpopHeader } func (d *dpopJWT) GetHeader() jwt.RFCHeader { return d.Header.RFCHeader } func (d *dpopJWT) SetHeader(hdr jwt.Header) error { d.Header.RFCHeader = *hdr.GetRFCHeader() data, err := json.Marshal(d.Header) if err != nil { return err } d.Protected = []byte(base64.RawURLEncoding.EncodeToString(data)) return nil } // TestRoundTrip_CustomHeader verifies that custom header fields survive the // full sign-decode-verify round-trip when using a custom SignableJWT type. func TestRoundTrip_CustomHeader(t *testing.T) { pk, err := jwt.NewPrivateKey() if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatal(err) } verifier := signer.Verifier() now := time.Now() claims := &jwt.TokenClaims{ Iss: "https://example.com", Sub: "user-1", Aud: jwt.Listish{"api"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), } // Build and sign a DPoP JWT with a custom nonce header. dpop := &dpopJWT{Header: dpopHeader{ RFCHeader: jwt.RFCHeader{Typ: "dpop+jwt"}, Nonce: "server-nonce-abc", }} if err := dpop.SetClaims(claims); err != nil { t.Fatal(err) } if err := signer.SignJWT(dpop); err != nil { t.Fatal(err) } tokenStr, err := jwt.Encode(dpop) if err != nil { t.Fatal(err) } // Verify the signature with the standard Decode path. jws, err := jwt.Decode(tokenStr) if err != nil { t.Fatalf("Decode: %v", err) } if err := verifier.Verify(jws); err != nil { t.Fatalf("Verify: %v", err) } // Recover the custom header via DecodeRaw + UnmarshalHeader. raw, err := jwt.DecodeRaw(tokenStr) if err != nil { t.Fatalf("DecodeRaw: %v", err) } var gotHdr dpopHeader if err := raw.UnmarshalHeader(&gotHdr); err != nil { t.Fatalf("UnmarshalHeader: %v", err) } if gotHdr.Nonce != "server-nonce-abc" { t.Errorf("nonce: got %q, want %q", gotHdr.Nonce, "server-nonce-abc") } if gotHdr.Typ != "dpop+jwt" { t.Errorf("typ: got %q, want %q", gotHdr.Typ, "dpop+jwt") } if gotHdr.Alg != "EdDSA" { t.Errorf("alg: got %q, want %q", gotHdr.Alg, "EdDSA") } if gotHdr.KID != pk.KID { t.Errorf("kid: got %q, want %q", gotHdr.KID, pk.KID) } } // TestRoundTrip_ExpiredTokenRejection signs a token with a past exp, verifies // that Decode+Verify succeeds (signature is valid) but Validate rejects it. func TestRoundTrip_ExpiredTokenRejection(t *testing.T) { pk, err := jwt.NewPrivateKey() if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatal(err) } now := time.Now() claims := &jwt.TokenClaims{ Iss: "https://idp.example.com", Sub: "user-1", Aud: jwt.Listish{"app"}, Exp: now.Add(-time.Hour).Unix(), // expired 1 hour ago IAt: now.Add(-2 * time.Hour).Unix(), AuthTime: now.Add(-2 * time.Hour).Unix(), AzP: "app", } tokenStr, err := signer.SignToString(claims) if err != nil { t.Fatal(err) } // Decode + Verify should succeed - the signature is valid. decoded, err := jwt.Decode(tokenStr) if err != nil { t.Fatalf("Decode: %v", err) } verifier := signer.Verifier() if err := verifier.Verify(decoded); err != nil { t.Fatalf("Verify should succeed for expired token: %v", err) } var got jwt.TokenClaims if err := decoded.UnmarshalClaims(&got); err != nil { t.Fatalf("UnmarshalClaims: %v", err) } // Validate should fail with ErrAfterExp. v := jwt.NewIDTokenValidator( []string{"https://idp.example.com"}, []string{"app"}, []string{"app"}, 0, ) err = v.Validate(nil, &got, time.Now()) if err == nil { t.Fatal("expected validation error for expired token") } if !errors.Is(err, jwt.ErrAfterExp) { t.Fatalf("expected ErrAfterExp, got: %v", err) } } // TestRoundTrip_WrongAudienceRejection signs a token with one audience and // validates against a different audience, expecting rejection. func TestRoundTrip_WrongAudienceRejection(t *testing.T) { pk, err := jwt.NewPrivateKey() if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatal(err) } now := time.Now() claims := &jwt.TokenClaims{ Iss: "https://idp.example.com", Sub: "user-1", Aud: jwt.Listish{"app-A"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), AuthTime: now.Unix(), AzP: "app-A", } tokenStr, err := signer.SignToString(claims) if err != nil { t.Fatal(err) } // Decode + Verify succeeds. verifier := signer.Verifier() jws, err := verifier.VerifyJWT(tokenStr) if err != nil { t.Fatalf("VerifyJWT: %v", err) } var got jwt.TokenClaims if err := jws.UnmarshalClaims(&got); err != nil { t.Fatalf("UnmarshalClaims: %v", err) } // Validate with a different audience - should fail. v := jwt.NewIDTokenValidator( []string{"https://idp.example.com"}, []string{"app-B"}, // wrong audience []string{"app-A"}, 0, ) err = v.Validate(nil, &got, time.Now()) if err == nil { t.Fatal("expected validation error for wrong audience") } if !errors.Is(err, jwt.ErrInvalidClaim) { t.Fatalf("expected ErrInvalidClaim for aud mismatch, got: %v", err) } } // TestDuplicateKIDRotation verifies that when multiple keys share the same KID // (e.g. during key rotation), the verifier tries all matching keys and succeeds // if any one of them verifies the signature. func TestDuplicateKIDRotation(t *testing.T) { oldKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) newKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) // Both keys share the same KID (simulating a rotation where the KID is reused). sharedKID := "rotating-key" // Sign a token with the OLD key. oldSigner, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, oldKey, sharedKID)}) if err != nil { t.Fatal(err) } claims := goodClaims() oldToken, err := oldSigner.SignToString(&claims) if err != nil { t.Fatal(err) } // Sign a token with the NEW key. newSigner, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, newKey, sharedKID)}) if err != nil { t.Fatal(err) } newToken, err := newSigner.SignToString(&claims) if err != nil { t.Fatal(err) } // Verifier has both keys under the same KID. verifier, err := jwt.NewVerifier([]jwt.PublicKey{ {Key: &oldKey.PublicKey, KID: sharedKID}, {Key: &newKey.PublicKey, KID: sharedKID}, }) if err != nil { t.Fatal(err) } // Both tokens should verify successfully. for _, tt := range []struct { name string token string }{ {"old key", oldToken}, {"new key", newToken}, } { t.Run(tt.name, func(t *testing.T) { parsed, err := jwt.Decode(tt.token) if err != nil { t.Fatalf("Decode: %v", err) } if err := verifier.Verify(parsed); err != nil { t.Fatalf("Verify should succeed for %s: %v", tt.name, err) } }) } } // TestNoKIDTriesAllKeys verifies that when a token has no KID, all verifier // keys are tried and the first successful verification wins. func TestNoKIDTriesAllKeys(t *testing.T) { wrongKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) rightKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) // Sign with rightKey using SignRaw with a header that has no KID. signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, rightKey, "any")}) hdr := &jwt.RFCHeader{} // no KID, no typ payloadJSON := []byte(`{"sub":"user-1"}`) raw, err := signer.SignRaw(hdr, payloadJSON) if err != nil { t.Fatal(err) } // Reconstruct as compact token: protected.payload.signature token := string(raw.GetProtected()) + "." + string(raw.GetPayload()) + "." + base64.RawURLEncoding.EncodeToString(raw.GetSignature()) // Verifier has wrongKey first, then rightKey. verifier, err := jwt.NewVerifier([]jwt.PublicKey{ {Key: &wrongKey.PublicKey, KID: "wrong"}, {Key: &rightKey.PublicKey, KID: "right"}, }) if err != nil { t.Fatal(err) } parsed, err := jwt.Decode(token) if err != nil { t.Fatal(err) } if parsed.GetHeader().KID != "" { t.Fatalf("token should have no KID, got %q", parsed.GetHeader().KID) } if err := verifier.Verify(parsed); err != nil { t.Fatalf("Verify should try all keys when token has no KID: %v", err) } } // TestEmptyKIDTokenEmptyKIDKey verifies that when both the token and a // verifier key have empty KIDs, verification succeeds (all keys are tried). func TestEmptyKIDTokenEmptyKIDKey(t *testing.T) { rightKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) // Sign with SignRaw to produce a token with no KID. signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, rightKey, "any")}) raw, err := signer.SignRaw(&jwt.RFCHeader{}, []byte(`{"sub":"user-1"}`)) if err != nil { t.Fatal(err) } token := string(raw.GetProtected()) + "." + string(raw.GetPayload()) + "." + base64.RawURLEncoding.EncodeToString(raw.GetSignature()) // Verifier key also has empty KID. verifier, err := jwt.NewVerifier([]jwt.PublicKey{ {Key: &rightKey.PublicKey, KID: ""}, }) if err != nil { t.Fatal(err) } parsed, err := jwt.Decode(token) if err != nil { t.Fatal(err) } if err := verifier.Verify(parsed); err != nil { t.Fatalf("empty KID token + empty KID key should verify: %v", err) } } // TestKIDTokenEmptyKIDKey verifies that when the token has a KID but the // verifier key has an empty KID, the key is not a candidate (no match). func TestKIDTokenEmptyKIDKey(t *testing.T) { rightKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) // Sign normally -- token gets a KID from the signer. signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, rightKey, "my-kid")}) claims := goodClaims() token, _ := signer.SignToString(&claims) // Verifier has the same key material but with an empty KID. verifier, err := jwt.NewVerifier([]jwt.PublicKey{ {Key: &rightKey.PublicKey, KID: ""}, }) if err != nil { t.Fatal(err) } parsed, err := jwt.Decode(token) if err != nil { t.Fatal(err) } err = verifier.Verify(parsed) if !errors.Is(err, jwt.ErrUnknownKID) { t.Fatalf("token KID %q should not match empty-KID key, expected ErrUnknownKID, got: %v", parsed.GetHeader().KID, err) } } // TestMultiKeyVerifier verifies that a Verifier with keys of different algorithms // correctly selects the right key by KID when verifying tokens. func TestMultiKeyVerifier(t *testing.T) { ecKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048) edPub, edPriv, _ := ed25519.GenerateKey(rand.Reader) ecPK := mustPK(t, ecKey, "ec-key") rsaPK := mustPK(t, rsaKey, "rsa-key") edPK := mustPK(t, edPriv, "ed-key") // Create a verifier with all three public keys. verifier, err := jwt.NewVerifier([]jwt.PublicKey{ {Key: &ecKey.PublicKey, KID: "ec-key"}, {Key: &rsaKey.PublicKey, KID: "rsa-key"}, {Key: edPub, KID: "ed-key"}, }) if err != nil { t.Fatal(err) } // Sign a token with each key type and verify the multi-key verifier picks the right one. for _, tt := range []struct { name string pk *jwt.PrivateKey alg string }{ {"EC/ES256", ecPK, "ES256"}, {"RSA/RS256", rsaPK, "RS256"}, {"Ed25519/EdDSA", edPK, "EdDSA"}, } { t.Run(tt.name, func(t *testing.T) { signer, err := jwt.NewSigner([]*jwt.PrivateKey{tt.pk}) if err != nil { t.Fatal(err) } claims := goodClaims() tokenStr, err := signer.SignToString(&claims) if err != nil { t.Fatal(err) } parsed, err := jwt.Decode(tokenStr) if err != nil { t.Fatalf("Decode: %v", err) } if parsed.GetHeader().Alg != tt.alg { t.Fatalf("alg: got %s, want %s", parsed.GetHeader().Alg, tt.alg) } if err := verifier.Verify(parsed); err != nil { t.Fatalf("Verify failed for %s: %v", tt.alg, err) } }) } } // TestAudienceSingleString verifies that a single-string "aud" claim // (RFC 7519 ยง4.1.3) is correctly unmarshaled as a single-element Audience. func TestAudienceSingleString(t *testing.T) { pk, err := jwt.NewPrivateKey() if err != nil { t.Fatal(err) } signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) if err != nil { t.Fatal(err) } now := time.Now() claims := &jwt.TokenClaims{ Iss: "https://example.com", Sub: "user-1", Aud: jwt.Listish{"single-aud"}, // single element Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), AuthTime: now.Unix(), } tokenStr, err := signer.SignToString(claims) if err != nil { t.Fatal(err) } // Verify the wire format uses a string (not an array) for single-element aud. decoded, err := jwt.Decode(tokenStr) if err != nil { t.Fatal(err) } var rawPayload map[string]json.RawMessage payloadBytes, _ := base64.RawURLEncoding.DecodeString(string(decoded.GetPayload())) if err := json.Unmarshal(payloadBytes, &rawPayload); err != nil { t.Fatal(err) } audRaw := string(rawPayload["aud"]) if audRaw[0] == '[' { t.Errorf("single-element aud should be a string, got array: %s", audRaw) } // Unmarshal and validate. var got jwt.TokenClaims if err := decoded.UnmarshalClaims(&got); err != nil { t.Fatal(err) } if len(got.Aud) != 1 || got.Aud[0] != "single-aud" { t.Errorf("aud: got %v, want [single-aud]", got.Aud) } // Also test unmarshaling from a manually constructed single-string aud. singleJSON := []byte(`"just-one"`) var aud jwt.Listish if err := json.Unmarshal(singleJSON, &aud); err != nil { t.Fatalf("Unmarshal single-string aud: %v", err) } if len(aud) != 1 || aud[0] != "just-one" { t.Errorf("single-string unmarshal: got %v, want [just-one]", aud) } // And array form. arrayJSON := []byte(`["a","b"]`) var aud2 jwt.Listish if err := json.Unmarshal(arrayJSON, &aud2); err != nil { t.Fatalf("Unmarshal array aud: %v", err) } if len(aud2) != 2 || aud2[0] != "a" || aud2[1] != "b" { t.Errorf("array unmarshal: got %v, want [a b]", aud2) } } // TestVerifyTamperedPayload confirms that a tampered payload (modified after signing) // is rejected by signature verification. func TestVerifyTamperedPayload(t *testing.T) { privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) signer, _ := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, privKey, "k")}) claims := goodClaims() token, _ := signer.SignToString(&claims) verifier := goodVerifier(jwt.PublicKey{Key: &privKey.PublicKey, KID: "k"}) // Tamper with the payload: change the sub claim. parts := strings.SplitN(token, ".", 3) payloadBytes, _ := base64.RawURLEncoding.DecodeString(parts[1]) tampered := strings.Replace(string(payloadBytes), claims.Sub, "evil-user", 1) parts[1] = base64.RawURLEncoding.EncodeToString([]byte(tampered)) tamperedToken := strings.Join(parts, ".") parsed, err := jwt.Decode(tamperedToken) if err != nil { t.Fatalf("Decode should succeed for well-formed tampered token: %v", err) } if err := verifier.Verify(parsed); err == nil { t.Fatal("expected Verify to fail for tampered payload") } else if !errors.Is(err, jwt.ErrSignatureInvalid) { t.Fatalf("expected ErrSignatureInvalid, got: %v", err) } } // TestValidationErrorAnnotation verifies that time-related validation errors // include the server time annotation. func TestValidationErrorAnnotation(t *testing.T) { now := time.Now() claims := &jwt.TokenClaims{ Iss: "https://example.com", Sub: "user-1", Aud: jwt.Listish{"app"}, Exp: now.Add(-time.Hour).Unix(), // expired IAt: now.Unix(), AuthTime: now.Unix(), } v := jwt.NewIDTokenValidator( []string{"https://example.com"}, []string{"app"}, nil, 0, ) err := v.Validate(nil, claims, now) if err == nil { t.Fatal("expected validation error") } if !strings.Contains(err.Error(), "server time") { t.Errorf("time error should include server time annotation: %v", err) } } // TestValidateThreadsHeaderErrors verifies that errors from header validation // (IsAllowedTyp) are preserved when threaded into Validate. func TestValidateThreadsHeaderErrors(t *testing.T) { now := time.Now() claims := &jwt.TokenClaims{ Iss: "https://example.com", Sub: "user-1", Aud: jwt.Listish{"app"}, Exp: now.Add(time.Hour).Unix(), IAt: now.Unix(), AuthTime: now.Unix(), } v := jwt.NewIDTokenValidator( []string{"https://example.com"}, []string{"app"}, nil, 0, ) // Simulate a header check failure. hdr := jwt.RFCHeader{Typ: "at+jwt"} // wrong typ for ID token var errs []error errs = hdr.IsAllowedTyp(errs, []string{"JWT"}) if len(errs) == 0 { t.Fatal("expected IsAllowedTyp to produce an error") } // Thread header errors into Validate - claims are valid, so only header error remains. err := v.Validate(errs, claims, now) if err == nil { t.Fatal("expected error from threaded header validation") } if !strings.Contains(err.Error(), "typ") { t.Errorf("expected typ error in output: %v", err) } }