// Copyright 2026 AJ ONeal. SPDX-License-Identifier: MPL-2.0 package jwt import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/rsa" "encoding/asn1" "encoding/base64" "encoding/json" "errors" "fmt" "io" "math/big" "strings" "testing" "time" ) // ============================================================ // Helpers // ============================================================ var testNow = time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC) func mustECKey(t *testing.T, curve elliptic.Curve) *ecdsa.PrivateKey { t.Helper() k, err := ecdsa.GenerateKey(curve, rand.Reader) if err != nil { t.Fatal(err) } return k } func mustRSAKey(t *testing.T) *rsa.PrivateKey { t.Helper() k, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) } return k } func mustEdKey(t *testing.T) ed25519.PrivateKey { t.Helper() _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } return priv } func mustFromPrivate(t *testing.T, signer crypto.Signer) *PrivateKey { t.Helper() pk, err := FromPrivateKey(signer, "") if err != nil { t.Fatal(err) } return pk } func mustSigner(t *testing.T, keys ...*PrivateKey) *Signer { t.Helper() s, err := NewSigner(keys) if err != nil { t.Fatal(err) } return s } func mustSignStr(t *testing.T, s *Signer, tc *TokenClaims) string { t.Helper() tok, err := s.SignToString(tc) if err != nil { t.Fatal(err) } return tok } func goodClaims() *TokenClaims { return &TokenClaims{ Iss: "https://example.com", Sub: "user-123", Aud: Listish{"https://api.example.com"}, Exp: testNow.Add(time.Hour).Unix(), IAt: testNow.Add(-time.Minute).Unix(), JTI: "jti-abc", AuthTime: testNow.Add(-5 * time.Minute).Unix(), AzP: "client-abc", ClientID: "client-abc", Scope: SpaceDelimited{"openid", "profile"}, } } // fakeKey implements CryptoPublicKey but is not EC/RSA/Ed25519. type fakeKey struct{} func (fakeKey) Equal(crypto.PublicKey) bool { return false } // fakeSigner is a crypto.Signer with fakeKey public key. type fakeSigner struct{ pub crypto.PublicKey } func (f fakeSigner) Public() crypto.PublicKey { return f.pub } func (fakeSigner) Sign(io.Reader, []byte, crypto.SignerOpts) ([]byte, error) { return nil, fmt.Errorf("not implemented") } // badClaims fails json.Marshal because of channel field. type badClaims struct { TokenClaims Bad chan int `json:"bad"` } // ============================================================ // claims.go // ============================================================ func TestCov_GetTokenClaims(t *testing.T) { tc := &TokenClaims{Iss: "x"} got := tc.GetTokenClaims() if got != tc { t.Fatal("expected same pointer") } } // ============================================================ // types.go - Listish // ============================================================ func TestCov_Listish_UnmarshalJSON(t *testing.T) { t.Run("string", func(t *testing.T) { var l Listish if err := json.Unmarshal([]byte(`"https://ex.com"`), &l); err != nil { t.Fatal(err) } if len(l) != 1 || l[0] != "https://ex.com" { t.Fatalf("got %v", l) } }) t.Run("empty_string", func(t *testing.T) { var l Listish if err := json.Unmarshal([]byte(`""`), &l); err != nil { t.Fatal(err) } if l == nil || len(l) != 0 { t.Fatalf("expected non-nil empty, got %v", l) } }) t.Run("array", func(t *testing.T) { var l Listish if err := json.Unmarshal([]byte(`["a","b"]`), &l); err != nil { t.Fatal(err) } if len(l) != 2 { t.Fatalf("got %v", l) } }) t.Run("invalid", func(t *testing.T) { var l Listish err := json.Unmarshal([]byte(`123`), &l) if err == nil { t.Fatal("expected error") } }) } func TestCov_Listish_IsZero(t *testing.T) { if !Listish(nil).IsZero() { t.Fatal("nil should be zero") } if (Listish{"a"}).IsZero() { t.Fatal("non-empty should not be zero") } } func TestCov_Listish_MarshalJSON(t *testing.T) { t.Run("nil", func(t *testing.T) { b, _ := Listish(nil).MarshalJSON() if string(b) != "null" { t.Fatalf("got %s", b) } }) t.Run("single", func(t *testing.T) { b, _ := Listish{"x"}.MarshalJSON() if string(b) != `"x"` { t.Fatalf("got %s", b) } }) t.Run("multiple", func(t *testing.T) { b, _ := Listish{"a", "b"}.MarshalJSON() if string(b) != `["a","b"]` { t.Fatalf("got %s", b) } }) } // ============================================================ // types.go - SpaceDelimited // ============================================================ func TestCov_SpaceDelimited_UnmarshalJSON(t *testing.T) { t.Run("values", func(t *testing.T) { var s SpaceDelimited json.Unmarshal([]byte(`"openid profile"`), &s) if len(s) != 2 || s[0] != "openid" { t.Fatalf("got %v", s) } }) t.Run("empty", func(t *testing.T) { var s SpaceDelimited json.Unmarshal([]byte(`""`), &s) if s == nil { t.Fatal("expected non-nil empty SpaceDelimited, got nil") } if len(s) != 0 { t.Fatalf("expected empty, got %v", s) } }) t.Run("invalid", func(t *testing.T) { var s SpaceDelimited if err := json.Unmarshal([]byte(`123`), &s); err == nil { t.Fatal("expected error") } }) } func TestCov_SpaceDelimited_MarshalJSON(t *testing.T) { t.Run("populated", func(t *testing.T) { b, _ := SpaceDelimited{"a", "b"}.MarshalJSON() if string(b) != `"a b"` { t.Fatalf("got %s", b) } }) t.Run("nil", func(t *testing.T) { b, _ := SpaceDelimited(nil).MarshalJSON() if string(b) != `null` { t.Fatalf("expected null, got %s", b) } }) t.Run("empty_non_nil", func(t *testing.T) { b, _ := (SpaceDelimited{}).MarshalJSON() if string(b) != `""` { t.Fatalf("expected empty string, got %s", b) } }) } func TestCov_SpaceDelimited_IsZero(t *testing.T) { if !SpaceDelimited(nil).IsZero() { t.Fatal("nil should be zero") } if (SpaceDelimited{}).IsZero() { t.Fatal("non-nil empty should not be zero") } if (SpaceDelimited{"a"}).IsZero() { t.Fatal("populated should not be zero") } } func TestCov_SpaceDelimited_Omitzero(t *testing.T) { // Verify struct-level marshaling: nil scope omitted, empty scope present type tc struct { Scope SpaceDelimited `json:"scope,omitzero"` } // nil scope -> field omitted b, _ := json.Marshal(tc{Scope: nil}) if strings.Contains(string(b), "scope") { t.Fatalf("nil scope should be omitted, got %s", b) } // non-nil empty scope -> "scope":"" b, _ = json.Marshal(tc{Scope: SpaceDelimited{}}) if !strings.Contains(string(b), `"scope":""`) { t.Fatalf("empty scope should marshal as empty string, got %s", b) } // populated scope -> "scope":"a b" b, _ = json.Marshal(tc{Scope: SpaceDelimited{"a", "b"}}) if !strings.Contains(string(b), `"scope":"a b"`) { t.Fatalf("populated scope should marshal as space-separated, got %s", b) } } // ============================================================ // types.go - NullBool // ============================================================ func TestCov_NullBool(t *testing.T) { t.Run("IsZero", func(t *testing.T) { if !(NullBool{}).IsZero() { t.Fatal("zero value should be zero") } if (NullBool{Bool: true, Valid: true}).IsZero() { //nolint t.Fatal("valid should not be zero") } }) t.Run("MarshalJSON", func(t *testing.T) { b, _ := NullBool{}.MarshalJSON() if string(b) != "null" { t.Fatalf("got %s", b) } b, _ = NullBool{Bool: true, Valid: true}.MarshalJSON() if string(b) != "true" { t.Fatalf("got %s", b) } b, _ = NullBool{Bool: false, Valid: true}.MarshalJSON() if string(b) != "false" { t.Fatalf("got %s", b) } }) t.Run("UnmarshalJSON", func(t *testing.T) { var nb NullBool json.Unmarshal([]byte("null"), &nb) if nb.Valid { t.Fatal("null should not be valid") } json.Unmarshal([]byte("true"), &nb) if !nb.Valid || !nb.Bool { t.Fatal("expected true") } json.Unmarshal([]byte("false"), &nb) if !nb.Valid || nb.Bool { t.Fatal("expected false") } if err := nb.UnmarshalJSON([]byte(`"yes"`)); err == nil { t.Fatal("expected error for string") } }) } // ============================================================ // jwt.go // ============================================================ func TestCov_DecodeRaw(t *testing.T) { t.Run("happy", func(t *testing.T) { // Build a valid token to decode pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) tok := mustSignStr(t, s, goodClaims()) raw, err := DecodeRaw(tok) if err != nil { t.Fatal(err) } if len(raw.Protected) == 0 || len(raw.Payload) == 0 || len(raw.Signature) == 0 { t.Fatal("expected non-empty segments") } }) t.Run("empty", func(t *testing.T) { _, err := DecodeRaw("") if !errors.Is(err, ErrMalformedToken) { t.Fatalf("expected ErrMalformedToken, got %v", err) } }) t.Run("two_parts", func(t *testing.T) { _, err := DecodeRaw("a.b") if !errors.Is(err, ErrMalformedToken) { t.Fatal("expected ErrMalformedToken") } }) t.Run("four_parts", func(t *testing.T) { _, err := DecodeRaw("a.b.c.d") if !errors.Is(err, ErrMalformedToken) { t.Fatal("expected ErrMalformedToken") } }) t.Run("bad_sig_base64", func(t *testing.T) { _, err := DecodeRaw("a.b.!!!invalid!!!") if !errors.Is(err, ErrSignatureInvalid) { t.Fatalf("expected ErrSignatureInvalid, got %v", err) } }) } func TestCov_Decode(t *testing.T) { t.Run("happy", func(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) tok := mustSignStr(t, s, goodClaims()) jws, err := Decode(tok) if err != nil { t.Fatal(err) } if jws.GetHeader().Alg != "EdDSA" { t.Fatalf("expected EdDSA, got %s", jws.GetHeader().Alg) } }) t.Run("bad_token", func(t *testing.T) { _, err := Decode("bad") if err == nil { t.Fatal("expected error") } }) t.Run("bad_header_json", func(t *testing.T) { // valid base64 but not valid JSON header badHdr := base64.RawURLEncoding.EncodeToString([]byte("not json")) payload := base64.RawURLEncoding.EncodeToString([]byte("{}")) sig := base64.RawURLEncoding.EncodeToString([]byte("sig")) _, err := Decode(badHdr + "." + payload + "." + sig) if !errors.Is(err, ErrInvalidHeader) { t.Fatalf("expected ErrInvalidHeader, got %v", err) } }) } func TestCov_UnmarshalClaims(t *testing.T) { t.Run("happy", func(t *testing.T) { raw := &RawJWT{ Payload: []byte(base64.RawURLEncoding.EncodeToString([]byte(`{"iss":"x"}`))), } var tc TokenClaims if err := raw.UnmarshalClaims(&tc); err != nil { t.Fatal(err) } if tc.Iss != "x" { t.Fatalf("got %q", tc.Iss) } }) t.Run("bad_base64", func(t *testing.T) { raw := &RawJWT{Payload: []byte("!!!")} err := raw.UnmarshalClaims(&TokenClaims{}) if !errors.Is(err, ErrInvalidPayload) { t.Fatalf("expected ErrInvalidPayload, got %v", err) } }) t.Run("bad_json", func(t *testing.T) { raw := &RawJWT{ Payload: []byte(base64.RawURLEncoding.EncodeToString([]byte("not json"))), } err := raw.UnmarshalClaims(&TokenClaims{}) if !errors.Is(err, ErrInvalidPayload) { t.Fatalf("expected ErrInvalidPayload, got %v", err) } }) } func TestCov_UnmarshalHeader(t *testing.T) { t.Run("happy", func(t *testing.T) { hdrJSON := `{"alg":"EdDSA","kid":"k1","typ":"JWT"}` raw := &RawJWT{ Protected: []byte(base64.RawURLEncoding.EncodeToString([]byte(hdrJSON))), } var h RFCHeader if err := raw.UnmarshalHeader(&h); err != nil { t.Fatal(err) } if h.Alg != "EdDSA" || h.KID != "k1" || h.Typ != "JWT" { t.Fatalf("got %+v", h) } }) t.Run("bad_base64", func(t *testing.T) { raw := &RawJWT{Protected: []byte("!!!")} err := raw.UnmarshalHeader(&RFCHeader{}) if !errors.Is(err, ErrInvalidHeader) { t.Fatalf("expected ErrInvalidHeader, got %v", err) } }) t.Run("bad_json", func(t *testing.T) { raw := &RawJWT{ Protected: []byte(base64.RawURLEncoding.EncodeToString([]byte("not json"))), } err := raw.UnmarshalHeader(&RFCHeader{}) if !errors.Is(err, ErrInvalidHeader) { t.Fatalf("expected ErrInvalidHeader, got %v", err) } }) } func TestCov_New(t *testing.T) { tc := goodClaims() jws, err := New(tc) if err != nil { t.Fatal(err) } h := jws.GetHeader() if h.Typ != "JWT" { t.Fatalf("expected JWT typ, got %q", h.Typ) } } func TestCov_New_BadClaims(t *testing.T) { _, err := New(&badClaims{Bad: make(chan int)}) if err == nil { t.Fatal("expected marshal error") } } func TestCov_NewAccessToken(t *testing.T) { jws, err := NewAccessToken(goodClaims()) if err != nil { t.Fatal(err) } // SetTyp was called; header isn't re-parsed until SetHeader, but // the internal header.Typ should be "at+jwt" if jws.header.Typ != AccessTokenTyp { t.Fatalf("expected %q, got %q", AccessTokenTyp, jws.header.Typ) } } func TestCov_Encode(t *testing.T) { t.Run("unsigned_rejected", func(t *testing.T) { jws, _ := New(goodClaims()) _, err := Encode(jws) if !errors.Is(err, ErrInvalidHeader) { t.Fatalf("expected ErrInvalidHeader, got %v", err) } }) t.Run("signed_ok", func(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) jws, _ := s.Sign(goodClaims()) str, err := Encode(jws) if err != nil { t.Fatal(err) } if strings.Count(str, ".") != 2 { t.Fatal("expected 3 segments") } }) } func TestCov_JWT_Encode(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) jws, _ := s.Sign(goodClaims()) str, err := jws.Encode() if err != nil { t.Fatal(err) } if strings.Count(str, ".") != 2 { t.Fatal("expected 3 segments") } } func TestCov_JWT_SetTyp(t *testing.T) { jws, _ := New(goodClaims()) jws.SetTyp("at+jwt") if jws.header.Typ != "at+jwt" { t.Fatal("SetTyp did not work") } } func TestCov_JWT_SetHeader(t *testing.T) { jws, _ := New(goodClaims()) h := RFCHeader{Alg: "EdDSA", KID: "k1", Typ: "JWT"} if err := jws.SetHeader(&h); err != nil { t.Fatal(err) } got := jws.GetHeader() if got.Alg != "EdDSA" || got.KID != "k1" { t.Fatalf("got %+v", got) } } func TestCov_RFCHeader_GetRFCHeader(t *testing.T) { h := &RFCHeader{Alg: "x"} if h.GetRFCHeader() != h { t.Fatal("expected same pointer") } } func TestCov_RawJWT_Accessors(t *testing.T) { raw := &RawJWT{ Protected: []byte("p"), Payload: []byte("a"), Signature: []byte("s"), } if string(raw.GetProtected()) != "p" { t.Fatal() } if string(raw.GetPayload()) != "a" { t.Fatal() } if string(raw.GetSignature()) != "s" { t.Fatal() } raw.SetSignature([]byte("s2")) if string(raw.Signature) != "s2" { t.Fatal() } } func TestCov_RawJWT_JSON(t *testing.T) { t.Run("round_trip", func(t *testing.T) { orig := &RawJWT{ Protected: []byte("hdr"), Payload: []byte("pay"), Signature: []byte{1, 2, 3}, } data, err := json.Marshal(orig) if err != nil { t.Fatal(err) } var got RawJWT if err := json.Unmarshal(data, &got); err != nil { t.Fatal(err) } if string(got.Protected) != "hdr" || string(got.Payload) != "pay" { t.Fatalf("got %+v", got) } }) t.Run("bad_json", func(t *testing.T) { var r RawJWT if err := r.UnmarshalJSON([]byte("not json")); err == nil { t.Fatal("expected error") } }) t.Run("bad_sig_base64", func(t *testing.T) { var r RawJWT err := r.UnmarshalJSON([]byte(`{"protected":"a","payload":"b","signature":"!!!"}`)) if err == nil { t.Fatal("expected error") } }) } func TestCov_SetClaims(t *testing.T) { raw := &RawJWT{} tc := goodClaims() if err := raw.SetClaims(tc); err != nil { t.Fatal(err) } if len(raw.Payload) == 0 { t.Fatal("expected non-empty payload") } } func TestCov_SetClaims_Bad(t *testing.T) { raw := &RawJWT{} err := raw.SetClaims(&badClaims{Bad: make(chan int)}) if err == nil { t.Fatal("expected error") } } // ============================================================ // jwa.go // ============================================================ func TestCov_ecInfo(t *testing.T) { for _, tc := range []struct { curve elliptic.Curve alg string }{ {elliptic.P256(), "ES256"}, {elliptic.P384(), "ES384"}, {elliptic.P521(), "ES512"}, } { ci, err := ecInfo(tc.curve) if err != nil { t.Fatal(err) } if ci.Alg != tc.alg { t.Fatalf("expected %s, got %s", tc.alg, ci.Alg) } } // unsupported curve - use a custom curve params badCurve := &elliptic.CurveParams{Name: "bad", BitSize: 128} _, err := ecInfo(badCurve) if !errors.Is(err, ErrUnsupportedCurve) { t.Fatalf("expected ErrUnsupportedCurve, got %v", err) } } func TestCov_ecInfoByCrv(t *testing.T) { for _, crv := range []string{"P-256", "P-384", "P-521"} { if _, err := ecInfoByCrv(crv); err != nil { t.Fatal(err) } } _, err := ecInfoByCrv("P-192") if !errors.Is(err, ErrUnsupportedCurve) { t.Fatal("expected ErrUnsupportedCurve") } } func TestCov_ecInfoForAlg(t *testing.T) { t.Run("match", func(t *testing.T) { ci, err := ecInfoForAlg(elliptic.P256(), "ES256") if err != nil || ci.Alg != "ES256" { t.Fatal("expected match") } }) t.Run("mismatch", func(t *testing.T) { _, err := ecInfoForAlg(elliptic.P256(), "ES384") if !errors.Is(err, ErrAlgConflict) { t.Fatal("expected ErrAlgConflict") } }) t.Run("bad_curve", func(t *testing.T) { badCurve := &elliptic.CurveParams{Name: "bad", BitSize: 128} _, err := ecInfoForAlg(badCurve, "ES256") if !errors.Is(err, ErrUnsupportedCurve) { t.Fatal("expected ErrUnsupportedCurve") } }) } func TestCov_signingParams(t *testing.T) { t.Run("EC", func(t *testing.T) { k := mustECKey(t, elliptic.P256()) alg, hash, ecKeySize, err := signingParams(k) if err != nil || alg != "ES256" || hash != crypto.SHA256 || ecKeySize != 32 { t.Fatalf("got %s %v %d %v", alg, hash, ecKeySize, err) } }) t.Run("RSA", func(t *testing.T) { k := mustRSAKey(t) alg, hash, ecKeySize, err := signingParams(k) if err != nil || alg != "RS256" || hash != crypto.SHA256 || ecKeySize != 0 { t.Fatalf("got %s %v %d %v", alg, hash, ecKeySize, err) } }) t.Run("Ed25519", func(t *testing.T) { k := mustEdKey(t) alg, hash, ecKeySize, err := signingParams(k) if err != nil || alg != "EdDSA" || hash != 0 || ecKeySize != 0 { t.Fatalf("got %s %v %d %v", alg, hash, ecKeySize, err) } }) t.Run("unsupported", func(t *testing.T) { _, _, _, err := signingParams(fakeSigner{pub: fakeKey{}}) if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatal("expected ErrUnsupportedKeyType") } }) } func TestCov_signingInputBytes(t *testing.T) { out := signingInputBytes([]byte("hdr"), []byte("pay")) if string(out) != "hdr.pay" { t.Fatalf("got %q", out) } } func TestCov_digestFor(t *testing.T) { t.Run("valid", func(t *testing.T) { d, err := digestFor(crypto.SHA256, []byte("hello")) if err != nil || len(d) != 32 { t.Fatal(err) } }) t.Run("unavailable", func(t *testing.T) { _, err := digestFor(crypto.Hash(99), []byte("hello")) if !errors.Is(err, ErrUnsupportedAlg) { t.Fatal("expected ErrUnsupportedAlg") } }) } func TestCov_ecdsaDERToP1363(t *testing.T) { keySize := 32 t.Run("valid", func(t *testing.T) { type ecSig struct{ R, S *big.Int } der, _ := asn1.Marshal(ecSig{big.NewInt(42), big.NewInt(99)}) out, err := ecdsaDERToP1363(der, keySize) if err != nil || len(out) != 2*keySize { t.Fatalf("err=%v len=%d", err, len(out)) } }) t.Run("bad_asn1", func(t *testing.T) { _, err := ecdsaDERToP1363([]byte{0xff}, keySize) if err == nil { t.Fatal("expected error") } }) t.Run("trailing_bytes", func(t *testing.T) { type ecSig struct{ R, S *big.Int } der, _ := asn1.Marshal(ecSig{big.NewInt(1), big.NewInt(1)}) der = append(der, 0x00) // trailing byte _, err := ecdsaDERToP1363(der, keySize) if !errors.Is(err, ErrSignatureInvalid) { t.Fatalf("expected ErrSignatureInvalid, got %v", err) } }) t.Run("R_too_large", func(t *testing.T) { type ecSig struct{ R, S *big.Int } bigR := new(big.Int).SetBytes(make([]byte, keySize+1)) bigR.SetBit(bigR, (keySize+1)*8-1, 1) // ensure it's keySize+1 bytes der, _ := asn1.Marshal(ecSig{bigR, big.NewInt(1)}) _, err := ecdsaDERToP1363(der, keySize) if !errors.Is(err, ErrSignatureInvalid) { t.Fatalf("expected ErrSignatureInvalid, got %v", err) } }) } // ============================================================ // jwk.go // ============================================================ func TestCov_KeyType(t *testing.T) { ec := mustECKey(t, elliptic.P256()) rs := mustRSAKey(t) ed := mustEdKey(t) tests := []struct { key CryptoPublicKey expect string }{ {&ec.PublicKey, "EC"}, {&rs.PublicKey, "RSA"}, {ed.Public().(ed25519.PublicKey), "OKP"}, {fakeKey{}, ""}, } for _, tt := range tests { pk := PublicKey{Pub: tt.key} if got := pk.KeyType(); got != tt.expect { t.Errorf("KeyType(%T)=%q want %q", tt.key, got, tt.expect) } } } func TestCov_PublicKey_JSON_AllTypes(t *testing.T) { for _, name := range []string{"EC", "RSA", "Ed25519"} { t.Run(name, func(t *testing.T) { var pub crypto.PublicKey switch name { case "EC": pub = &mustECKey(t, elliptic.P256()).PublicKey case "RSA": pub = &mustRSAKey(t).PublicKey case "Ed25519": pub = mustEdKey(t).Public() } pk, err := FromPublicKey(pub) if err != nil { t.Fatal(err) } data, err := json.Marshal(pk) if err != nil { t.Fatal(err) } var decoded PublicKey if err := json.Unmarshal(data, &decoded); err != nil { t.Fatal(err) } if decoded.KID != pk.KID { t.Fatalf("KID mismatch: %q vs %q", decoded.KID, pk.KID) } }) } } func TestCov_PublicKey_UnmarshalJSON_Errors(t *testing.T) { t.Run("bad_json", func(t *testing.T) { var pk PublicKey if err := pk.UnmarshalJSON([]byte("not json")); err == nil { t.Fatal("expected error") } }) t.Run("unknown_kty", func(t *testing.T) { var pk PublicKey err := pk.UnmarshalJSON([]byte(`{"kty":"UNKNOWN"}`)) if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) } }) } func TestCov_PrivateKey_JSON_AllTypes(t *testing.T) { for _, name := range []string{"EC", "RSA", "Ed25519"} { t.Run(name, func(t *testing.T) { var signer crypto.Signer switch name { case "EC": signer = mustECKey(t, elliptic.P384()) case "RSA": signer = mustRSAKey(t) case "Ed25519": signer = mustEdKey(t) } pk, err := FromPrivateKey(signer, "test-kid") if err != nil { t.Fatal(err) } data, err := json.Marshal(pk) if err != nil { t.Fatal(err) } var decoded PrivateKey if err := json.Unmarshal(data, &decoded); err != nil { t.Fatal(err) } if decoded.KID != pk.KID { t.Fatalf("KID mismatch: %q vs %q", decoded.KID, pk.KID) } }) } } func TestCov_PrivateKey_UnmarshalJSON_Errors(t *testing.T) { t.Run("bad_json", func(t *testing.T) { var pk PrivateKey if err := pk.UnmarshalJSON([]byte("not json")); err == nil { t.Fatal("expected error") } }) t.Run("missing_d", func(t *testing.T) { var pk PrivateKey err := pk.UnmarshalJSON([]byte(`{"kty":"EC","crv":"P-256","x":"a","y":"b"}`)) if !errors.Is(err, ErrMissingKeyData) { t.Fatalf("expected ErrMissingKeyData, got %v", err) } }) } func TestCov_Thumbprint(t *testing.T) { for _, name := range []string{"EC", "RSA", "Ed25519"} { t.Run(name, func(t *testing.T) { var pub crypto.PublicKey switch name { case "EC": pub = &mustECKey(t, elliptic.P521()).PublicKey case "RSA": pub = &mustRSAKey(t).PublicKey case "Ed25519": pub = mustEdKey(t).Public() } pk, _ := FromPublicKey(pub) thumb, err := pk.Thumbprint() if err != nil || thumb == "" { t.Fatalf("err=%v thumb=%q", err, thumb) } // deterministic thumb2, _ := pk.Thumbprint() if thumb != thumb2 { t.Fatal("not deterministic") } }) } t.Run("unsupported", func(t *testing.T) { pk := PublicKey{Pub: fakeKey{}} _, err := pk.Thumbprint() if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) } }) } func TestCov_PrivateKey_Thumbprint(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) pub, err := pk.PublicKey() if err != nil { t.Fatal(err) } thumb, err := pub.Thumbprint() if err != nil || thumb == "" { t.Fatal(err) } } func TestCov_PrivateKey_PublicKey(t *testing.T) { t.Run("happy", func(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) pub, err := pk.PublicKey() if err != nil { t.Fatal(err) } if pub.KID != pk.KID { t.Fatal("KID mismatch") } }) t.Run("bad_signer", func(t *testing.T) { // signer whose Public() returns a non-CryptoPublicKey pk := &PrivateKey{Priv: fakeSigner{pub: "not a key"}} _, err := pk.PublicKey() if !errors.Is(err, ErrSanityFail) { t.Fatalf("expected ErrSanityFail, got %v", err) } }) } func TestCov_NewPrivateKey(t *testing.T) { pk, err := NewPrivateKey() if err != nil { t.Fatal(err) } if pk.KID == "" { t.Fatal("expected auto KID") } // Should be Ed25519 if _, ok := pk.Priv.(ed25519.PrivateKey); !ok { t.Fatalf("expected Ed25519, got %T", pk.Priv) } } func TestCov_FromPublicKey(t *testing.T) { t.Run("EC", func(t *testing.T) { pk, err := FromPublicKey(&mustECKey(t, elliptic.P256()).PublicKey) if err != nil || pk.Alg != "ES256" { t.Fatalf("err=%v alg=%s", err, pk.Alg) } }) t.Run("RSA", func(t *testing.T) { pk, err := FromPublicKey(&mustRSAKey(t).PublicKey) if err != nil || pk.Alg != "RS256" { t.Fatalf("err=%v alg=%s", err, pk.Alg) } }) t.Run("Ed25519", func(t *testing.T) { pk, err := FromPublicKey(mustEdKey(t).Public()) if err != nil || pk.Alg != "EdDSA" { t.Fatalf("err=%v alg=%s", err, pk.Alg) } }) t.Run("not_crypto_key", func(t *testing.T) { _, err := FromPublicKey("not a key") if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatal("expected ErrUnsupportedKeyType") } }) t.Run("unsupported_crypto_key", func(t *testing.T) { _, err := FromPublicKey(fakeKey{}) if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatal("expected ErrUnsupportedKeyType") } }) } func TestCov_FromPrivateKey(t *testing.T) { t.Run("happy", func(t *testing.T) { pk, err := FromPrivateKey(mustEdKey(t), "my-kid") if err != nil || pk.KID != "my-kid" || pk.Alg != "EdDSA" { t.Fatalf("err=%v kid=%s alg=%s", err, pk.KID, pk.Alg) } }) t.Run("unsupported", func(t *testing.T) { _, err := FromPrivateKey(fakeSigner{pub: fakeKey{}}, "") if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatal("expected ErrUnsupportedKeyType") } }) } func TestCov_ParseJWK(t *testing.T) { // Generate a key, marshal it, parse it back ed := mustEdKey(t) pk, _ := FromPublicKey(ed.Public()) data, _ := json.Marshal(pk) t.Run("ParsePublicJWK", func(t *testing.T) { got, err := ParsePublicJWK(data) if err != nil || got.KID != pk.KID { t.Fatalf("err=%v kid=%s", err, got.KID) } }) // Private key priv, _ := FromPrivateKey(ed, "k1") privData, _ := json.Marshal(priv) t.Run("ParsePrivateJWK", func(t *testing.T) { got, err := ParsePrivateJWK(privData) if err != nil || got.KID != "k1" { t.Fatalf("err=%v kid=%s", err, got.KID) } }) t.Run("ParseWellKnownJWKs", func(t *testing.T) { jwksData := fmt.Sprintf(`{"keys":[%s]}`, string(data)) got, err := ParseWellKnownJWKs([]byte(jwksData)) if err != nil || len(got.Keys) != 1 { t.Fatalf("err=%v len=%d", err, len(got.Keys)) } }) } func TestCov_decodeRSA_Errors(t *testing.T) { t.Run("bad_n", func(t *testing.T) { _, err := decodeRSA(rawKey{Kty: "RSA", N: "!!!", E: "AQAB"}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("bad_e", func(t *testing.T) { _, err := decodeRSA(rawKey{Kty: "RSA", N: "AAAA", E: "!!!"}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("exponent_too_small", func(t *testing.T) { // e=1 e := base64.RawURLEncoding.EncodeToString(big.NewInt(1).Bytes()) n := base64.RawURLEncoding.EncodeToString(make([]byte, 256)) _, err := decodeRSA(rawKey{Kty: "RSA", N: n, E: e}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("exponent_too_large_32bit", func(t *testing.T) { // e > MaxInt32 but fits in int64 e := base64.RawURLEncoding.EncodeToString(big.NewInt(1<<31 + 1).Bytes()) n := base64.RawURLEncoding.EncodeToString(make([]byte, 256)) _, err := decodeRSA(rawKey{Kty: "RSA", N: n, E: e}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("exponent_too_large_int64", func(t *testing.T) { // e that doesn't fit in int64 bigE := new(big.Int).Lsh(big.NewInt(1), 64) e := base64.RawURLEncoding.EncodeToString(bigE.Bytes()) n := base64.RawURLEncoding.EncodeToString(make([]byte, 256)) _, err := decodeRSA(rawKey{Kty: "RSA", N: n, E: e}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("key_too_small", func(t *testing.T) { // 512-bit key n := base64.RawURLEncoding.EncodeToString(make([]byte, 64)) e := base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()) _, err := decodeRSA(rawKey{Kty: "RSA", N: n, E: e}) if !errors.Is(err, ErrKeyTooSmall) { t.Fatalf("expected ErrKeyTooSmall, got %v", err) } }) } func TestCov_decodeEC_Errors(t *testing.T) { zeros32 := base64.RawURLEncoding.EncodeToString(make([]byte, 32)) t.Run("bad_x", func(t *testing.T) { _, err := decodeEC(rawKey{Kty: "EC", Crv: "P-256", X: "!!!", Y: zeros32}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("bad_y", func(t *testing.T) { _, err := decodeEC(rawKey{Kty: "EC", Crv: "P-256", X: zeros32, Y: "!!!"}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("unsupported_crv", func(t *testing.T) { _, err := decodeEC(rawKey{Kty: "EC", Crv: "P-192", X: zeros32, Y: zeros32}) if !errors.Is(err, ErrUnsupportedCurve) { t.Fatalf("expected ErrUnsupportedCurve, got %v", err) } }) t.Run("x_too_long", func(t *testing.T) { longX := base64.RawURLEncoding.EncodeToString(make([]byte, 33)) _, err := decodeEC(rawKey{Kty: "EC", Crv: "P-256", X: longX, Y: zeros32}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("y_too_long", func(t *testing.T) { longY := base64.RawURLEncoding.EncodeToString(make([]byte, 33)) _, err := decodeEC(rawKey{Kty: "EC", Crv: "P-256", X: zeros32, Y: longY}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("not_on_curve", func(t *testing.T) { // (0, 0) is not on P-256 _, err := decodeEC(rawKey{Kty: "EC", Crv: "P-256", X: zeros32, Y: zeros32}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) } func TestCov_decodeOKP_Errors(t *testing.T) { t.Run("wrong_crv", func(t *testing.T) { _, err := decodeOKP(rawKey{Kty: "OKP", Crv: "X25519"}) if !errors.Is(err, ErrUnsupportedCurve) { t.Fatalf("expected ErrUnsupportedCurve, got %v", err) } }) t.Run("bad_x", func(t *testing.T) { _, err := decodeOKP(rawKey{Kty: "OKP", Crv: "Ed25519", X: "!!!"}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("wrong_size", func(t *testing.T) { x := base64.RawURLEncoding.EncodeToString(make([]byte, 31)) _, err := decodeOKP(rawKey{Kty: "OKP", Crv: "Ed25519", X: x}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) } func TestCov_decodePrivate_Errors(t *testing.T) { t.Run("missing_d", func(t *testing.T) { _, err := decodePrivate(rawKey{Kty: "EC", Crv: "P-256"}) if !errors.Is(err, ErrMissingKeyData) { t.Fatalf("expected ErrMissingKeyData, got %v", err) } }) t.Run("unknown_kty", func(t *testing.T) { _, err := decodePrivate(rawKey{Kty: "UNKNOWN", D: "AA"}) if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) } }) t.Run("OKP_wrong_crv", func(t *testing.T) { _, err := decodePrivate(rawKey{Kty: "OKP", Crv: "X25519", D: "AA"}) if !errors.Is(err, ErrUnsupportedCurve) { t.Fatalf("expected ErrUnsupportedCurve, got %v", err) } }) t.Run("Ed25519_wrong_seed_size", func(t *testing.T) { d := base64.RawURLEncoding.EncodeToString(make([]byte, 31)) _, err := decodePrivate(rawKey{Kty: "OKP", Crv: "Ed25519", D: d}) if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) t.Run("EC_bad_d", func(t *testing.T) { _, err := decodePrivate(rawKey{Kty: "EC", Crv: "P-256", D: "!!!"}) if err == nil { t.Fatal("expected error") } }) t.Run("RSA_bad_d", func(t *testing.T) { n := base64.RawURLEncoding.EncodeToString(make([]byte, 256)) e := base64.RawURLEncoding.EncodeToString(big.NewInt(65537).Bytes()) _, err := decodePrivate(rawKey{Kty: "RSA", N: n, E: e, D: "!!!"}) if err == nil { t.Fatal("expected error") } }) t.Run("OKP_bad_d", func(t *testing.T) { _, err := decodePrivate(rawKey{Kty: "OKP", Crv: "Ed25519", D: "!!!"}) if err == nil { t.Fatal("expected error") } }) t.Run("auto_kid", func(t *testing.T) { // Valid Ed25519 private key with no KID - should auto-compute seed := make([]byte, ed25519.SeedSize) rand.Read(seed) d := base64.RawURLEncoding.EncodeToString(seed) x := base64.RawURLEncoding.EncodeToString(ed25519.NewKeyFromSeed(seed).Public().(ed25519.PublicKey)) pk, err := decodePrivate(rawKey{Kty: "OKP", Crv: "Ed25519", D: d, X: x}) if err != nil { t.Fatal(err) } if pk.KID == "" { t.Fatal("expected auto KID") } }) } func TestCov_decodeOne_AutoKID(t *testing.T) { // decodeOne with no kid should auto-compute from thumbprint pub := mustEdKey(t).Public().(ed25519.PublicKey) x := base64.RawURLEncoding.EncodeToString([]byte(pub)) pk, err := decodeOne(rawKey{Kty: "OKP", Crv: "Ed25519", X: x}) if err != nil { t.Fatal(err) } if pk.KID == "" { t.Fatal("expected auto KID") } } func TestCov_decodeB64Field(t *testing.T) { t.Run("valid", func(t *testing.T) { b, err := decodeB64Field("EC", "kid1", "d", base64.RawURLEncoding.EncodeToString([]byte{1, 2, 3})) if err != nil || len(b) != 3 { t.Fatal(err) } }) t.Run("bad_base64", func(t *testing.T) { _, err := decodeB64Field("EC", "kid1", "d", "!!!") if !errors.Is(err, ErrInvalidKey) { t.Fatalf("expected ErrInvalidKey, got %v", err) } }) } func TestCov_toPublicKeyOps(t *testing.T) { tests := []struct { in []string expect []string }{ {nil, nil}, {[]string{"sign"}, []string{"verify"}}, {[]string{"decrypt"}, []string{"encrypt"}}, {[]string{"unwrapKey"}, []string{"wrapKey"}}, {[]string{"verify", "encrypt", "wrapKey"}, []string{"verify", "encrypt", "wrapKey"}}, {[]string{"deriveKey"}, nil}, // unrecognized ops dropped } for _, tt := range tests { got := toPublicKeyOps(tt.in) if len(got) != len(tt.expect) { t.Errorf("toPublicKeyOps(%v)=%v want %v", tt.in, got, tt.expect) } } } func TestCov_encode_Unsupported(t *testing.T) { _, err := encode(PublicKey{Pub: fakeKey{}}) if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) } } func TestCov_encodePrivate_Unsupported(t *testing.T) { // Signer whose Public() returns a real ed25519 key but signer type is custom edKey := mustEdKey(t) pk := PrivateKey{ Priv: fakeSigner{pub: edKey.Public()}, KID: "test", } _, err := encodePrivate(pk) if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) } } func TestCov_encodePrivate_RSA_NoPrimes(t *testing.T) { rsaKey := mustRSAKey(t) rsaKey.Primes = nil // remove primes pk := PrivateKey{Priv: rsaKey, KID: "test"} rk, err := encodePrivate(pk) if err != nil { t.Fatal(err) } // Should have D but no P/Q if rk.D == "" { t.Fatal("expected D") } if rk.P != "" || rk.Q != "" { t.Fatal("expected no P/Q without primes") } } // ============================================================ // sign.go // ============================================================ func TestCov_NewSigner(t *testing.T) { t.Run("empty_keys", func(t *testing.T) { _, err := NewSigner(nil) if !errors.Is(err, ErrNoSigningKey) { t.Fatal("expected ErrNoSigningKey") } }) t.Run("nil_key", func(t *testing.T) { _, err := NewSigner([]*PrivateKey{nil}) if !errors.Is(err, ErrNoSigningKey) { t.Fatal("expected ErrNoSigningKey") } }) t.Run("nil_privkey", func(t *testing.T) { _, err := NewSigner([]*PrivateKey{{}}) if !errors.Is(err, ErrNoSigningKey) { t.Fatal("expected ErrNoSigningKey") } }) t.Run("alg_conflict", func(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) pk.Alg = "RS256" // wrong alg for Ed25519 _, err := NewSigner([]*PrivateKey{pk}) if !errors.Is(err, ErrAlgConflict) { t.Fatalf("expected ErrAlgConflict, got %v", err) } }) t.Run("wrong_use", func(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) pk.Use = "enc" _, err := NewSigner([]*PrivateKey{pk}) if err == nil { t.Fatal("expected error for use=enc") } }) t.Run("auto_kid", func(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) pk.KID = "" // clear to trigger auto-compute s, err := NewSigner([]*PrivateKey{pk}) if err != nil { t.Fatal(err) } if s.keys[0].KID == "" { t.Fatal("expected auto KID") } }) t.Run("multiple_keys", func(t *testing.T) { pk1 := mustFromPrivate(t, mustEdKey(t)) pk2 := mustFromPrivate(t, mustECKey(t, elliptic.P256())) s, err := NewSigner([]*PrivateKey{pk1, pk2}) if err != nil { t.Fatal(err) } if len(s.keys) != 2 { t.Fatal("expected 2 keys") } }) t.Run("retired_keys", func(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) retired, _ := FromPublicKey(mustEdKey(t).Public()) s, err := NewSigner([]*PrivateKey{pk}, *retired) if err != nil { t.Fatal(err) } if len(s.Keys) != 2 { t.Fatalf("expected 2 JWKS keys, got %d", len(s.Keys)) } }) } func TestCov_SignJWT(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) t.Run("happy", func(t *testing.T) { jws, _ := New(goodClaims()) if err := s.SignJWT(jws); err != nil { t.Fatal(err) } if jws.GetHeader().Alg != "EdDSA" { t.Fatal("expected EdDSA") } }) t.Run("with_kid", func(t *testing.T) { jws, _ := New(goodClaims()) jws.header.KID = pk.KID if err := s.SignJWT(jws); err != nil { t.Fatal(err) } }) t.Run("unknown_kid", func(t *testing.T) { jws, _ := New(goodClaims()) jws.header.KID = "nonexistent" err := s.SignJWT(jws) if !errors.Is(err, ErrUnknownKID) { t.Fatalf("expected ErrUnknownKID, got %v", err) } }) t.Run("alg_conflict", func(t *testing.T) { jws, _ := New(goodClaims()) jws.header.Alg = "RS256" err := s.SignJWT(jws) if !errors.Is(err, ErrAlgConflict) { t.Fatalf("expected ErrAlgConflict, got %v", err) } }) t.Run("nil_privkey", func(t *testing.T) { // Construct signer directly (bypass NewSigner validation) bad := &Signer{keys: []PrivateKey{{KID: "test"}}} jws, _ := New(goodClaims()) err := bad.SignJWT(jws) if !errors.Is(err, ErrNoSigningKey) { t.Fatalf("expected ErrNoSigningKey, got %v", err) } }) } func TestCov_SignRaw(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) t.Run("happy", func(t *testing.T) { hdr := &RFCHeader{Typ: "JWT"} raw, err := s.SignRaw(hdr, []byte(`{"sub":"user"}`)) if err != nil { t.Fatal(err) } if len(raw.Signature) == 0 { t.Fatal("expected signature") } }) t.Run("nil_payload", func(t *testing.T) { hdr := &RFCHeader{Typ: "JWT"} raw, err := s.SignRaw(hdr, nil) if err != nil { t.Fatal(err) } if raw.Payload == nil { t.Fatal("expected non-nil payload") } }) t.Run("alg_conflict", func(t *testing.T) { hdr := &RFCHeader{Alg: "RS256"} _, err := s.SignRaw(hdr, nil) if !errors.Is(err, ErrAlgConflict) { t.Fatal("expected ErrAlgConflict") } }) } func TestCov_Sign(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) jws, err := s.Sign(goodClaims()) if err != nil { t.Fatal(err) } if jws.GetHeader().Alg != "EdDSA" { t.Fatal("expected EdDSA") } } func TestCov_SignToString(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) tok, err := s.SignToString(goodClaims()) if err != nil { t.Fatal(err) } if strings.Count(tok, ".") != 2 { t.Fatal("expected 3 segments") } } func TestCov_Signer_Verifier(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) v := s.Verifier() if v == nil { t.Fatal("expected verifier") } // Should work for round-trip tok := mustSignStr(t, s, goodClaims()) if _, err := v.VerifyJWT(tok); err != nil { t.Fatal(err) } } func TestCov_RoundRobin(t *testing.T) { pk1 := mustFromPrivate(t, mustEdKey(t)) pk2 := mustFromPrivate(t, mustECKey(t, elliptic.P256())) s := mustSigner(t, pk1, pk2) // Sign twice, should use different keys tok1 := mustSignStr(t, s, goodClaims()) tok2 := mustSignStr(t, s, goodClaims()) jws1, _ := Decode(tok1) jws2, _ := Decode(tok2) if jws1.GetHeader().KID == jws2.GetHeader().KID { t.Fatal("expected different KIDs from round-robin") } } // ============================================================ // verify.go // ============================================================ func TestCov_NewVerifier(t *testing.T) { ed := mustEdKey(t) pub, _ := FromPublicKey(ed.Public()) t.Run("happy", func(t *testing.T) { v, err := NewVerifier([]PublicKey{*pub}) if err != nil || len(v.pubKeys) != 1 { t.Fatal(err) } }) t.Run("dedup", func(t *testing.T) { v, err := NewVerifier([]PublicKey{*pub, *pub}) if err != nil || len(v.pubKeys) != 1 { t.Fatalf("expected dedup to 1, got %d", len(v.pubKeys)) } }) t.Run("nil_rejected", func(t *testing.T) { _, err := NewVerifier(nil) if !errors.Is(err, ErrNoVerificationKey) { t.Fatalf("expected ErrNoVerificationKey, got %v", err) } }) t.Run("empty_rejected", func(t *testing.T) { _, err := NewVerifier([]PublicKey{}) if !errors.Is(err, ErrNoVerificationKey) { t.Fatalf("expected ErrNoVerificationKey, got %v", err) } }) t.Run("same_kid_different_keys", func(t *testing.T) { // Two different keys with the same KID should both be kept pub1, _ := FromPublicKey(mustEdKey(t).Public()) pub2, _ := FromPublicKey(mustEdKey(t).Public()) pub1.KID = "shared" pub2.KID = "shared" v, err := NewVerifier([]PublicKey{*pub1, *pub2}) if err != nil { t.Fatal(err) } if len(v.pubKeys) != 2 { t.Fatalf("expected 2 keys (same KID, different material), got %d", len(v.pubKeys)) } }) } func TestCov_PublicKeys(t *testing.T) { ed := mustEdKey(t) pub, _ := FromPublicKey(ed.Public()) v, _ := NewVerifier([]PublicKey{*pub}) keys := v.PublicKeys() if len(keys) != 1 { t.Fatal("expected 1 key") } } func TestCov_Verify(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) tok := mustSignStr(t, s, goodClaims()) t.Run("happy", func(t *testing.T) { jws, _ := Decode(tok) v := s.Verifier() if err := v.Verify(jws); err != nil { t.Fatal(err) } }) t.Run("unknown_kid", func(t *testing.T) { jws, _ := Decode(tok) other, _ := FromPublicKey(mustEdKey(t).Public()) v, _ := NewVerifier([]PublicKey{*other}) err := v.Verify(jws) if !errors.Is(err, ErrUnknownKID) { t.Fatalf("expected ErrUnknownKID, got %v", err) } }) t.Run("wrong_key_matching_kid", func(t *testing.T) { // Verifier has a key with the same KID but different material jws, _ := Decode(tok) kid := jws.GetHeader().KID other, _ := FromPublicKey(mustEdKey(t).Public()) other.KID = kid // same KID, different key material v, _ := NewVerifier([]PublicKey{*other}) err := v.Verify(jws) if !errors.Is(err, ErrSignatureInvalid) { t.Fatalf("expected ErrSignatureInvalid, got %v", err) } }) } func TestCov_verifyOneKey_AllAlgs(t *testing.T) { for _, tc := range []struct { name string signer crypto.Signer }{ {"ES256", mustECKey(t, elliptic.P256())}, {"ES384", mustECKey(t, elliptic.P384())}, {"ES512", mustECKey(t, elliptic.P521())}, {"RS256", mustRSAKey(t)}, {"EdDSA", mustEdKey(t)}, } { t.Run(tc.name, func(t *testing.T) { pk := mustFromPrivate(t, tc.signer) s := mustSigner(t, pk) tok := mustSignStr(t, s, goodClaims()) v := s.Verifier() jws, _ := Decode(tok) if err := v.Verify(jws); err != nil { t.Fatal(err) } }) } t.Run("unsupported_alg", func(t *testing.T) { h := RFCHeader{Alg: "HS256", KID: "k"} err := verifyOneKey(h, mustEdKey(t).Public().(ed25519.PublicKey), []byte("input"), []byte("sig")) if !errors.Is(err, ErrUnsupportedAlg) { t.Fatal("expected ErrUnsupportedAlg") } }) t.Run("wrong_key_type_EC", func(t *testing.T) { h := RFCHeader{Alg: "ES256", KID: "k"} err := verifyOneKey(h, mustEdKey(t).Public().(ed25519.PublicKey), []byte("input"), []byte("sig")) if !errors.Is(err, ErrAlgConflict) { t.Fatal("expected ErrAlgConflict") } }) t.Run("wrong_key_type_RSA", func(t *testing.T) { h := RFCHeader{Alg: "RS256", KID: "k"} err := verifyOneKey(h, mustEdKey(t).Public().(ed25519.PublicKey), []byte("input"), []byte("sig")) if !errors.Is(err, ErrAlgConflict) { t.Fatal("expected ErrAlgConflict") } }) t.Run("wrong_key_type_EdDSA", func(t *testing.T) { h := RFCHeader{Alg: "EdDSA", KID: "k"} err := verifyOneKey(h, &mustRSAKey(t).PublicKey, []byte("input"), []byte("sig")) if !errors.Is(err, ErrAlgConflict) { t.Fatal("expected ErrAlgConflict") } }) t.Run("EC_wrong_sig_length", func(t *testing.T) { h := RFCHeader{Alg: "ES256", KID: "k"} err := verifyOneKey(h, &mustECKey(t, elliptic.P256()).PublicKey, []byte("input"), []byte("short")) if !errors.Is(err, ErrSignatureInvalid) { t.Fatal("expected ErrSignatureInvalid") } }) } func TestCov_VerifyJWT(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) tok := mustSignStr(t, s, goodClaims()) v := s.Verifier() t.Run("happy", func(t *testing.T) { jws, err := v.VerifyJWT(tok) if err != nil { t.Fatal(err) } if jws.GetHeader().Alg != "EdDSA" { t.Fatal("wrong alg") } }) t.Run("bad_token", func(t *testing.T) { _, err := v.VerifyJWT("bad") if err == nil { t.Fatal("expected error") } }) t.Run("bad_sig", func(t *testing.T) { _, err := v.VerifyJWT(tok[:len(tok)-4] + "AAAA") if err == nil { t.Fatal("expected error") } }) } // ============================================================ // validate.go // ============================================================ func TestCov_ValidationError(t *testing.T) { ve := &ValidationError{ Code: "token_expired", Description: "exp: expired 5m ago", Err: ErrAfterExp, } if ve.Error() != "exp: expired 5m ago" { t.Fatalf("Error()=%q", ve.Error()) } if ve.Unwrap() != ErrAfterExp { t.Fatal("Unwrap mismatch") } if !errors.Is(ve, ErrAfterExp) { t.Fatal("errors.Is should match ErrAfterExp") } // ErrAfterExp wraps ErrInvalidClaim if !errors.Is(ve, ErrInvalidClaim) { t.Fatal("errors.Is should match ErrInvalidClaim via chain") } } func TestCov_ValidationErrors(t *testing.T) { t.Run("nil", func(t *testing.T) { if ves := ValidationErrors(nil); ves != nil { t.Fatal("expected nil") } }) t.Run("joined_with_VEs", func(t *testing.T) { ve1 := &ValidationError{Code: "a", Err: ErrAfterExp} ve2 := &ValidationError{Code: "b", Err: ErrMissingClaim} joined := errors.Join(ve1, ve2) ves := ValidationErrors(joined) if len(ves) != 2 { t.Fatalf("expected 2, got %d", len(ves)) } }) t.Run("joined_no_VEs", func(t *testing.T) { joined := errors.Join(fmt.Errorf("plain error")) if ves := ValidationErrors(joined); ves != nil { t.Fatal("expected nil for no VEs") } }) t.Run("single_VE", func(t *testing.T) { ve := &ValidationError{Code: "a", Err: ErrAfterExp} ves := ValidationErrors(ve) if len(ves) != 1 { t.Fatalf("expected 1, got %d", len(ves)) } }) t.Run("single_non_VE", func(t *testing.T) { if ves := ValidationErrors(fmt.Errorf("plain")); ves != nil { t.Fatal("expected nil") } }) } func TestCov_GetOAuth2Error(t *testing.T) { t.Run("nil", func(t *testing.T) { if code := GetOAuth2Error(nil); code != "" { t.Fatalf("expected empty, got %q", code) } }) t.Run("no_VEs", func(t *testing.T) { if code := GetOAuth2Error(fmt.Errorf("plain")); code != "" { t.Fatalf("expected empty, got %q", code) } }) t.Run("invalid_token", func(t *testing.T) { ve := &ValidationError{Code: "token_expired", Err: ErrAfterExp} if code := GetOAuth2Error(ve); code != "invalid_token" { t.Fatalf("expected invalid_token, got %q", code) } }) t.Run("insufficient_scope", func(t *testing.T) { ve := &ValidationError{Code: "insufficient_scope", Err: ErrInsufficientScope} if code := GetOAuth2Error(ve); code != "insufficient_scope" { t.Fatalf("expected insufficient_scope, got %q", code) } }) t.Run("server_error", func(t *testing.T) { ve := &ValidationError{Code: "server_error", Err: ErrMisconfigured} if code := GetOAuth2Error(ve); code != "server_error" { t.Fatalf("expected server_error, got %q", code) } }) t.Run("server_error_wins_over_scope", func(t *testing.T) { ve1 := &ValidationError{Err: ErrInsufficientScope} ve2 := &ValidationError{Err: ErrMisconfigured} joined := errors.Join(ve1, ve2) if code := GetOAuth2Error(joined); code != "server_error" { t.Fatalf("expected server_error, got %q", code) } }) } func TestCov_codeFor(t *testing.T) { tests := []struct { sentinel error code string }{ {ErrAfterExp, "token_expired"}, {ErrBeforeNBf, "token_not_yet_valid"}, {ErrBeforeIAt, "future_issued_at"}, {ErrBeforeAuthTime, "future_auth_time"}, {ErrAfterAuthMaxAge, "auth_time_exceeded"}, {ErrInsufficientScope, "insufficient_scope"}, {ErrMissingClaim, "missing_claim"}, {ErrInvalidTyp, "invalid_typ"}, {ErrInvalidClaim, "invalid_claim"}, {ErrMisconfigured, "server_error"}, {fmt.Errorf("unknown"), "unknown_error"}, } for _, tt := range tests { got := codeFor(tt.sentinel) if got != tt.code { t.Errorf("codeFor(%v)=%q want %q", tt.sentinel, got, tt.code) } } } func TestCov_isTimeSentinel(t *testing.T) { for _, s := range []error{ErrAfterExp, ErrBeforeNBf, ErrBeforeIAt, ErrBeforeAuthTime, ErrAfterAuthMaxAge} { if !isTimeSentinel(s) { t.Errorf("expected isTimeSentinel(%v)=true", s) } } for _, s := range []error{ErrMissingClaim, ErrInvalidClaim, ErrMisconfigured, ErrInsufficientScope} { if isTimeSentinel(s) { t.Errorf("expected isTimeSentinel(%v)=false", s) } } } func TestCov_formatDuration(t *testing.T) { t.Run("full", func(t *testing.T) { d := 25*time.Hour + 30*time.Minute + 45*time.Second got := formatDuration(d) if got != "1d 1h 30m 45s" { t.Fatalf("got %q", got) } }) t.Run("sub_second", func(t *testing.T) { got := formatDuration(500 * time.Millisecond) if got != "500ms" { t.Fatalf("got %q", got) } }) t.Run("zero", func(t *testing.T) { got := formatDuration(0) if got != "0ms" { t.Fatalf("got %q", got) } }) t.Run("negative", func(t *testing.T) { got := formatDuration(-5 * time.Second) if got != "5s" { t.Fatalf("got %q", got) } }) } func TestCov_resolveSkew(t *testing.T) { if got := resolveSkew(0); got != defaultGracePeriod { t.Fatalf("zero: got %v want %v", got, defaultGracePeriod) } if got := resolveSkew(-1); got != 0 { t.Fatalf("negative: got %v want 0", got) } if got := resolveSkew(5 * time.Second); got != 5*time.Second { t.Fatalf("positive: got %v want 5s", got) } } func TestCov_NewIDTokenValidator(t *testing.T) { v := NewIDTokenValidator([]string{"iss"}, []string{"aud"}, []string{"azp"}) if v.Checks&ChecksConfigured == 0 { t.Fatal("expected ChecksConfigured") } if v.Checks&CheckIss == 0 || v.Checks&CheckAud == 0 { t.Fatal("expected CheckIss and CheckAud when slices provided") } if v.Checks&CheckSub == 0 || v.Checks&CheckExp == 0 { t.Fatal("expected CheckSub and CheckExp") } // nil iss/aud should not set those check bits v2 := NewIDTokenValidator(nil, nil, nil) if v2.Checks&CheckIss != 0 { t.Fatal("expected no CheckIss for nil iss") } if v2.Checks&CheckAud != 0 { t.Fatal("expected no CheckAud for nil aud") } } func TestCov_NewAccessTokenValidator(t *testing.T) { v := NewAccessTokenValidator([]string{"iss"}, []string{"aud"}) if v.Checks&ChecksConfigured == 0 { t.Fatal("expected ChecksConfigured") } if v.Checks&CheckJTI == 0 || v.Checks&CheckClientID == 0 { t.Fatal("expected CheckJTI and CheckClientID for access token") } v2 := NewAccessTokenValidator(nil, nil) if v2.Checks&CheckIss != 0 { t.Fatal("expected no CheckIss for nil iss") } } func TestCov_NewAccessTokenValidator_Scopes(t *testing.T) { iss := []string{"https://example.com"} aud := []string{"https://api.example.com"} t.Run("nil_no_scope_check", func(t *testing.T) { // No scope args: CheckScope not set, scope claim not validated. v := NewAccessTokenValidator(iss, aud) if v.Checks&CheckScope != 0 { t.Fatal("expected CheckScope not set for nil scopes") } if v.RequiredScopes != nil { t.Fatal("expected nil RequiredScopes") } // Validate passes even with no scope claim. claims := goodClaims() claims.Scope = nil claims.JTI = "jti-x" if err := v.Validate(nil, claims, testNow); err != nil { t.Fatalf("expected no error without scope check, got %v", err) } }) t.Run("empty_presence_only", func(t *testing.T) { // Empty spread: CheckScope set, any non-empty scope accepted. v := NewAccessTokenValidator(iss, aud, []string{}...) if v.Checks&CheckScope == 0 { t.Fatal("expected CheckScope set for empty non-nil scopes") } if v.RequiredScopes == nil { t.Fatal("expected non-nil RequiredScopes") } // Validate passes when scope is present. if err := v.Validate(nil, goodClaims(), testNow); err != nil { t.Fatalf("expected no error with scope present, got %v", err) } // Validate fails when scope is absent. claims := goodClaims() claims.Scope = nil err := v.Validate(nil, claims, testNow) if !errors.Is(err, ErrMissingClaim) { t.Fatalf("expected ErrMissingClaim for absent scope, got %v", err) } }) t.Run("specific_scope", func(t *testing.T) { // Specific scope: CheckScope set, token must contain "openid". v := NewAccessTokenValidator(iss, aud, "openid") if v.Checks&CheckScope == 0 { t.Fatal("expected CheckScope set") } // Validate passes when scope contains "openid". if err := v.Validate(nil, goodClaims(), testNow); err != nil { t.Fatalf("expected no error, got %v", err) } // Validate fails when "openid" is absent from scope. claims := goodClaims() claims.Scope = SpaceDelimited{"profile"} err := v.Validate(nil, claims, testNow) if !errors.Is(err, ErrInsufficientScope) { t.Fatalf("expected ErrInsufficientScope, got %v", err) } }) } func TestCov_Validate_Unconfigured(t *testing.T) { v := &Validator{} // zero value err := v.Validate(nil, goodClaims(), testNow) if !errors.Is(err, ErrMisconfigured) { t.Fatalf("expected ErrMisconfigured, got %v", err) } } func TestCov_Validate_AllPass(t *testing.T) { v := NewIDTokenValidator([]string{"https://example.com"}, []string{"https://api.example.com"}, []string{"client-abc"}) err := v.Validate(nil, goodClaims(), testNow) if err != nil { t.Fatalf("expected nil, got %v", err) } } func TestCov_Validate_TimeAnnotation(t *testing.T) { v := NewIDTokenValidator([]string{"https://example.com"}, []string{"https://api.example.com"}, nil) claims := goodClaims() claims.Exp = testNow.Add(-time.Hour).Unix() // expired err := v.Validate(nil, claims, testNow) if err == nil { t.Fatal("expected error") } // Time errors get annotated with "server time" if !strings.Contains(err.Error(), "server time") { t.Fatalf("expected server time annotation, got: %s", err.Error()) } } func TestCov_Validate_ExplicitConfigForcesChecks(t *testing.T) { // Even without Check flags, non-empty Iss forces iss check v := &Validator{ Checks: ChecksConfigured, Iss: []string{"https://other.com"}, } claims := goodClaims() err := v.Validate(nil, claims, testNow) if !errors.Is(err, ErrInvalidClaim) { t.Fatalf("expected ErrInvalidClaim from forced iss check, got %v", err) } } func TestCov_Validate_AllChecks(t *testing.T) { // Enable every check via the bitmask v := &Validator{ Checks: ChecksConfigured | CheckIss | CheckSub | CheckAud | CheckExp | CheckNBf | CheckIAt | CheckJTI | CheckClientID | CheckAuthTime | CheckAzP | CheckScope, Iss: []string{"https://example.com"}, Aud: []string{"https://api.example.com"}, AzP: []string{"client-abc"}, RequiredScopes: []string{"openid"}, } claims := goodClaims() claims.NBf = testNow.Add(-time.Minute).Unix() err := v.Validate(nil, claims, testNow) if err != nil { t.Fatalf("expected all pass, got %v", err) } } func TestCov_Validate_PreExistingErrors(t *testing.T) { v := NewIDTokenValidator([]string{"https://example.com"}, []string{"https://api.example.com"}, nil) prior := []error{&ValidationError{Code: "invalid_claim", Description: "typ wrong", Err: ErrInvalidClaim}} err := v.Validate(prior, goodClaims(), testNow) // Should include the prior error in the joined result if err == nil { t.Fatal("expected error from prior errors") } if !strings.Contains(err.Error(), "typ wrong") { t.Fatalf("expected prior error in result, got: %s", err.Error()) } } func TestCov_Validate_ExplicitAud(t *testing.T) { // Non-empty Aud forces CheckAud even without the flag v := &Validator{ Checks: ChecksConfigured, Aud: []string{"wrong-aud"}, } err := v.Validate(nil, goodClaims(), testNow) if !errors.Is(err, ErrInvalidClaim) { t.Fatalf("expected ErrInvalidClaim, got %v", err) } } func TestCov_Validate_ExplicitAzP(t *testing.T) { // Non-empty AzP forces CheckAzP v := &Validator{ Checks: ChecksConfigured, AzP: []string{"wrong-azp"}, } err := v.Validate(nil, goodClaims(), testNow) if !errors.Is(err, ErrInvalidClaim) { t.Fatalf("expected ErrInvalidClaim, got %v", err) } } func TestCov_Validate_ExplicitScopes(t *testing.T) { // Non-empty RequiredScopes forces CheckScope v := &Validator{ Checks: ChecksConfigured, RequiredScopes: []string{"admin"}, } err := v.Validate(nil, goodClaims(), testNow) if !errors.Is(err, ErrInsufficientScope) { t.Fatalf("expected ErrInsufficientScope, got %v", err) } } func TestCov_Validate_ExplicitMaxAge(t *testing.T) { // MaxAge > 0 forces auth_time check v := &Validator{ Checks: ChecksConfigured, MaxAge: 1 * time.Second, // very short } err := v.Validate(nil, goodClaims(), testNow) if !errors.Is(err, ErrAfterAuthMaxAge) { t.Fatalf("expected ErrAfterAuthMaxAge, got %v", err) } } func TestCov_Validate_ErrorsIsChain(t *testing.T) { // Verify that errors.Is works through the full Validate path: // the returned error should match both ErrAfterExp and ErrInvalidClaim. v := NewIDTokenValidator([]string{"https://example.com"}, []string{"https://api.example.com"}, nil) claims := goodClaims() claims.Exp = testNow.Add(-time.Hour).Unix() // expired err := v.Validate(nil, claims, testNow) if err == nil { t.Fatal("expected error") } if !errors.Is(err, ErrAfterExp) { t.Fatal("expected errors.Is(err, ErrAfterExp)") } if !errors.Is(err, ErrInvalidClaim) { t.Fatal("expected errors.Is(err, ErrInvalidClaim) via chain") } } func TestCov_Validate_NegativeGracePeriod(t *testing.T) { // Negative GracePeriod disables skew tolerance entirely. // A token that expired 1s ago should fail even though default skew is 2s. v := &Validator{ Checks: ChecksConfigured | CheckExp, GracePeriod: -1, } claims := goodClaims() claims.Exp = testNow.Add(-1 * time.Second).Unix() err := v.Validate(nil, claims, testNow) if !errors.Is(err, ErrAfterExp) { t.Fatalf("expected ErrAfterExp with no skew, got %v", err) } } func TestCov_Validate_EmptyIss_Misconfigured(t *testing.T) { // Non-nil empty Iss forces the check and returns misconfigured. v := &Validator{ Checks: ChecksConfigured, Iss: []string{}, } err := v.Validate(nil, goodClaims(), testNow) if !errors.Is(err, ErrMisconfigured) { t.Fatalf("expected ErrMisconfigured for empty Iss, got %v", err) } } func TestCov_Validate_EmptyAud_Misconfigured(t *testing.T) { // Non-nil empty Aud forces the check and returns misconfigured. v := &Validator{ Checks: ChecksConfigured, Aud: []string{}, } err := v.Validate(nil, goodClaims(), testNow) if !errors.Is(err, ErrMisconfigured) { t.Fatalf("expected ErrMisconfigured for empty Aud, got %v", err) } } func TestCov_Verify_PrefersSigInvalid(t *testing.T) { // When multiple keys are tried, ErrSignatureInvalid should be preferred // over ErrAlgConflict. Verifier has an RSA key (wrong type) and an // Ed25519 key (right type, wrong material). pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) tok := mustSignStr(t, s, goodClaims()) jws, _ := Decode(tok) kid := jws.GetHeader().KID // RSA key will give ErrAlgConflict; wrong Ed25519 gives ErrSignatureInvalid rsaPub, _ := FromPublicKey(&mustRSAKey(t).PublicKey) rsaPub.KID = kid edPub, _ := FromPublicKey(mustEdKey(t).Public()) edPub.KID = kid v, _ := NewVerifier([]PublicKey{*rsaPub, *edPub}) err := v.Verify(jws) if !errors.Is(err, ErrSignatureInvalid) { t.Fatalf("expected ErrSignatureInvalid (preferred over ErrAlgConflict), got %v", err) } } // --- Per-claim check methods --- func TestCov_IsAllowedIss(t *testing.T) { tc := goodClaims() t.Run("nil_allowed", func(t *testing.T) { errs := tc.IsAllowedIss(nil, nil) if len(errs) != 1 || !errors.Is(errs[0], ErrMisconfigured) { t.Fatal("expected ErrMisconfigured for nil") } }) t.Run("empty_allowed", func(t *testing.T) { errs := tc.IsAllowedIss(nil, []string{}) if len(errs) != 1 || !errors.Is(errs[0], ErrMisconfigured) { t.Fatal("expected ErrMisconfigured for empty") } }) t.Run("missing_iss", func(t *testing.T) { tc2 := &TokenClaims{} errs := tc2.IsAllowedIss(nil, []string{"x"}) if len(errs) != 1 || !errors.Is(errs[0], ErrMissingClaim) { t.Fatal("expected ErrMissingClaim") } }) t.Run("not_in_list", func(t *testing.T) { errs := tc.IsAllowedIss(nil, []string{"https://other.com"}) if len(errs) != 1 || !errors.Is(errs[0], ErrInvalidClaim) { t.Fatal("expected ErrInvalidClaim") } }) t.Run("wildcard", func(t *testing.T) { errs := tc.IsAllowedIss(nil, []string{"*"}) if len(errs) != 0 { t.Fatal("expected pass with wildcard") } }) t.Run("match", func(t *testing.T) { errs := tc.IsAllowedIss(nil, []string{"https://example.com"}) if len(errs) != 0 { t.Fatal("expected pass") } }) } func TestCov_IsPresentSub(t *testing.T) { t.Run("missing", func(t *testing.T) { tc := &TokenClaims{} errs := tc.IsPresentSub(nil) if len(errs) != 1 || !errors.Is(errs[0], ErrMissingClaim) { t.Fatal("expected ErrMissingClaim") } }) t.Run("present", func(t *testing.T) { errs := goodClaims().IsPresentSub(nil) if len(errs) != 0 { t.Fatal("expected pass") } }) } func TestCov_HasAllowedAud(t *testing.T) { tc := goodClaims() t.Run("nil_allowed", func(t *testing.T) { errs := tc.HasAllowedAud(nil, nil) if len(errs) != 1 || !errors.Is(errs[0], ErrMisconfigured) { t.Fatal("expected ErrMisconfigured for nil") } }) t.Run("empty_allowed", func(t *testing.T) { errs := tc.HasAllowedAud(nil, []string{}) if len(errs) != 1 || !errors.Is(errs[0], ErrMisconfigured) { t.Fatal("expected ErrMisconfigured for empty") } }) t.Run("missing_aud", func(t *testing.T) { tc2 := &TokenClaims{} errs := tc2.HasAllowedAud(nil, []string{"x"}) if len(errs) != 1 || !errors.Is(errs[0], ErrMissingClaim) { t.Fatal("expected ErrMissingClaim") } }) t.Run("not_in_list", func(t *testing.T) { errs := tc.HasAllowedAud(nil, []string{"wrong"}) if len(errs) != 1 || !errors.Is(errs[0], ErrInvalidClaim) { t.Fatal("expected ErrInvalidClaim") } }) t.Run("wildcard", func(t *testing.T) { errs := tc.HasAllowedAud(nil, []string{"*"}) if len(errs) != 0 { t.Fatal("expected pass with wildcard") } }) t.Run("intersects", func(t *testing.T) { errs := tc.HasAllowedAud(nil, []string{"https://api.example.com"}) if len(errs) != 0 { t.Fatal("expected pass") } }) } func TestCov_IsBeforeExp(t *testing.T) { t.Run("missing", func(t *testing.T) { tc := &TokenClaims{} errs := tc.IsBeforeExp(nil, testNow, 0) if len(errs) != 1 || !errors.Is(errs[0], ErrMissingClaim) { t.Fatal("expected ErrMissingClaim") } }) t.Run("expired", func(t *testing.T) { tc := goodClaims() tc.Exp = testNow.Add(-time.Hour).Unix() errs := tc.IsBeforeExp(nil, testNow, 0) if len(errs) != 1 || !errors.Is(errs[0], ErrAfterExp) { t.Fatal("expected ErrAfterExp") } }) t.Run("valid", func(t *testing.T) { errs := goodClaims().IsBeforeExp(nil, testNow, 0) if len(errs) != 0 { t.Fatal("expected pass") } }) t.Run("within_skew", func(t *testing.T) { tc := goodClaims() tc.Exp = testNow.Add(-1 * time.Second).Unix() errs := tc.IsBeforeExp(nil, testNow, 2*time.Second) if len(errs) != 0 { t.Fatal("expected pass within skew") } }) } func TestCov_IsAfterNBf(t *testing.T) { t.Run("absent", func(t *testing.T) { tc := &TokenClaims{} errs := tc.IsAfterNBf(nil, testNow, 0) if len(errs) != 0 { t.Fatal("expected pass for absent nbf") } }) t.Run("future", func(t *testing.T) { tc := &TokenClaims{NBf: testNow.Add(time.Hour).Unix()} errs := tc.IsAfterNBf(nil, testNow, 0) if len(errs) != 1 || !errors.Is(errs[0], ErrBeforeNBf) { t.Fatal("expected ErrBeforeNBf") } }) t.Run("valid", func(t *testing.T) { tc := &TokenClaims{NBf: testNow.Add(-time.Minute).Unix()} errs := tc.IsAfterNBf(nil, testNow, 0) if len(errs) != 0 { t.Fatal("expected pass") } }) t.Run("within_skew", func(t *testing.T) { // nbf is 1s in the future, but skew of 2s accepts it tc := &TokenClaims{NBf: testNow.Add(1 * time.Second).Unix()} errs := tc.IsAfterNBf(nil, testNow, 2*time.Second) if len(errs) != 0 { t.Fatal("expected pass within skew") } }) } func TestCov_IsAfterIAt(t *testing.T) { t.Run("absent", func(t *testing.T) { tc := &TokenClaims{} errs := tc.IsAfterIAt(nil, testNow, 0) if len(errs) != 0 { t.Fatal("expected pass for absent iat") } }) t.Run("future", func(t *testing.T) { tc := &TokenClaims{IAt: testNow.Add(time.Hour).Unix()} errs := tc.IsAfterIAt(nil, testNow, 0) if len(errs) != 1 || !errors.Is(errs[0], ErrBeforeIAt) { t.Fatal("expected ErrBeforeIAt") } }) t.Run("valid", func(t *testing.T) { tc := goodClaims() errs := tc.IsAfterIAt(nil, testNow, 0) if len(errs) != 0 { t.Fatal("expected pass") } }) t.Run("within_skew", func(t *testing.T) { // iat is 1s in the future, but skew of 2s accepts it tc := &TokenClaims{IAt: testNow.Add(1 * time.Second).Unix()} errs := tc.IsAfterIAt(nil, testNow, 2*time.Second) if len(errs) != 0 { t.Fatal("expected pass within skew") } }) } func TestCov_IsPresentJTI(t *testing.T) { t.Run("missing", func(t *testing.T) { tc := &TokenClaims{} errs := tc.IsPresentJTI(nil) if len(errs) != 1 || !errors.Is(errs[0], ErrMissingClaim) { t.Fatal("expected ErrMissingClaim") } }) t.Run("present", func(t *testing.T) { errs := goodClaims().IsPresentJTI(nil) if len(errs) != 0 { t.Fatal("expected pass") } }) } func TestCov_IsValidAuthTime(t *testing.T) { t.Run("missing", func(t *testing.T) { tc := &TokenClaims{} errs := tc.IsValidAuthTime(nil, testNow, 0, 0) if len(errs) != 1 || !errors.Is(errs[0], ErrMissingClaim) { t.Fatal("expected ErrMissingClaim") } }) t.Run("future", func(t *testing.T) { tc := &TokenClaims{AuthTime: testNow.Add(time.Hour).Unix()} errs := tc.IsValidAuthTime(nil, testNow, 0, 0) if len(errs) != 1 || !errors.Is(errs[0], ErrBeforeAuthTime) { t.Fatal("expected ErrBeforeAuthTime") } }) t.Run("maxAge_exceeded", func(t *testing.T) { tc := &TokenClaims{AuthTime: testNow.Add(-time.Hour).Unix()} errs := tc.IsValidAuthTime(nil, testNow, 0, 30*time.Minute) if len(errs) != 1 || !errors.Is(errs[0], ErrAfterAuthMaxAge) { t.Fatal("expected ErrAfterAuthMaxAge") } }) t.Run("valid_with_maxAge", func(t *testing.T) { tc := goodClaims() // auth_time is 5m ago errs := tc.IsValidAuthTime(nil, testNow, 0, time.Hour) if len(errs) != 0 { t.Fatal("expected pass") } }) t.Run("valid_without_maxAge", func(t *testing.T) { tc := goodClaims() errs := tc.IsValidAuthTime(nil, testNow, 0, 0) if len(errs) != 0 { t.Fatal("expected pass") } }) t.Run("future_within_skew", func(t *testing.T) { // auth_time is 1s in the future, but skew of 2s accepts it tc := &TokenClaims{AuthTime: testNow.Add(1 * time.Second).Unix()} errs := tc.IsValidAuthTime(nil, testNow, 2*time.Second, 0) if len(errs) != 0 { t.Fatal("expected pass within skew") } }) t.Run("maxAge_within_skew", func(t *testing.T) { // auth_time is 31m ago, maxAge is 30m, but skew of 2m accepts it tc := &TokenClaims{AuthTime: testNow.Add(-31 * time.Minute).Unix()} errs := tc.IsValidAuthTime(nil, testNow, 2*time.Minute, 30*time.Minute) if len(errs) != 0 { t.Fatal("expected pass: maxAge exceeded by 1m but within 2m skew") } }) } func TestCov_IsAllowedAzP(t *testing.T) { tc := goodClaims() t.Run("nil_allowed", func(t *testing.T) { errs := tc.IsAllowedAzP(nil, nil) if len(errs) != 1 || !errors.Is(errs[0], ErrMisconfigured) { t.Fatal("expected ErrMisconfigured") } }) t.Run("empty_allowed", func(t *testing.T) { errs := tc.IsAllowedAzP(nil, []string{}) if len(errs) != 1 || !errors.Is(errs[0], ErrMisconfigured) { t.Fatal("expected ErrMisconfigured") } }) t.Run("missing", func(t *testing.T) { tc2 := &TokenClaims{} errs := tc2.IsAllowedAzP(nil, []string{"x"}) if len(errs) != 1 || !errors.Is(errs[0], ErrMissingClaim) { t.Fatal("expected ErrMissingClaim") } }) t.Run("not_in_list", func(t *testing.T) { errs := tc.IsAllowedAzP(nil, []string{"wrong"}) if len(errs) != 1 || !errors.Is(errs[0], ErrInvalidClaim) { t.Fatal("expected ErrInvalidClaim") } }) t.Run("wildcard", func(t *testing.T) { errs := tc.IsAllowedAzP(nil, []string{"*"}) if len(errs) != 0 { t.Fatal("expected pass") } }) t.Run("match", func(t *testing.T) { errs := tc.IsAllowedAzP(nil, []string{"client-abc"}) if len(errs) != 0 { t.Fatal("expected pass") } }) } func TestCov_IsPresentClientID(t *testing.T) { t.Run("missing", func(t *testing.T) { tc := &TokenClaims{} errs := tc.IsPresentClientID(nil) if len(errs) != 1 || !errors.Is(errs[0], ErrMissingClaim) { t.Fatal("expected ErrMissingClaim") } }) t.Run("present", func(t *testing.T) { errs := goodClaims().IsPresentClientID(nil) if len(errs) != 0 { t.Fatal("expected pass") } }) } func TestCov_ContainsScopes(t *testing.T) { t.Run("missing_scope", func(t *testing.T) { tc := &TokenClaims{} errs := tc.ContainsScopes(nil, nil) if len(errs) != 1 || !errors.Is(errs[0], ErrMissingClaim) { t.Fatal("expected ErrMissingClaim") } }) t.Run("missing_required", func(t *testing.T) { tc := goodClaims() errs := tc.ContainsScopes(nil, []string{"admin"}) if len(errs) != 1 || !errors.Is(errs[0], ErrInsufficientScope) { t.Fatal("expected ErrInsufficientScope") } }) t.Run("all_present", func(t *testing.T) { tc := goodClaims() errs := tc.ContainsScopes(nil, []string{"openid"}) if len(errs) != 0 { t.Fatal("expected pass") } }) t.Run("presence_only", func(t *testing.T) { tc := goodClaims() errs := tc.ContainsScopes(nil, nil) if len(errs) != 0 { t.Fatal("expected pass for presence-only") } }) } func TestCov_IsAllowedTyp(t *testing.T) { t.Run("empty_allowed", func(t *testing.T) { h := &RFCHeader{Typ: "JWT"} errs := h.IsAllowedTyp(nil, nil) if len(errs) != 1 || !errors.Is(errs[0], ErrMisconfigured) { t.Fatal("expected ErrMisconfigured") } }) t.Run("not_in_list", func(t *testing.T) { h := &RFCHeader{Typ: "at+jwt"} errs := h.IsAllowedTyp(nil, []string{"JWT"}) if len(errs) != 1 || !errors.Is(errs[0], ErrInvalidTyp) { t.Fatal("expected ErrInvalidTyp") } }) t.Run("case_insensitive_match", func(t *testing.T) { h := &RFCHeader{Typ: "jwt"} errs := h.IsAllowedTyp(nil, []string{"JWT"}) if len(errs) != 0 { t.Fatal("expected pass (case-insensitive)") } }) t.Run("exact_match", func(t *testing.T) { h := &RFCHeader{Typ: "JWT"} errs := h.IsAllowedTyp(nil, []string{"JWT"}) if len(errs) != 0 { t.Fatal("expected pass") } }) t.Run("empty_typ_not_in_list", func(t *testing.T) { h := &RFCHeader{} errs := h.IsAllowedTyp(nil, []string{"JWT"}) if len(errs) != 1 || !errors.Is(errs[0], ErrInvalidTyp) { t.Fatal("expected ErrInvalidTyp for empty typ") } }) } // ============================================================ // Additional coverage: parse helpers, marshal, NewPrivateKey, etc. // ============================================================ func TestCov_NewAccessToken_BadClaims(t *testing.T) { _, err := NewAccessToken(&badClaims{Bad: make(chan int)}) if err == nil { t.Fatal("expected error") } } func TestCov_ParsePublicJWK(t *testing.T) { // Create a valid JWK JSON via marshal round-trip pub, _ := FromPublicKey(mustEdKey(t).Public()) data, _ := json.Marshal(pub) pk, err := ParsePublicJWK(data) if err != nil { t.Fatal(err) } if pk.KID == "" { t.Fatal("expected KID") } // Bad JSON _, err = ParsePublicJWK([]byte("{bad")) if err == nil { t.Fatal("expected error") } } func TestCov_ParsePrivateJWK(t *testing.T) { edKey := mustEdKey(t) pk := mustFromPrivate(t, edKey) // Need to give it a KID so NewSigner doesn't complain s := mustSigner(t, pk) _ = s // just to exercise key // Marshal the private key data, err := json.Marshal(pk) if err != nil { t.Fatal(err) } parsed, err := ParsePrivateJWK(data) if err != nil { t.Fatal(err) } if parsed.KID == "" { t.Fatal("expected KID") } // Bad JSON _, err = ParsePrivateJWK([]byte("{bad")) if err == nil { t.Fatal("expected error") } } func TestCov_ParseWellKnownJWKs(t *testing.T) { pub, _ := FromPublicKey(mustEdKey(t).Public()) jwks := WellKnownJWKs{Keys: []PublicKey{*pub}} data, _ := json.Marshal(jwks) parsed, err := ParseWellKnownJWKs(data) if err != nil { t.Fatal(err) } if len(parsed.Keys) != 1 { t.Fatalf("expected 1 key, got %d", len(parsed.Keys)) } // Bad JSON _, err = ParseWellKnownJWKs([]byte("{bad")) if err == nil { t.Fatal("expected error") } } func TestCov_PublicKey_MarshalJSON(t *testing.T) { pub, _ := FromPublicKey(mustEdKey(t).Public()) data, err := json.Marshal(pub) if err != nil { t.Fatal(err) } if len(data) == 0 { t.Fatal("expected data") } // Error path: unsupported key bad := PublicKey{Pub: fakeKey{}} _, err = json.Marshal(bad) if err == nil { t.Fatal("expected error for unsupported key") } } func TestCov_PrivateKey_MarshalJSON(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) data, err := json.Marshal(pk) if err != nil { t.Fatal(err) } if len(data) == 0 { t.Fatal("expected data") } } func TestCov_PublicKey_Thumbprint_AllTypes(t *testing.T) { // EC ecPub, _ := FromPublicKey(&mustECKey(t, elliptic.P256()).PublicKey) if _, err := ecPub.Thumbprint(); err != nil { t.Fatal(err) } // RSA rsaPub, _ := FromPublicKey(&mustRSAKey(t).PublicKey) if _, err := rsaPub.Thumbprint(); err != nil { t.Fatal(err) } // OKP (Ed25519) edPub, _ := FromPublicKey(mustEdKey(t).Public()) if _, err := edPub.Thumbprint(); err != nil { t.Fatal(err) } } func TestCov_Sign_BadClaims(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) _, err := s.Sign(&badClaims{Bad: make(chan int)}) if err == nil { t.Fatal("expected error for bad claims") } } func TestCov_SignToString_BadClaims(t *testing.T) { pk := mustFromPrivate(t, mustEdKey(t)) s := mustSigner(t, pk) _, err := s.SignToString(&badClaims{Bad: make(chan int)}) if err == nil { t.Fatal("expected error for bad claims") } } func TestCov_SignRaw_NilPrivKey(t *testing.T) { bad := &Signer{keys: []PrivateKey{{KID: "test"}}} _, err := bad.SignRaw(&RFCHeader{}, nil) if !errors.Is(err, ErrNoSigningKey) { t.Fatalf("expected ErrNoSigningKey, got %v", err) } } func TestCov_SetHeader_OK(t *testing.T) { jws, _ := New(goodClaims()) hdr := jws.GetHeader() err := jws.SetHeader(&hdr) if err != nil { t.Fatal(err) } } func TestCov_FromPublicKey_Unsupported(t *testing.T) { _, err := FromPublicKey(fakeKey{}) if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) } } func TestCov_FromPublicKey_AllTypes(t *testing.T) { // EC pk, err := FromPublicKey(&mustECKey(t, elliptic.P384()).PublicKey) if err != nil || pk.Alg != "ES384" { t.Fatalf("EC: err=%v alg=%q", err, pk.Alg) } // RSA pk, err = FromPublicKey(&mustRSAKey(t).PublicKey) if err != nil || pk.Alg != "RS256" { t.Fatalf("RSA: err=%v alg=%q", err, pk.Alg) } } func TestCov_validateSigningKey_AllTypes(t *testing.T) { for _, tc := range []struct { name string signer crypto.Signer }{ {"EC", mustECKey(t, elliptic.P256())}, {"RSA", mustRSAKey(t)}, {"Ed25519", mustEdKey(t)}, } { t.Run(tc.name, func(t *testing.T) { pk := mustFromPrivate(t, tc.signer) pub, err := pk.PublicKey() if err != nil { t.Fatal(err) } if err := validateSigningKey(pk, pub); err != nil { t.Fatal(err) } }) } } func TestCov_encode_AllTypes(t *testing.T) { // EC P-384 ecPub, _ := FromPublicKey(&mustECKey(t, elliptic.P384()).PublicKey) if _, err := encode(*ecPub); err != nil { t.Fatal(err) } // EC P-521 ecPub2, _ := FromPublicKey(&mustECKey(t, elliptic.P521()).PublicKey) if _, err := encode(*ecPub2); err != nil { t.Fatal(err) } // RSA rsaPub, _ := FromPublicKey(&mustRSAKey(t).PublicKey) if _, err := encode(*rsaPub); err != nil { t.Fatal(err) } } func TestCov_encodePrivate_AllTypes(t *testing.T) { // EC ecKey := mustECKey(t, elliptic.P256()) ecPK := mustFromPrivate(t, ecKey) if _, err := encodePrivate(*ecPK); err != nil { t.Fatal(err) } // RSA rsaPK := mustFromPrivate(t, mustRSAKey(t)) if _, err := encodePrivate(*rsaPK); err != nil { t.Fatal(err) } } func TestCov_decodeOne_RSA(t *testing.T) { rsaKey := mustRSAKey(t) pub, _ := FromPublicKey(&rsaKey.PublicKey) data, _ := json.Marshal(pub) var rk rawKey json.Unmarshal(data, &rk) pk, err := decodeOne(rk) if err != nil { t.Fatal(err) } if pk.Pub == nil { t.Fatal("expected key") } } func TestCov_decodeOne_UnknownKty(t *testing.T) { _, err := decodeOne(rawKey{Kty: "UNKNOWN"}) if !errors.Is(err, ErrUnsupportedKeyType) { t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) } } func TestCov_decodePrivate_RSA(t *testing.T) { rsaKey := mustRSAKey(t) pk := mustFromPrivate(t, rsaKey) data, _ := json.Marshal(pk) var rk rawKey json.Unmarshal(data, &rk) priv, err := decodePrivate(rk) if err != nil { t.Fatal(err) } if priv.Priv == nil { t.Fatal("expected private key") } } func TestCov_decodePrivate_EC(t *testing.T) { ecKey := mustECKey(t, elliptic.P256()) pk := mustFromPrivate(t, ecKey) data, _ := json.Marshal(pk) var rk rawKey json.Unmarshal(data, &rk) priv, err := decodePrivate(rk) if err != nil { t.Fatal(err) } if priv.Priv == nil { t.Fatal("expected private key") } } func TestCov_signingParams_RSA(t *testing.T) { rsaKey := mustRSAKey(t) alg, hash, ecKeySize, err := signingParams(rsaKey) if err != nil || alg != "RS256" || hash == 0 || ecKeySize != 0 { t.Fatalf("unexpected: alg=%q hash=%v ecKeySize=%d err=%v", alg, hash, ecKeySize, err) } } func TestCov_signBytes_RSA(t *testing.T) { rsaKey := mustRSAKey(t) sig, err := signBytes(rsaKey, "RS256", crypto.SHA256, 0, []byte("test input")) if err != nil || len(sig) == 0 { t.Fatal(err) } } func TestCov_verifyOneKey_EC_CurveMismatch(t *testing.T) { // ES256 key but token says ES384 h := RFCHeader{Alg: "ES384", KID: "k"} err := verifyOneKey(h, &mustECKey(t, elliptic.P256()).PublicKey, []byte("input"), make([]byte, 96)) if !errors.Is(err, ErrAlgConflict) { t.Fatalf("expected ErrAlgConflict, got %v", err) } }