test(auth/jwt): more tests

This commit is contained in:
AJ ONeal 2026-03-17 04:16:41 -06:00
parent 08dd881974
commit 999d7c2615
No known key found for this signature in database
3 changed files with 3453 additions and 0 deletions

714
auth/jwt/edge_test.go Normal file
View File

@ -0,0 +1,714 @@
package jwt_test
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"math/big"
"testing"
"github.com/therootcompany/golib/auth/jwt"
)
// --- JWK parsing edge cases ---
func TestJWKMissingKty(t *testing.T) {
data := []byte(`{"kid":"test"}`)
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for missing kty")
}
if !errors.Is(err, jwt.ErrUnsupportedKeyType) {
t.Fatalf("expected ErrUnsupportedKeyType, got: %v", err)
}
}
func TestJWKUnknownKty(t *testing.T) {
data := []byte(`{"kty":"MAGIC","kid":"test"}`)
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for unknown kty")
}
if !errors.Is(err, jwt.ErrUnsupportedKeyType) {
t.Fatalf("expected ErrUnsupportedKeyType, got: %v", err)
}
}
func TestPrivateKeyMissingD(t *testing.T) {
// Valid EC public key but no "d" field
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
pub := jwt.PublicKey{Key: &ecKey.PublicKey, KID: "test"}
pubJSON, err := json.Marshal(pub)
if err != nil {
t.Fatal(err)
}
var pk jwt.PrivateKey
err = pk.UnmarshalJSON(pubJSON)
if err == nil {
t.Fatal("expected error for missing d field")
}
if !errors.Is(err, jwt.ErrMissingKeyData) {
t.Fatalf("expected ErrMissingKeyData, got: %v", err)
}
}
func TestRSAKeyTooSmall(t *testing.T) {
tests := []struct {
name string
nLen int // modulus byte length
}{
{"all_zeros_1024bit", 128}, // 1024 bits of zeros - Size() returns 0
{"all_zeros_64byte", 64}, // 512 bits of zeros
{"all_zeros_1byte", 1}, // 8 bits of zeros
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
n := make([]byte, tt.nLen) // all zeros
data, _ := json.Marshal(map[string]string{
"kty": "RSA",
"kid": "small",
"n": base64.RawURLEncoding.EncodeToString(n),
"e": "AQAB",
})
var decoded jwt.PublicKey
err := decoded.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for all-zeros RSA modulus")
}
if !errors.Is(err, jwt.ErrKeyTooSmall) {
t.Fatalf("expected ErrKeyTooSmall, got: %v", err)
}
})
}
}
func TestRSADegenerateExponent(t *testing.T) {
tests := []struct {
name string
e int
}{
{"exponent_0", 0},
{"exponent_1", 1},
{"exponent_2", 2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Build a JWK with a degenerate exponent
n := make([]byte, 256) // 2048-bit modulus
n[0] = 1 // non-zero MSB
eBytes := big.NewInt(int64(tt.e)).Bytes()
data, _ := json.Marshal(map[string]string{
"kty": "RSA",
"kid": "bad-e",
"n": base64.RawURLEncoding.EncodeToString(n),
"e": base64.RawURLEncoding.EncodeToString(eBytes),
})
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for degenerate RSA exponent")
}
if !errors.Is(err, jwt.ErrInvalidKey) {
t.Fatalf("expected ErrInvalidKey, got: %v", err)
}
})
}
}
func TestRSAEmptyFields(t *testing.T) {
tests := []struct {
name string
jwk map[string]string
}{
{"empty_n", map[string]string{"kty": "RSA", "n": "", "e": "AQAB"}},
{"empty_e", map[string]string{"kty": "RSA", "n": "AQAB", "e": ""}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, _ := json.Marshal(tt.jwk)
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for empty RSA field")
}
if !errors.Is(err, jwt.ErrInvalidKey) {
t.Fatalf("expected ErrInvalidKey, got: %v", err)
}
})
}
}
func TestEd25519WrongKeySize(t *testing.T) {
tests := []struct {
name string
size int
}{
{"too_short_31", 31},
{"too_long_33", 33},
{"zero_length", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
x := make([]byte, tt.size)
data, _ := json.Marshal(map[string]string{
"kty": "OKP",
"crv": "Ed25519",
"x": base64.RawURLEncoding.EncodeToString(x),
})
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatalf("expected error for Ed25519 key size %d", tt.size)
}
if !errors.Is(err, jwt.ErrInvalidKey) {
t.Fatalf("expected ErrInvalidKey, got: %v", err)
}
})
}
}
func TestEd25519AllZerosKey(t *testing.T) {
// All-zeros is a valid encoding but represents a low-order point.
// The key should parse (it's 32 bytes), and signing should work
// but verification with the wrong key should fail.
x := make([]byte, ed25519.PublicKeySize) // all zeros
data, _ := json.Marshal(map[string]string{
"kty": "OKP",
"crv": "Ed25519",
"kid": "zero-key",
"x": base64.RawURLEncoding.EncodeToString(x),
})
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err != nil {
t.Fatalf("all-zeros Ed25519 key should parse: %v", err)
}
// Verify with this key should reject any real signature
realKey, err := jwt.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{realKey})
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
tokenStr, err := signer.SignToString(&claims)
if err != nil {
t.Fatal(err)
}
// Parse and try to verify with the all-zeros key
jws, err := jwt.Decode(tokenStr)
if err != nil {
t.Fatal(err)
}
// Change the kid in the header to match our zero key
zeroVerifier, _ := jwt.NewVerifier([]jwt.PublicKey{pk})
// The KID won't match, but let's verify that the system handles it
err = zeroVerifier.Verify(jws)
if err == nil {
t.Fatal("expected verification to fail with wrong key")
}
}
func TestOKPWrongCrv(t *testing.T) {
data, _ := json.Marshal(map[string]string{
"kty": "OKP",
"crv": "X25519",
"x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)),
})
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for X25519 crv")
}
if !errors.Is(err, jwt.ErrUnsupportedCurve) {
t.Fatalf("expected ErrUnsupportedCurve, got: %v", err)
}
}
func TestOKPPrivateWrongCrv(t *testing.T) {
data, _ := json.Marshal(map[string]string{
"kty": "OKP",
"crv": "X25519",
"x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)),
"d": base64.RawURLEncoding.EncodeToString(make([]byte, 32)),
})
var pk jwt.PrivateKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for X25519 crv on private key")
}
if !errors.Is(err, jwt.ErrUnsupportedCurve) {
t.Fatalf("expected ErrUnsupportedCurve, got: %v", err)
}
}
func TestECCoordinatesTooLong(t *testing.T) {
ci := struct {
keySize int
crv string
}{32, "P-256"} // P-256 has 32-byte coordinates
tests := []struct {
name string
xSize int
ySize int
}{
{"x_too_long", ci.keySize + 1, ci.keySize},
{"y_too_long", ci.keySize, ci.keySize + 1},
{"both_too_long", ci.keySize + 1, ci.keySize + 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, _ := json.Marshal(map[string]string{
"kty": "EC",
"crv": ci.crv,
"x": base64.RawURLEncoding.EncodeToString(make([]byte, tt.xSize)),
"y": base64.RawURLEncoding.EncodeToString(make([]byte, tt.ySize)),
})
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for oversized EC coordinates")
}
if !errors.Is(err, jwt.ErrInvalidKey) {
t.Fatalf("expected ErrInvalidKey, got: %v", err)
}
})
}
}
func TestECUnsupportedCurve(t *testing.T) {
data, _ := json.Marshal(map[string]string{
"kty": "EC",
"crv": "P-192",
"x": base64.RawURLEncoding.EncodeToString(make([]byte, 24)),
"y": base64.RawURLEncoding.EncodeToString(make([]byte, 24)),
})
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for P-192 curve")
}
if !errors.Is(err, jwt.ErrUnsupportedCurve) {
t.Fatalf("expected ErrUnsupportedCurve, got: %v", err)
}
}
func TestEd25519PrivateWrongSeedSize(t *testing.T) {
tests := []struct {
name string
size int
}{
{"seed_too_short", 31},
{"seed_too_long", 33},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, _ := json.Marshal(map[string]string{
"kty": "OKP",
"crv": "Ed25519",
"x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)),
"d": base64.RawURLEncoding.EncodeToString(make([]byte, tt.size)),
})
var pk jwt.PrivateKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatalf("expected error for seed size %d", tt.size)
}
if !errors.Is(err, jwt.ErrInvalidKey) {
t.Fatalf("expected ErrInvalidKey, got: %v", err)
}
})
}
}
func TestInvalidBase64Fields(t *testing.T) {
tests := []struct {
name string
jwk map[string]string
}{
{"invalid_rsa_n", map[string]string{"kty": "RSA", "n": "!!!invalid!!!", "e": "AQAB"}},
{"invalid_rsa_e", map[string]string{"kty": "RSA", "n": "AQAB", "e": "!!!"}},
{"invalid_ec_x", map[string]string{"kty": "EC", "crv": "P-256", "x": "!!!", "y": "AAAA"}},
{"invalid_ec_y", map[string]string{"kty": "EC", "crv": "P-256", "x": "AAAA", "y": "!!!"}},
{"invalid_okp_x", map[string]string{"kty": "OKP", "crv": "Ed25519", "x": "!!!"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, _ := json.Marshal(tt.jwk)
var pk jwt.PublicKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for invalid base64")
}
if !errors.Is(err, jwt.ErrInvalidKey) {
t.Fatalf("expected ErrInvalidKey, got: %v", err)
}
})
}
}
func TestInvalidBase64PrivateFields(t *testing.T) {
tests := []struct {
name string
jwk map[string]string
}{
{"invalid_ec_d", map[string]string{
"kty": "EC", "crv": "P-256",
"x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)),
"y": base64.RawURLEncoding.EncodeToString(make([]byte, 32)),
"d": "!!!invalid!!!",
}},
{"invalid_rsa_d", map[string]string{
"kty": "RSA",
"n": base64.RawURLEncoding.EncodeToString(make([]byte, 256)),
"e": "AQAB",
"d": "!!!invalid!!!",
}},
{"invalid_okp_d", map[string]string{
"kty": "OKP", "crv": "Ed25519",
"x": base64.RawURLEncoding.EncodeToString(make([]byte, 32)),
"d": "!!!invalid!!!",
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, _ := json.Marshal(tt.jwk)
var pk jwt.PrivateKey
err := pk.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for invalid base64 in private field")
}
if !errors.Is(err, jwt.ErrInvalidKey) {
t.Fatalf("expected ErrInvalidKey, got: %v", err)
}
})
}
}
// --- Signature verification edge cases ---
func TestVerifyWrongKeyTypeForAlg(t *testing.T) {
// Sign with Ed25519, then try to verify with an RSA key
// that has the same KID
edKey, err := jwt.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{edKey})
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
tokenStr, err := signer.SignToString(&claims)
if err != nil {
t.Fatal(err)
}
// Create an RSA key with the same KID
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
rsaPub := jwt.PublicKey{
Key: &rsaKey.PublicKey,
KID: edKey.KID, // same KID
}
verifier, _ := jwt.NewVerifier([]jwt.PublicKey{rsaPub})
jws, err := jwt.Decode(tokenStr)
if err != nil {
t.Fatal(err)
}
err = verifier.Verify(jws)
if err == nil {
t.Fatal("expected error: wrong key type for EdDSA alg")
}
if !errors.Is(err, jwt.ErrAlgConflict) {
t.Fatalf("expected ErrAlgConflict, got: %v", err)
}
}
func TestVerifyZeroLengthSignature(t *testing.T) {
// Create a valid token then replace the signature with empty
key, err := jwt.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{key})
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
tokenStr, err := signer.SignToString(&claims)
if err != nil {
t.Fatal(err)
}
// Replace signature with empty
parts := splitToken(tokenStr)
tampered := parts[0] + "." + parts[1] + "."
jws, err := jwt.Decode(tampered)
if err != nil {
t.Fatal(err)
}
verifier := signer.Verifier()
err = verifier.Verify(jws)
if err == nil {
t.Fatal("expected error for zero-length signature")
}
if !errors.Is(err, jwt.ErrSignatureInvalid) {
t.Fatalf("expected ErrSignatureInvalid, got: %v", err)
}
}
func TestVerifyECDSAWrongSigLength(t *testing.T) {
// Sign with P-256, verify with correct key but tampered sig length
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
pk, err := jwt.FromPrivateKey(ecKey, "")
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk})
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
tokenStr, err := signer.SignToString(&claims)
if err != nil {
t.Fatal(err)
}
// Replace signature with wrong-length bytes
parts := splitToken(tokenStr)
wrongSig := base64.RawURLEncoding.EncodeToString([]byte("short"))
tampered := parts[0] + "." + parts[1] + "." + wrongSig
jws, err := jwt.Decode(tampered)
if err != nil {
t.Fatal(err)
}
verifier := signer.Verifier()
err = verifier.Verify(jws)
if err == nil {
t.Fatal("expected error for wrong ECDSA signature length")
}
if !errors.Is(err, jwt.ErrSignatureInvalid) {
t.Fatalf("expected ErrSignatureInvalid, got: %v", err)
}
}
func TestVerifyUnsupportedAlg(t *testing.T) {
// Build a token with an unsupported alg header
key, err := jwt.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{key})
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
tokenStr, err := signer.SignToString(&claims)
if err != nil {
t.Fatal(err)
}
// Tamper the header to use an unsupported alg
header := map[string]string{"alg": "HS256", "kid": key.KID, "typ": "JWT"}
headerJSON, _ := json.Marshal(header)
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
parts := splitToken(tokenStr)
tampered := headerB64 + "." + parts[1] + "." + parts[2]
jws, err := jwt.Decode(tampered)
if err != nil {
t.Fatal(err)
}
verifier := signer.Verifier()
err = verifier.Verify(jws)
if err == nil {
t.Fatal("expected error for unsupported alg")
}
if !errors.Is(err, jwt.ErrUnsupportedAlg) {
t.Fatalf("expected ErrUnsupportedAlg, got: %v", err)
}
}
// TestVerifyMissingKID verifies that tokens without a KID header try all keys
// via fallthrough. A tampered header (different signing input) still fails
// with ErrSignatureInvalid, but the lookup itself does not reject the token.
func TestVerifyMissingKID(t *testing.T) {
key, err := jwt.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
signer, err := jwt.NewSigner([]*jwt.PrivateKey{key})
if err != nil {
t.Fatal(err)
}
claims := goodClaims()
tokenStr, err := signer.SignToString(&claims)
if err != nil {
t.Fatal(err)
}
// Tamper header to remove kid - signing input changes, so sig will be invalid.
header := map[string]string{"alg": "EdDSA", "typ": "JWT"}
headerJSON, _ := json.Marshal(header)
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
parts := splitToken(tokenStr)
tampered := headerB64 + "." + parts[1] + "." + parts[2]
jws, err := jwt.Decode(tampered)
if err != nil {
t.Fatal(err)
}
verifier := signer.Verifier()
err = verifier.Verify(jws)
if err == nil {
t.Fatal("expected error for tampered header")
}
// With no KID, all keys are tried - fails with signature invalid, not missing KID.
if !errors.Is(err, jwt.ErrSignatureInvalid) {
t.Fatalf("expected ErrSignatureInvalid, got: %v", err)
}
}
func TestDecodeEmptyToken(t *testing.T) {
_, err := jwt.Decode("")
if err == nil {
t.Fatal("expected error for empty token")
}
if !errors.Is(err, jwt.ErrMalformedToken) {
t.Fatalf("expected ErrMalformedToken, got: %v", err)
}
}
func TestDecodeOnePart(t *testing.T) {
_, err := jwt.Decode("justonepart")
if err == nil {
t.Fatal("expected error for single-part token")
}
if !errors.Is(err, jwt.ErrMalformedToken) {
t.Fatalf("expected ErrMalformedToken, got: %v", err)
}
}
func TestDecodeFourParts(t *testing.T) {
_, err := jwt.Decode("a.b.c.d")
if err == nil {
t.Fatal("expected error for four-part token")
}
if !errors.Is(err, jwt.ErrMalformedToken) {
t.Fatalf("expected ErrMalformedToken, got: %v", err)
}
}
// --- RSA private key edge cases ---
func TestRSAPrivateKeyTooSmall(t *testing.T) {
tests := []struct {
name string
nLen int
}{
{"all_zeros_1024bit", 128},
{"all_zeros_64byte", 64},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
n := make([]byte, tt.nLen) // all zeros
d := make([]byte, tt.nLen) // all zeros
data, _ := json.Marshal(map[string]string{
"kty": "RSA",
"kid": "small",
"n": base64.RawURLEncoding.EncodeToString(n),
"e": "AQAB",
"d": base64.RawURLEncoding.EncodeToString(d),
})
var decoded jwt.PrivateKey
err := decoded.UnmarshalJSON(data)
if err == nil {
t.Fatal("expected error for all-zeros RSA private key")
}
if !errors.Is(err, jwt.ErrKeyTooSmall) {
t.Fatalf("expected ErrKeyTooSmall, got: %v", err)
}
})
}
}
// --- Thumbprint edge cases ---
func TestThumbprintNilKey(t *testing.T) {
pk := jwt.PublicKey{} // nil Key
_, err := pk.Thumbprint()
if err == nil {
t.Fatal("expected error for nil key thumbprint")
}
if !errors.Is(err, jwt.ErrUnsupportedKeyType) {
t.Fatalf("expected ErrUnsupportedKeyType, got: %v", err)
}
}
// splitToken splits a compact JWT into its three dot-separated parts.
func splitToken(s string) [3]string {
var parts [3]string
idx1 := 0
for i, c := range s {
if c == '.' {
if idx1 == 0 {
parts[0] = s[:i]
idx1 = i + 1
} else {
parts[1] = s[idx1:i]
parts[2] = s[i+1:]
return parts
}
}
}
return parts
}

2459
auth/jwt/jwt_test.go Normal file

File diff suppressed because it is too large Load Diff

280
auth/jwt/nullbool_test.go Normal file
View File

@ -0,0 +1,280 @@
package jwt_test
import (
"encoding/json"
"testing"
"github.com/therootcompany/golib/auth/jwt"
)
func TestNullBool_MarshalJSON(t *testing.T) {
tests := []struct {
name string
nb jwt.NullBool
want string
}{
{"true", jwt.NullBool{Bool: true, Valid: true}, "true"},
{"false", jwt.NullBool{Bool: false, Valid: true}, "false"},
{"null (zero value)", jwt.NullBool{}, "null"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := json.Marshal(tt.nb)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
if string(got) != tt.want {
t.Errorf("Marshal = %s, want %s", got, tt.want)
}
})
}
}
func TestNullBool_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
wantValue bool
wantValid bool
}{
{"true", "true", true, true},
{"false", "false", false, true},
{"null", "null", false, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var nb jwt.NullBool
if err := json.Unmarshal([]byte(tt.input), &nb); err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if nb.Bool != tt.wantValue {
t.Errorf("Value = %v, want %v", nb.Bool, tt.wantValue)
}
if nb.Valid != tt.wantValid {
t.Errorf("Valid = %v, want %v", nb.Valid, tt.wantValid)
}
})
}
}
func TestNullBool_UnmarshalJSON_InvalidInput(t *testing.T) {
var nb jwt.NullBool
if err := json.Unmarshal([]byte(`"yes"`), &nb); err == nil {
t.Error("expected error for invalid input, got nil")
}
}
func TestNullBool_IsZero(t *testing.T) {
tests := []struct {
name string
nb jwt.NullBool
want bool
}{
{"zero value", jwt.NullBool{}, true},
{"true", jwt.NullBool{Bool: true, Valid: true}, false},
{"false", jwt.NullBool{Bool: false, Valid: true}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.nb.IsZero(); got != tt.want {
t.Errorf("IsZero() = %v, want %v", got, tt.want)
}
})
}
}
func TestNullBool_RoundTrip(t *testing.T) {
values := []jwt.NullBool{
{Bool: true, Valid: true},
{Bool: false, Valid: true},
{Bool: false, Valid: false},
}
for _, orig := range values {
data, err := json.Marshal(orig)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
var got jwt.NullBool
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if got.Bool != orig.Bool || got.Valid != orig.Valid {
t.Errorf("round-trip: got {%v, %v}, want {%v, %v}",
got.Bool, got.Valid, orig.Bool, orig.Valid)
}
}
}
func TestNullBool_ClaimsIntegration(t *testing.T) {
t.Run("marshal with email verified true", func(t *testing.T) {
claims := jwt.StandardClaims{
TokenClaims: jwt.TokenClaims{
Iss: "https://example.com",
Sub: "user123",
Exp: 9999999999,
IAt: 1000000000,
},
Email: "user@example.com",
EmailVerified: jwt.NullBool{Bool: true, Valid: true},
}
data, err := json.Marshal(claims)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
var raw map[string]json.RawMessage
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatal(err)
}
if string(raw["email_verified"]) != "true" {
t.Errorf("email_verified = %s, want true", raw["email_verified"])
}
})
t.Run("marshal with email verified false", func(t *testing.T) {
claims := jwt.StandardClaims{
TokenClaims: jwt.TokenClaims{
Iss: "https://example.com",
Sub: "user123",
Exp: 9999999999,
IAt: 1000000000,
},
Email: "user@example.com",
EmailVerified: jwt.NullBool{Bool: false, Valid: true},
}
data, err := json.Marshal(claims)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
var raw map[string]json.RawMessage
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatal(err)
}
if string(raw["email_verified"]) != "false" {
t.Errorf("email_verified = %s, want false", raw["email_verified"])
}
})
t.Run("marshal omits verified when no email", func(t *testing.T) {
claims := jwt.StandardClaims{
TokenClaims: jwt.TokenClaims{
Iss: "https://example.com",
Sub: "user123",
Exp: 9999999999,
IAt: 1000000000,
},
// No email, no EmailVerified -> field omitted via omitzero
}
data, err := json.Marshal(claims)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
var raw map[string]json.RawMessage
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatal(err)
}
if _, ok := raw["email_verified"]; ok {
t.Errorf("email_verified present = %s, want omitted", raw["email_verified"])
}
if _, ok := raw["phone_number_verified"]; ok {
t.Errorf("phone_number_verified present = %s, want omitted", raw["phone_number_verified"])
}
})
t.Run("unmarshal claims with verified fields", func(t *testing.T) {
input := `{
"iss": "https://example.com",
"sub": "user123",
"exp": 9999999999,
"iat": 1000000000,
"email": "user@example.com",
"email_verified": true,
"phone_number": "+1555000000",
"phone_number_verified": false
}`
var claims jwt.StandardClaims
if err := json.Unmarshal([]byte(input), &claims); err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if !claims.EmailVerified.Valid || !claims.EmailVerified.Bool {
t.Errorf("EmailVerified = {%v, %v}, want {true, true}",
claims.EmailVerified.Bool, claims.EmailVerified.Valid)
}
if !claims.PhoneNumberVerified.Valid || claims.PhoneNumberVerified.Bool {
t.Errorf("PhoneNumberVerified = {%v, %v}, want {false, true}",
claims.PhoneNumberVerified.Bool, claims.PhoneNumberVerified.Valid)
}
})
t.Run("unmarshal claims with null verified fields", func(t *testing.T) {
input := `{
"iss": "https://example.com",
"sub": "user123",
"exp": 9999999999,
"iat": 1000000000,
"email_verified": null,
"phone_number_verified": null
}`
var claims jwt.StandardClaims
if err := json.Unmarshal([]byte(input), &claims); err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if claims.EmailVerified.Valid {
t.Error("EmailVerified.Valid = true, want false")
}
if claims.PhoneNumberVerified.Valid {
t.Error("PhoneNumberVerified.Valid = true, want false")
}
})
t.Run("unmarshal claims with omitted verified fields", func(t *testing.T) {
input := `{
"iss": "https://example.com",
"sub": "user123",
"exp": 9999999999,
"iat": 1000000000
}`
var claims jwt.StandardClaims
if err := json.Unmarshal([]byte(input), &claims); err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
// Omitted fields -> zero value: {false, false}
if claims.EmailVerified.Valid {
t.Error("EmailVerified.Valid = true, want false")
}
if claims.PhoneNumberVerified.Valid {
t.Error("PhoneNumberVerified.Valid = true, want false")
}
})
t.Run("round-trip claims", func(t *testing.T) {
orig := jwt.StandardClaims{
TokenClaims: jwt.TokenClaims{
Iss: "https://example.com",
Sub: "user123",
Exp: 9999999999,
IAt: 1000000000,
},
Email: "user@example.com",
EmailVerified: jwt.NullBool{Bool: true, Valid: true},
PhoneNumber: "+1555000000",
PhoneNumberVerified: jwt.NullBool{Bool: false, Valid: true},
}
data, err := json.Marshal(orig)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
var got jwt.StandardClaims
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if got.EmailVerified != orig.EmailVerified {
t.Errorf("EmailVerified = %+v, want %+v", got.EmailVerified, orig.EmailVerified)
}
if got.PhoneNumberVerified != orig.PhoneNumberVerified {
t.Errorf("PhoneNumberVerified = %+v, want %+v", got.PhoneNumberVerified, orig.PhoneNumberVerified)
}
})
}