From 73e7903c4c94a1e4ab9f806ef59b2e374fac40fb Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Tue, 17 Mar 2026 04:16:54 -0600 Subject: [PATCH] test(auth/jwt): even more tests --- auth/jwt/tests/go.mod | 27 + auth/jwt/tests/go.sum | 43 + auth/jwt/tests/nuance/nuance_report_test.go | 525 ++++++++++++ .../round-trip-go-jose/round_trip_test.go | 660 +++++++++++++++ .../round-trip-go-jwt/round_trip_test.go | 775 ++++++++++++++++++ .../tests/round-trip-jwx/round_trip_test.go | 621 ++++++++++++++ auth/jwt/tests/testkeys/testkeys.go | 158 ++++ 7 files changed, 2809 insertions(+) create mode 100644 auth/jwt/tests/go.mod create mode 100644 auth/jwt/tests/go.sum create mode 100644 auth/jwt/tests/nuance/nuance_report_test.go create mode 100644 auth/jwt/tests/round-trip-go-jose/round_trip_test.go create mode 100644 auth/jwt/tests/round-trip-go-jwt/round_trip_test.go create mode 100644 auth/jwt/tests/round-trip-jwx/round_trip_test.go create mode 100644 auth/jwt/tests/testkeys/testkeys.go diff --git a/auth/jwt/tests/go.mod b/auth/jwt/tests/go.mod new file mode 100644 index 0000000..381a384 --- /dev/null +++ b/auth/jwt/tests/go.mod @@ -0,0 +1,27 @@ +module github.com/therootcompany/golib/auth/jwt/tests + +go 1.26.1 + +require ( + github.com/go-jose/go-jose/v4 v4.1.3 + github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/lestrrat-go/jwx/v3 v3.0.13 + github.com/therootcompany/golib/auth/jwt v0.0.0 +) + +require ( + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/goccy/go-json v0.10.3 // indirect + github.com/lestrrat-go/blackmagic v1.0.4 // indirect + github.com/lestrrat-go/dsig v1.0.0 // indirect + github.com/lestrrat-go/dsig-secp256k1 v1.0.0 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/httprc/v3 v3.0.2 // indirect + github.com/lestrrat-go/option/v2 v2.0.0 // indirect + github.com/segmentio/asm v1.2.1 // indirect + github.com/valyala/fastjson v1.6.7 // indirect + golang.org/x/crypto v0.46.0 // indirect + golang.org/x/sys v0.39.0 // indirect +) + +replace github.com/therootcompany/golib/auth/jwt => ../ diff --git a/auth/jwt/tests/go.sum b/auth/jwt/tests/go.sum new file mode 100644 index 0000000..c185cc7 --- /dev/null +++ b/auth/jwt/tests/go.sum @@ -0,0 +1,43 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/lestrrat-go/blackmagic v1.0.4 h1:IwQibdnf8l2KoO+qC3uT4OaTWsW7tuRQXy9TRN9QanA= +github.com/lestrrat-go/blackmagic v1.0.4/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw= +github.com/lestrrat-go/dsig v1.0.0 h1:OE09s2r9Z81kxzJYRn07TFM9XA4akrUdoMwr0L8xj38= +github.com/lestrrat-go/dsig v1.0.0/go.mod h1:dEgoOYYEJvW6XGbLasr8TFcAxoWrKlbQvmJgCR0qkDo= +github.com/lestrrat-go/dsig-secp256k1 v1.0.0 h1:JpDe4Aybfl0soBvoVwjqDbp+9S1Y2OM7gcrVVMFPOzY= +github.com/lestrrat-go/dsig-secp256k1 v1.0.0/go.mod h1:CxUgAhssb8FToqbL8NjSPoGQlnO4w3LG1P0qPWQm/NU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc/v3 v3.0.2 h1:7u4HUaD0NQbf2/n5+fyp+T10hNCsAnwKfqn4A4Baif0= +github.com/lestrrat-go/httprc/v3 v3.0.2/go.mod h1:mSMtkZW92Z98M5YoNNztbRGxbXHql7tSitCvaxvo9l0= +github.com/lestrrat-go/jwx/v3 v3.0.13 h1:AdHKiPIYeCSnOJtvdpipPg/0SuFh9rdkN+HF3O0VdSk= +github.com/lestrrat-go/jwx/v3 v3.0.13/go.mod h1:2m0PV1A9tM4b/jVLMx8rh6rBl7F6WGb3EG2hufN9OQU= +github.com/lestrrat-go/option/v2 v2.0.0 h1:XxrcaJESE1fokHy3FpaQ/cXW8ZsIdWcdFzzLOcID3Ss= +github.com/lestrrat-go/option/v2 v2.0.0/go.mod h1:oSySsmzMoR0iRzCDCaUfsCzxQHUEuhOViQObyy7S6Vg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= +github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM= +github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/auth/jwt/tests/nuance/nuance_report_test.go b/auth/jwt/tests/nuance/nuance_report_test.go new file mode 100644 index 0000000..8e7134a --- /dev/null +++ b/auth/jwt/tests/nuance/nuance_report_test.go @@ -0,0 +1,525 @@ +// Package nuance_test documents behavioral differences between this library, +// go-jose/go-jose v4, and lestrrat-go/jwx v3 that may cause interop surprises. +// +// Each test logs observations via t.Log so that `go test -v ./nuance/` produces +// a readable report. Tests that demonstrate library-specific defaults use +// controlled clock offsets to show exactly where each library draws the line. +// +// Run: +// +// go test ./nuance/ -v +package nuance_test + +import ( + "crypto" + "encoding/base64" + "encoding/json" + "strings" + "testing" + "time" + + jose "github.com/go-jose/go-jose/v4" + josejwt "github.com/go-jose/go-jose/v4/jwt" + "github.com/lestrrat-go/jwx/v3/jwa" + jwxjwk "github.com/lestrrat-go/jwx/v3/jwk" + jwxjwt "github.com/lestrrat-go/jwx/v3/jwt" + + "github.com/therootcompany/golib/auth/jwt" + "github.com/therootcompany/golib/auth/jwt/tests/testkeys" +) + +// signOurs creates a JWT signed with our library using the given claims. +func signOurs(t *testing.T, ks testkeys.KeySet, claims jwt.Claims) string { + t.Helper() + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatal(err) + } + tok, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + return tok +} + +// ----------------------------------------------------------------------- +// Clock skew / expiration tolerance +// ----------------------------------------------------------------------- + +func TestNuance_ClockSkew_GoJose(t *testing.T) { + t.Log("=== Nuance: expiration checking - when does it happen? ===") + t.Log("") + t.Log("CRITICAL: Our VerifyJWT only checks the SIGNATURE.") + t.Log("Claims validation (exp, iat) requires a separate Validate() call.") + t.Log("go-jose also separates verification from validation.") + t.Log("jwx bundles both into jwt.Parse by default.") + t.Log("") + t.Log("go-jose ValidateWithLeeway takes an explicit leeway parameter.") + t.Log("Our Validator.Validate uses DefaultGracePeriod (2s).") + t.Log("") + + ks := testkeys.GenerateEdDSA("skew") + + // Token expired 30 seconds ago. + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "skew-test", + Exp: time.Now().Add(-30 * time.Second).Unix(), + IAt: time.Now().Add(-5 * time.Minute).Unix(), + } + tokenStr := signOurs(t, ks, claims) + + // Our VerifyJWT: signature-only, does NOT check exp. + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + jws, ourSigErr := verifier.VerifyJWT(tokenStr) + t.Logf(" our VerifyJWT (sig only): accepts=%v", ourSigErr == nil) + + // Our Validate: checks exp with DefaultGracePeriod (2s) => REJECTS. + if jws != nil { + var decoded jwt.TokenClaims + jws.UnmarshalClaims(&decoded) + v := jwt.Validator{ + Checks: jwt.CheckIss | jwt.CheckExp | jwt.CheckIAt | jwt.CheckNBf, + Iss: []string{"https://example.com"}, + } + valErr := v.Validate(nil, &decoded, time.Now()) + t.Logf(" our Validate (2s grace): rejects=%v (err=%v)", + valErr != nil, valErr) + if valErr == nil { + t.Error("expected our Validate to reject a token expired 30s ago") + } + } + + // go-jose: also separates parse/verify from validation. + tok, _ := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{jose.EdDSA}) + var joseClaims josejwt.Claims + tok.Claims(ks.RawPub, &joseClaims) + + // go-jose with explicit 1-minute leeway => accepts. + err1m := joseClaims.ValidateWithLeeway(josejwt.Expected{Time: time.Now()}, 1*time.Minute) + t.Logf(" go-jose (1m leeway): accepts=%v", err1m == nil) + + // go-jose with 0 leeway => rejects. + err0 := joseClaims.ValidateWithLeeway(josejwt.Expected{Time: time.Now()}, 0) + t.Logf(" go-jose (0 leeway): rejects=%v", err0 != nil) + + if err1m != nil { + t.Error("expected go-jose to accept with 1m leeway") + } + if err0 == nil { + t.Error("expected go-jose to reject with 0 leeway") + } + + t.Log("") + t.Log("ACTION: Our VerifyJWT is signature-only. You MUST call Validate()") + t.Log("after VerifyJWT to enforce exp/iat. go-jose likewise requires an") + t.Log("explicit ValidateWithLeeway call. Choose matching leeway values.") +} + +func TestNuance_ClockSkew_JWX(t *testing.T) { + t.Log("=== Nuance: jwx bundles validation into jwt.Parse ===") + t.Log("") + t.Log("Unlike our lib and go-jose (which separate sig from claims),") + t.Log("jwx v3 validates exp/iat DURING jwt.Parse. Default skew is 0.") + t.Log("Use jwt.WithAcceptableSkew(d) or jwt.WithValidate(false) to adjust.") + t.Log("") + + ks := testkeys.GenerateEdDSA("skew-jwx") + + // Token expired 1 second ago. + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "skew-test", + Exp: time.Now().Add(-1 * time.Second).Unix(), + IAt: time.Now().Add(-5 * time.Minute).Unix(), + } + tokenStr := signOurs(t, ks, claims) + + // jwx: zero skew (default) => rejects at parse time. + _, jwxErr := jwxjwt.Parse([]byte(tokenStr), jwxjwt.WithKey(jwa.EdDSA(), ks.RawPub)) + t.Logf(" jwx Parse (0s skew): rejects=%v", jwxErr != nil) + + // jwx: with 5s skew => accepts. + _, jwxErr5 := jwxjwt.Parse([]byte(tokenStr), + jwxjwt.WithKey(jwa.EdDSA(), ks.RawPub), + jwxjwt.WithAcceptableSkew(5*time.Second), + ) + t.Logf(" jwx Parse (5s skew): accepts=%v", jwxErr5 == nil) + + // jwx: validation disabled => accepts (sig-only, like our VerifyJWT). + _, jwxErrNoval := jwxjwt.Parse([]byte(tokenStr), + jwxjwt.WithKey(jwa.EdDSA(), ks.RawPub), + jwxjwt.WithValidate(false), + ) + t.Logf(" jwx Parse (no validate): accepts=%v", jwxErrNoval == nil) + + // Our VerifyJWT: always accepts (sig-only). + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + _, ourErr := verifier.VerifyJWT(tokenStr) + t.Logf(" our VerifyJWT (sig only): accepts=%v", ourErr == nil) + + if jwxErr == nil { + t.Error("expected jwx to reject with 0 skew") + } + if jwxErr5 != nil { + t.Error("expected jwx to accept with 5s skew") + } + if jwxErrNoval != nil { + t.Error("expected jwx to accept with validation disabled") + } + if ourErr != nil { + t.Errorf("expected our VerifyJWT to accept (sig-only): %v", ourErr) + } + + t.Log("") + t.Log("ACTION: jwx rejects expired tokens at parse time. Use") + t.Log("WithAcceptableSkew(d) to add clock tolerance, or") + t.Log("WithValidate(false) for sig-only (matching our VerifyJWT).") +} + +// ----------------------------------------------------------------------- +// kid header emission +// ----------------------------------------------------------------------- + +func TestNuance_KIDHeader_GoJose(t *testing.T) { + t.Log("=== Nuance: go-jose kid header emission ===") + t.Log("") + t.Log("go-jose omits 'kid' from the JWS header unless:") + t.Log(" 1. The signing key is wrapped in jose.JSONWebKey{KeyID: ...}, or") + t.Log(" 2. opts.WithHeader(jose.HeaderKey(\"kid\"), ...) is used.") + t.Log("Our verifier tries all keys when kid is missing (fallback).") + t.Log("") + + ks := testkeys.GenerateEdDSA("kid-test") + + // Sign with raw key (no JSONWebKey wrapper) - kid is missing. + rawSigKey := jose.SigningKey{ + Algorithm: jose.EdDSA, + Key: ks.RawPriv, // raw key, no JSONWebKey wrapper + } + rawSigner, _ := jose.NewSigner(rawSigKey, nil) + rawClaims := josejwt.Claims{ + Subject: "raw-key", + Expiry: josejwt.NewNumericDate(time.Now().Add(time.Hour)), + } + rawToken, _ := josejwt.Signed(rawSigner).Claims(rawClaims).Serialize() + + // Check the header. + parts := strings.SplitN(rawToken, ".", 3) + headerJSON, _ := base64.RawURLEncoding.DecodeString(parts[0]) + var header map[string]any + json.Unmarshal(headerJSON, &header) + _, hasKID := header["kid"] + t.Logf(" raw key signing: kid in header = %v (header: %s)", hasKID, headerJSON) + + // Our verifier accepts via try-all-keys fallback (no kid => try every key). + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + _, ourErr := verifier.VerifyJWT(rawToken) + t.Logf(" our VerifyJWT: err = %v", ourErr) + + // Sign with JSONWebKey wrapper - kid is present. + wrappedSigKey := jose.SigningKey{ + Algorithm: jose.EdDSA, + Key: jose.JSONWebKey{Key: ks.RawPriv, KeyID: ks.KID}, + } + wrappedSigner, _ := jose.NewSigner(wrappedSigKey, nil) + wrappedToken, _ := josejwt.Signed(wrappedSigner).Claims(rawClaims).Serialize() + parts2 := strings.SplitN(wrappedToken, ".", 3) + headerJSON2, _ := base64.RawURLEncoding.DecodeString(parts2[0]) + var header2 map[string]any + json.Unmarshal(headerJSON2, &header2) + _, hasKID2 := header2["kid"] + t.Logf(" JSONWebKey signing: kid in header = %v (header: %s)", hasKID2, headerJSON2) + + _, ourErr2 := verifier.VerifyJWT(wrappedToken) + t.Logf(" our VerifyJWT: err = %v", ourErr2) + + if hasKID { + t.Error("expected raw key signing to NOT have kid in header") + } + if ourErr != nil { + t.Errorf("expected our verifier to accept token without kid (try-all-keys fallback), got: %v", ourErr) + } + if !hasKID2 { + t.Error("expected JSONWebKey signing to have kid in header") + } + if ourErr2 != nil { + t.Errorf("expected our verifier to accept token with kid, got: %v", ourErr2) + } + + t.Log("") + t.Log("NOTE: When kid is missing, our verifier tries all keys (first match wins).") + t.Log("For multi-key verifiers, always set kid for efficient key lookup.") +} + +func TestNuance_KIDHeader_JWX(t *testing.T) { + t.Log("=== Nuance: jwx kid header emission ===") + t.Log("") + t.Log("jwx omits 'kid' unless jwk.KeyIDKey is set on the key before signing.") + t.Log("Our verifier tries all keys when kid is missing (fallback).") + t.Log("") + + ks := testkeys.GenerateEdDSA("kid-jwx") + + // Import key WITHOUT setting kid. + jwxKeyNoKID, _ := jwxjwk.Import(ks.RawPriv) + tok := jwxjwt.New() + tok.Set(jwxjwt.SubjectKey, "no-kid") + tok.Set(jwxjwt.ExpirationKey, time.Now().Add(time.Hour)) + noKIDToken, _ := jwxjwt.Sign(tok, jwxjwt.WithKey(jwa.EdDSA(), jwxKeyNoKID)) + + parts := strings.SplitN(string(noKIDToken), ".", 3) + headerJSON, _ := base64.RawURLEncoding.DecodeString(parts[0]) + var header map[string]any + json.Unmarshal(headerJSON, &header) + _, hasKID := header["kid"] + t.Logf(" no KeyIDKey set: kid in header = %v (header: %s)", hasKID, headerJSON) + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + _, ourErr := verifier.VerifyJWT(string(noKIDToken)) + t.Logf(" our VerifyJWT: err = %v", ourErr) + + // Import key WITH kid set. + jwxKeyWithKID, _ := jwxjwk.Import(ks.RawPriv) + jwxKeyWithKID.Set(jwxjwk.KeyIDKey, ks.KID) + tok2 := jwxjwt.New() + tok2.Set(jwxjwt.SubjectKey, "with-kid") + tok2.Set(jwxjwt.ExpirationKey, time.Now().Add(time.Hour)) + withKIDToken, _ := jwxjwt.Sign(tok2, jwxjwt.WithKey(jwa.EdDSA(), jwxKeyWithKID)) + + parts2 := strings.SplitN(string(withKIDToken), ".", 3) + headerJSON2, _ := base64.RawURLEncoding.DecodeString(parts2[0]) + var header2 map[string]any + json.Unmarshal(headerJSON2, &header2) + _, hasKID2 := header2["kid"] + t.Logf(" KeyIDKey set: kid in header = %v (header: %s)", hasKID2, headerJSON2) + + _, ourErr2 := verifier.VerifyJWT(string(withKIDToken)) + t.Logf(" our VerifyJWT: err = %v", ourErr2) + + if hasKID { + t.Error("expected no-kid key to omit kid from header") + } + if ourErr != nil { + t.Errorf("expected our verifier to accept token without kid (try-all-keys fallback), got: %v", ourErr) + } + if !hasKID2 { + t.Error("expected kid-set key to include kid in header") + } + if ourErr2 != nil { + t.Errorf("expected our verifier to accept token with kid, got: %v", ourErr2) + } + + t.Log("") + t.Log("NOTE: When kid is missing, our verifier tries all keys (first match wins).") + t.Log("For multi-key verifiers, always set kid for efficient key lookup.") +} + +// ----------------------------------------------------------------------- +// Audience marshaling +// ----------------------------------------------------------------------- + +func TestNuance_ListishMarshal(t *testing.T) { + t.Log("=== Nuance: audience JSON marshaling ===") + t.Log("") + t.Log("RFC 7519 allows aud as either a string or an array of strings.") + t.Log("Libraries differ in how they marshal a single-value audience:") + t.Log("") + + ks := testkeys.GenerateEdDSA("aud-marshal") + + // Our library: single aud => string, multi aud => array. + singleClaims := testkeys.ListishClaims("aud-test", jwt.Listish{"single"}) + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + ourSingleTok, _ := signer.SignToString(singleClaims) + ourSinglePayload := decodePayload(ourSingleTok) + t.Logf(" our lib (single aud): %s", ourSinglePayload) + + multiClaims := testkeys.ListishClaims("aud-test", jwt.Listish{"a", "b"}) + ourMultiTok, _ := signer.SignToString(multiClaims) + ourMultiPayload := decodePayload(ourMultiTok) + t.Logf(" our lib (multi aud): %s", ourMultiPayload) + + // go-jose: check how it marshals. + sigKey := jose.SigningKey{ + Algorithm: jose.EdDSA, + Key: jose.JSONWebKey{Key: ks.RawPriv, KeyID: ks.KID}, + } + joseSigner, _ := jose.NewSigner(sigKey, nil) + joseSingleClaims := josejwt.Claims{ + Subject: "aud-test", + Audience: []string{"single"}, + Expiry: josejwt.NewNumericDate(time.Now().Add(time.Hour)), + } + joseSingleTok, _ := josejwt.Signed(joseSigner).Claims(joseSingleClaims).Serialize() + joseSinglePayload := decodePayload(joseSingleTok) + t.Logf(" go-jose (single aud): %s", joseSinglePayload) + + joseMultiClaims := josejwt.Claims{ + Subject: "aud-test", + Audience: []string{"a", "b"}, + Expiry: josejwt.NewNumericDate(time.Now().Add(time.Hour)), + } + joseMultiTok, _ := josejwt.Signed(joseSigner).Claims(joseMultiClaims).Serialize() + joseMultiPayload := decodePayload(joseMultiTok) + t.Logf(" go-jose (multi aud): %s", joseMultiPayload) + + // All parsers should handle both string and array forms. + // Verify our parser handles go-jose's format. + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + verifiedJWS, err := verifier.VerifyJWT(joseSingleTok) + if err != nil { + t.Fatalf("our verify of go-jose single aud: %v", err) + } + var decoded jwt.TokenClaims + verifiedJWS.UnmarshalClaims(&decoded) + t.Logf(" our parse of go-jose single aud: %v", decoded.Aud) + + if len(decoded.Aud) != 1 || decoded.Aud[0] != "single" { + t.Errorf("expected [single], got %v", decoded.Aud) + } + + t.Log("") + t.Log("Both libraries handle both string and array forms on input.") + t.Log("No action needed - interop is seamless for audience values.") +} + +// ----------------------------------------------------------------------- +// Thumbprint encoding +// ----------------------------------------------------------------------- + +func TestNuance_ThumbprintEncoding(t *testing.T) { + t.Log("=== Nuance: JWK Thumbprint encoding (RFC 7638) ===") + t.Log("") + t.Log("All 3 libraries use unpadded base64url encoding for thumbprints.") + t.Log("Confirming no library adds '=' padding:") + t.Log("") + + for _, ag := range testkeys.AllAlgorithms() { + ks := ag.Generate("thumb-enc-" + ag.Name) + + // Our thumbprint. + ourThumb, _ := ks.PubKey.Thumbprint() + hasPadding := strings.Contains(ourThumb, "=") + t.Logf(" %s - our thumbprint: %s (padding=%v)", ag.Name, ourThumb, hasPadding) + if hasPadding { + t.Errorf("%s: our thumbprint has padding", ag.Name) + } + + // go-jose thumbprint. + joseKey := jose.JSONWebKey{Key: ks.RawPub} + joseRaw, _ := joseKey.Thumbprint(crypto.SHA256) + joseThumb := base64.RawURLEncoding.EncodeToString(joseRaw) + t.Logf(" %s - go-jose thumbprint: %s", ag.Name, joseThumb) + + // jwx thumbprint. + jwxKey, _ := jwxjwk.Import(ks.RawPub) + jwxRaw, _ := jwxKey.Thumbprint(crypto.SHA256) + jwxThumb := base64.RawURLEncoding.EncodeToString(jwxRaw) + t.Logf(" %s - jwx thumbprint: %s", ag.Name, jwxThumb) + + // All three should match. + if ourThumb != joseThumb || ourThumb != jwxThumb { + t.Errorf("%s: thumbprint mismatch: ours=%s go-jose=%s jwx=%s", + ag.Name, ourThumb, joseThumb, jwxThumb) + } + } + + t.Log("") + t.Log("All 3 libraries produce identical unpadded base64url thumbprints.") + t.Log("No action needed.") +} + +// ----------------------------------------------------------------------- +// iat (issued-at) validation +// ----------------------------------------------------------------------- + +func TestNuance_IssuedAtValidation(t *testing.T) { + t.Log("=== Nuance: iat (issued-at) validation ===") + t.Log("") + t.Log("All three libraries reject future iat. Our library checks that iat,") + t.Log("when present, is not in the future - a common-sense sanity check") + t.Log("even though the spec does not require it.") + t.Log("") + + ks := testkeys.GenerateEdDSA("iat-test") + + // Token with iat 10 seconds in the future. + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "iat-future", + Exp: time.Now().Add(time.Hour).Unix(), + IAt: time.Now().Add(10 * time.Second).Unix(), + } + tokenStr := signOurs(t, ks, claims) + + // jwx: rejects at parse time (iat in future, 0 skew). + _, jwxErr := jwxjwt.Parse([]byte(tokenStr), jwxjwt.WithKey(jwa.EdDSA(), ks.RawPub)) + t.Logf(" jwx Parse (0 skew): rejects=%v", jwxErr != nil) + + // go-jose: parse+verify succeeds, ValidateWithLeeway rejects future iat. + tok, _ := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{jose.EdDSA}) + var joseClaims josejwt.Claims + tok.Claims(ks.RawPub, &joseClaims) + joseErr := joseClaims.ValidateWithLeeway(josejwt.Expected{Time: time.Now()}, 0) + t.Logf(" go-jose ValidateWithLeeway(0): rejects=%v", joseErr != nil) + + // Our VerifyJWT: accepts (signature-only, no iat check). + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + jws, ourSigErr := verifier.VerifyJWT(tokenStr) + t.Logf(" our VerifyJWT (sig only): accepts=%v", ourSigErr == nil) + + // Our Validate: rejects future iat (common-sense check, not per spec). + if jws != nil { + var decoded jwt.TokenClaims + jws.UnmarshalClaims(&decoded) + v := jwt.Validator{ + Checks: jwt.CheckIss | jwt.CheckExp | jwt.CheckIAt | jwt.CheckNBf, + Iss: []string{"https://example.com"}, + } + valErr := v.Validate(nil, &decoded, time.Now()) + t.Logf(" our Validate: rejects=%v", valErr != nil) + if valErr == nil { + t.Error("expected our Validate to reject future iat") + } + } + + if jwxErr == nil { + t.Error("expected jwx to reject future iat") + } + if joseErr == nil { + t.Error("expected go-jose to reject future iat") + } + if ourSigErr != nil { + t.Errorf("expected our VerifyJWT to accept (sig-only): %v", ourSigErr) + } + + t.Log("") + t.Log("All three libraries agree: future iat is rejected.") + t.Log("Remove CheckIAt from Checks to opt out of this check if needed.") +} + +// ----------------------------------------------------------------------- +// helpers +// ----------------------------------------------------------------------- + +func decodePayload(tokenStr string) string { + parts := strings.SplitN(tokenStr, ".", 3) + if len(parts) < 2 { + return "(invalid token)" + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "(decode error: " + err.Error() + ")" + } + + // Extract just the aud field for compact display. + var m map[string]any + json.Unmarshal(payload, &m) + aud, ok := m["aud"] + if !ok { + return "(no aud field)" + } + audJSON, _ := json.Marshal(aud) + return "aud=" + string(audJSON) +} diff --git a/auth/jwt/tests/round-trip-go-jose/round_trip_test.go b/auth/jwt/tests/round-trip-go-jose/round_trip_test.go new file mode 100644 index 0000000..a1726b3 --- /dev/null +++ b/auth/jwt/tests/round-trip-go-jose/round_trip_test.go @@ -0,0 +1,660 @@ +// Package josert_test verifies interoperability between this library and +// github.com/go-jose/go-jose/v4 (JWS, JWK, JWT). It covers sign/verify, +// JWK serialization, thumbprint consistency, JWKS, audience, custom claims, +// NumericDate precision, and stress tests. +package josert_test + +import ( + "crypto" + "encoding/base64" + "encoding/json" + "flag" + "fmt" + "testing" + "time" + + jose "github.com/go-jose/go-jose/v4" + josejwt "github.com/go-jose/go-jose/v4/jwt" + + "github.com/therootcompany/golib/auth/jwt" + "github.com/therootcompany/golib/auth/jwt/tests/testkeys" +) + +var longTests = flag.Bool("long", false, "run extended stress tests (100 RSA iterations instead of 10)") + +// joseAlg maps our algorithm name to a go-jose SignatureAlgorithm constant. +func joseAlg(name string) jose.SignatureAlgorithm { + switch name { + case "EdDSA": + return jose.EdDSA + case "ES256": + return jose.ES256 + case "ES384": + return jose.ES384 + case "ES512": + return jose.ES512 + case "RS256": + return jose.RS256 + } + panic("unknown alg: " + name) +} + +// --- helpers --- + +func assertOurSignGoJoseVerify(t *testing.T, ks testkeys.KeySet, sub string) { + t.Helper() + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatalf("NewSigner: %v", err) + } + tokenStr, err := signer.SignToString(testkeys.TestClaims(sub)) + if err != nil { + t.Fatalf("SignToString: %v", err) + } + + // Parse and verify with go-jose. + tok, err := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{joseAlg(ks.AlgName)}) + if err != nil { + t.Fatalf("go-jose ParseSigned: %v", err) + } + + var claims josejwt.Claims + if err := tok.Claims(ks.RawPub, &claims); err != nil { + t.Fatalf("go-jose Claims: %v", err) + } + if claims.Subject != sub { + t.Errorf("sub: got %q, want %q", claims.Subject, sub) + } + if claims.Issuer != "https://example.com" { + t.Errorf("iss: got %q, want %q", claims.Issuer, "https://example.com") + } +} + +func assertGoJoseSignOurVerify(t *testing.T, ks testkeys.KeySet, sub string) { + t.Helper() + + // Use JSONWebKey wrapper to get kid in the JWS header. + sigKey := jose.SigningKey{ + Algorithm: joseAlg(ks.AlgName), + Key: jose.JSONWebKey{ + Key: ks.RawPriv, + KeyID: ks.KID, + }, + } + joseSigner, err := jose.NewSigner(sigKey, nil) + if err != nil { + t.Fatalf("go-jose NewSigner: %v", err) + } + + now := time.Now() + claims := josejwt.Claims{ + Issuer: "https://example.com", + Subject: sub, + Expiry: josejwt.NewNumericDate(now.Add(time.Hour)), + IssuedAt: josejwt.NewNumericDate(now), + } + tokenStr, err := josejwt.Signed(joseSigner).Claims(claims).Serialize() + if err != nil { + t.Fatalf("go-jose Serialize: %v", err) + } + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + verifiedJWS, err := verifier.VerifyJWT(tokenStr) + if err != nil { + t.Fatalf("our verify: %v", err) + } + + var decoded jwt.TokenClaims + if err := verifiedJWS.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims: %v", err) + } + if decoded.Sub != sub { + t.Errorf("sub: got %q, want %q", decoded.Sub, sub) + } +} + +// --- Our sign, go-jose verify (all algorithms) --- + +func TestOurSignGoJoseVerify_EdDSA(t *testing.T) { + assertOurSignGoJoseVerify(t, testkeys.GenerateEdDSA("k1"), "user-eddsa") +} + +func TestOurSignGoJoseVerify_ES256(t *testing.T) { + assertOurSignGoJoseVerify(t, testkeys.GenerateES256("k1"), "user-es256") +} + +func TestOurSignGoJoseVerify_ES384(t *testing.T) { + assertOurSignGoJoseVerify(t, testkeys.GenerateES384("k1"), "user-es384") +} + +func TestOurSignGoJoseVerify_ES512(t *testing.T) { + assertOurSignGoJoseVerify(t, testkeys.GenerateES512("k1"), "user-es512") +} + +func TestOurSignGoJoseVerify_RS256(t *testing.T) { + assertOurSignGoJoseVerify(t, testkeys.GenerateRS256("k1"), "user-rs256") +} + +// --- go-jose sign, our verify (all algorithms) --- + +func TestGoJoseSignOurVerify_EdDSA(t *testing.T) { + assertGoJoseSignOurVerify(t, testkeys.GenerateEdDSA("k1"), "user-eddsa") +} + +func TestGoJoseSignOurVerify_ES256(t *testing.T) { + assertGoJoseSignOurVerify(t, testkeys.GenerateES256("k1"), "user-es256") +} + +func TestGoJoseSignOurVerify_ES384(t *testing.T) { + assertGoJoseSignOurVerify(t, testkeys.GenerateES384("k1"), "user-es384") +} + +func TestGoJoseSignOurVerify_ES512(t *testing.T) { + assertGoJoseSignOurVerify(t, testkeys.GenerateES512("k1"), "user-es512") +} + +func TestGoJoseSignOurVerify_RS256(t *testing.T) { + assertGoJoseSignOurVerify(t, testkeys.GenerateRS256("k1"), "user-rs256") +} + +// --- JWK serialization interop --- + +func TestJWKInterop_OurJSONToGoJose(t *testing.T) { + for _, ag := range testkeys.AllAlgorithms() { + t.Run(ag.Name+"_Public", func(t *testing.T) { + ks := ag.Generate("jwk-" + ag.Name) + + // Marshal our public key to JSON. + ourJSON, err := json.Marshal(ks.PubKey) + if err != nil { + t.Fatalf("marshal our pubkey: %v", err) + } + + // Parse with go-jose. + var joseKey jose.JSONWebKey + if err := json.Unmarshal(ourJSON, &joseKey); err != nil { + t.Fatalf("go-jose unmarshal from our JSON: %v", err) + } + + // Verify a token signed by us, using the go-jose-parsed key. + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatal(err) + } + tokenStr, err := signer.SignToString(testkeys.TestClaims("jwk-interop")) + if err != nil { + t.Fatal(err) + } + + tok, err := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{joseAlg(ks.AlgName)}) + if err != nil { + t.Fatal(err) + } + var claims josejwt.Claims + if err := tok.Claims(joseKey.Key, &claims); err != nil { + t.Fatalf("go-jose verify with our-JSON-parsed key: %v", err) + } + }) + + t.Run(ag.Name+"_Private", func(t *testing.T) { + ks := ag.Generate("jwk-priv-" + ag.Name) + + // Marshal our private key to JSON. + ourJSON, err := json.Marshal(ks.PrivKey) + if err != nil { + t.Fatalf("marshal our privkey: %v", err) + } + + // Parse with go-jose. + var joseKey jose.JSONWebKey + if err := json.Unmarshal(ourJSON, &joseKey); err != nil { + t.Fatalf("go-jose unmarshal from our private JSON: %v", err) + } + + // Sign with the go-jose-parsed key, verify with our lib. + joseKey.KeyID = ks.KID + sigKey := jose.SigningKey{ + Algorithm: joseAlg(ks.AlgName), + Key: joseKey, + } + joseSigner, err := jose.NewSigner(sigKey, nil) + if err != nil { + t.Fatal(err) + } + claims := josejwt.Claims{ + Subject: "jwk-priv-interop", + Expiry: josejwt.NewNumericDate(time.Now().Add(time.Hour)), + } + tokenStr, err := josejwt.Signed(joseSigner).Claims(claims).Serialize() + if err != nil { + t.Fatalf("go-jose sign with our-JSON-parsed key: %v", err) + } + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + if _, err := verifier.VerifyJWT(tokenStr); err != nil { + t.Fatalf("our verify: %v", err) + } + }) + } +} + +func TestJWKInterop_GoJoseJSONToOur(t *testing.T) { + for _, ag := range testkeys.AllAlgorithms() { + t.Run(ag.Name, func(t *testing.T) { + ks := ag.Generate("jose-to-our-" + ag.Name) + + // Create go-jose JWK and serialize. + joseKey := jose.JSONWebKey{ + Key: ks.RawPub, + KeyID: ks.KID, + Algorithm: ks.AlgName, + Use: "sig", + } + joseJSON, err := json.Marshal(joseKey) + if err != nil { + t.Fatalf("marshal go-jose key: %v", err) + } + + // Parse with our library. + var recovered jwt.PublicKey + if err := json.Unmarshal(joseJSON, &recovered); err != nil { + t.Fatalf("our unmarshal of go-jose JSON: %v", err) + } + + // Sign with our signer, verify with the recovered key. + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatal(err) + } + tokenStr, err := signer.SignToString(testkeys.TestClaims("jose-json")) + if err != nil { + t.Fatal(err) + } + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{recovered}) + if _, err := verifier.VerifyJWT(tokenStr); err != nil { + t.Fatalf("verify with go-jose-JSON-parsed key: %v", err) + } + }) + } +} + +// --- Thumbprint consistency (RFC 7638) --- + +func TestThumbprintConsistency_GoJose(t *testing.T) { + for _, ag := range testkeys.AllAlgorithms() { + t.Run(ag.Name, func(t *testing.T) { + ks := ag.Generate("thumb-" + ag.Name) + + // Our thumbprint (returns base64url string). + ourThumb, err := ks.PubKey.Thumbprint() + if err != nil { + t.Fatalf("our Thumbprint: %v", err) + } + + // go-jose thumbprint (returns raw bytes). + joseKey := jose.JSONWebKey{Key: ks.RawPub} + joseRaw, err := joseKey.Thumbprint(crypto.SHA256) + if err != nil { + t.Fatalf("go-jose Thumbprint: %v", err) + } + joseThumb := base64.RawURLEncoding.EncodeToString(joseRaw) + + if ourThumb != joseThumb { + t.Errorf("thumbprint mismatch:\n ours: %s\n go-jose: %s", ourThumb, joseThumb) + } + }) + } +} + +// --- JWKS interop --- + +func TestJWKSInterop_OurToGoJose(t *testing.T) { + // Build a signer with all 5 key types. + var keys []*jwt.PrivateKey + var sets []testkeys.KeySet + for _, ag := range testkeys.AllAlgorithms() { + ks := ag.Generate("jwks-" + ag.Name) + keys = append(keys, ks.PrivKey) + sets = append(sets, ks) + } + signer, err := jwt.NewSigner(keys) + if err != nil { + t.Fatal(err) + } + + // Serialize our JWKS. + jwksData, err := json.Marshal(&signer) + if err != nil { + t.Fatal(err) + } + + // Parse with go-jose. + var joseJWKS jose.JSONWebKeySet + if err := json.Unmarshal(jwksData, &joseJWKS); err != nil { + t.Fatalf("go-jose unmarshal JWKS: %v", err) + } + if len(joseJWKS.Keys) != 5 { + t.Fatalf("expected 5 keys, got %d", len(joseJWKS.Keys)) + } + + // Sign tokens with each key and verify with the go-jose-parsed set. + for i, ks := range sets { + tokenStr, err := signer.SignToString(testkeys.TestClaims(fmt.Sprintf("jwks-%d", i))) + if err != nil { + t.Fatalf("sign[%d]: %v", i, err) + } + + tok, err := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{joseAlg(ks.AlgName)}) + if err != nil { + t.Errorf("parse[%d] (%s): %v", i, ks.AlgName, err) + continue + } + // Find the matching key from the parsed JWKS. + matching := joseJWKS.Key(ks.KID) + if len(matching) == 0 { + t.Errorf("no key found for kid %q", ks.KID) + continue + } + var claims josejwt.Claims + if err := tok.Claims(matching[0].Key, &claims); err != nil { + t.Errorf("go-jose verify[%d] (%s) with parsed JWKS: %v", i, ks.AlgName, err) + } + } +} + +func TestJWKSInterop_GoJoseToOur(t *testing.T) { + // Build a go-jose key set. + var joseJWKS jose.JSONWebKeySet + var sets []testkeys.KeySet + for _, ag := range testkeys.AllAlgorithms() { + ks := ag.Generate("jose-jwks-" + ag.Name) + sets = append(sets, ks) + joseJWKS.Keys = append(joseJWKS.Keys, jose.JSONWebKey{ + Key: ks.RawPub, + KeyID: ks.KID, + Algorithm: ks.AlgName, + Use: "sig", + }) + } + + // Serialize go-jose JWKS. + jwksData, err := json.Marshal(joseJWKS) + if err != nil { + t.Fatal(err) + } + + // Parse with our library. + var ourJWKS jwt.WellKnownJWKs + if err := json.Unmarshal(jwksData, &ourJWKS); err != nil { + t.Fatalf("our unmarshal of go-jose JWKS: %v", err) + } + if len(ourJWKS.Keys) != 5 { + t.Fatalf("expected 5 keys, got %d", len(ourJWKS.Keys)) + } + + verifier, _ := jwt.NewVerifier(ourJWKS.Keys) + + // Sign tokens with go-jose, verify with our library. + for _, ks := range sets { + sigKey := jose.SigningKey{ + Algorithm: joseAlg(ks.AlgName), + Key: jose.JSONWebKey{ + Key: ks.RawPriv, + KeyID: ks.KID, + }, + } + joseSigner, err := jose.NewSigner(sigKey, nil) + if err != nil { + t.Fatalf("go-jose signer %s: %v", ks.AlgName, err) + } + claims := josejwt.Claims{ + Subject: "jose-to-our", + Expiry: josejwt.NewNumericDate(time.Now().Add(time.Hour)), + } + tokenStr, err := josejwt.Signed(joseSigner).Claims(claims).Serialize() + if err != nil { + t.Fatalf("go-jose sign %s: %v", ks.AlgName, err) + } + if _, err := verifier.VerifyJWT(tokenStr); err != nil { + t.Errorf("our verify %s from go-jose JWKS: %v", ks.AlgName, err) + } + } +} + +// --- Audience interop --- + +func TestAudienceStringInterop_GoJose(t *testing.T) { + ks := testkeys.GenerateEdDSA("aud-test") + + // Our library: single aud marshals as plain string "single-aud". + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatal(err) + } + claims := testkeys.ListishClaims("aud-str", jwt.Listish{"single-aud"}) + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + + tok, err := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{jose.EdDSA}) + if err != nil { + t.Fatalf("go-jose parse: %v", err) + } + var joseClaims josejwt.Claims + if err := tok.Claims(ks.RawPub, &joseClaims); err != nil { + t.Fatalf("go-jose Claims: %v", err) + } + if len(joseClaims.Audience) != 1 || joseClaims.Audience[0] != "single-aud" { + t.Errorf("aud: got %v, want [single-aud]", joseClaims.Audience) + } + + // Reverse: go-jose signs with single aud, our library parses. + sigKey := jose.SigningKey{ + Algorithm: jose.EdDSA, + Key: jose.JSONWebKey{Key: ks.RawPriv, KeyID: ks.KID}, + } + joseSigner, _ := jose.NewSigner(sigKey, nil) + jClaims := josejwt.Claims{ + Subject: "aud-str-rev", + Audience: josejwt.Listish{"single-aud"}, + Expiry: josejwt.NewNumericDate(time.Now().Add(time.Hour)), + } + joseToken, _ := josejwt.Signed(joseSigner).Claims(jClaims).Serialize() + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + verifiedJWS, err := verifier.VerifyJWT(joseToken) + if err != nil { + t.Fatal(err) + } + var decoded jwt.TokenClaims + verifiedJWS.UnmarshalClaims(&decoded) + if len(decoded.Aud) == 0 || decoded.Aud[0] != "single-aud" { + t.Errorf("reverse aud: got %v, want [single-aud]", decoded.Aud) + } +} + +func TestAudienceArrayInterop_GoJose(t *testing.T) { + ks := testkeys.GenerateEdDSA("aud-arr") + + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + claims := testkeys.ListishClaims("aud-arr", jwt.Listish{"aud1", "aud2"}) + tokenStr, _ := signer.SignToString(claims) + + tok, err := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{jose.EdDSA}) + if err != nil { + t.Fatalf("go-jose parse: %v", err) + } + var joseClaims josejwt.Claims + if err := tok.Claims(ks.RawPub, &joseClaims); err != nil { + t.Fatal(err) + } + if len(joseClaims.Audience) != 2 || joseClaims.Audience[0] != "aud1" || joseClaims.Audience[1] != "aud2" { + t.Errorf("aud: got %v, want [aud1 aud2]", joseClaims.Audience) + } +} + +// --- Custom claims interop --- + +func TestCustomClaimsInterop_GoJose(t *testing.T) { + ks := testkeys.GenerateEdDSA("custom") + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + claims := &testkeys.CustomClaims{ + TokenClaims: *testkeys.TestClaims("custom-user"), + Email: "user@example.com", + Roles: []string{"admin", "editor"}, + Metadata: map[string]string{"team": "platform"}, + } + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + + tok, err := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{jose.EdDSA}) + if err != nil { + t.Fatalf("go-jose parse: %v", err) + } + + // go-jose extracts into an arbitrary struct. + var extracted struct { + josejwt.Claims + Email string `json:"email"` + Roles []string `json:"roles"` + Metadata map[string]string `json:"metadata"` + } + if err := tok.Claims(ks.RawPub, &extracted); err != nil { + t.Fatalf("go-jose Claims: %v", err) + } + if extracted.Email != "user@example.com" { + t.Errorf("email: got %q, want %q", extracted.Email, "user@example.com") + } + if len(extracted.Roles) != 2 || extracted.Roles[0] != "admin" { + t.Errorf("roles: got %v, want [admin editor]", extracted.Roles) + } + if extracted.Metadata["team"] != "platform" { + t.Errorf("metadata.team: got %v, want %q", extracted.Metadata["team"], "platform") + } +} + +// --- NumericDate precision --- + +func TestNumericDatePrecision_GoJose(t *testing.T) { + ks := testkeys.GenerateEdDSA("nd") + + // Use fixed future timestamps to test precision without triggering + // expiration validation. 2000000000 = 2033-05-18. + var wantExp int64 = 2000000000 + var wantIat int64 = 1999999000 + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "numdate", + Exp: wantExp, + IAt: wantIat, + } + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + tokenStr, _ := signer.SignToString(claims) + + tok, err := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{jose.EdDSA}) + if err != nil { + t.Fatal(err) + } + var joseClaims josejwt.Claims + if err := tok.Claims(ks.RawPub, &joseClaims); err != nil { + t.Fatal(err) + } + if joseClaims.Expiry.Time().Unix() != wantExp { + t.Errorf("exp: got %d, want %d", joseClaims.Expiry.Time().Unix(), wantExp) + } + if joseClaims.IssuedAt.Time().Unix() != wantIat { + t.Errorf("iat: got %d, want %d", joseClaims.IssuedAt.Time().Unix(), wantIat) + } + + // Reverse: go-jose signs with specific times, our library reads. + var wantExp2 int64 = 2100000000 + var wantIat2 int64 = 2099999000 + sigKey := jose.SigningKey{ + Algorithm: jose.EdDSA, + Key: jose.JSONWebKey{Key: ks.RawPriv, KeyID: ks.KID}, + } + joseSigner, _ := jose.NewSigner(sigKey, nil) + jClaims := josejwt.Claims{ + Subject: "numdate-rev", + Expiry: josejwt.NewNumericDate(time.Unix(wantExp2, 0)), + IssuedAt: josejwt.NewNumericDate(time.Unix(wantIat2, 0)), + } + joseToken, _ := josejwt.Signed(joseSigner).Claims(jClaims).Serialize() + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + verifiedJWS, _ := verifier.VerifyJWT(joseToken) + var decoded jwt.TokenClaims + verifiedJWS.UnmarshalClaims(&decoded) + if decoded.Exp != wantExp2 { + t.Errorf("rev exp: got %d, want %d", decoded.Exp, wantExp2) + } + if decoded.IAt != wantIat2 { + t.Errorf("rev iat: got %d, want %d", decoded.IAt, wantIat2) + } +} + +// --- Stress tests --- + +func TestStress_GoJose(t *testing.T) { + for _, ag := range testkeys.AllAlgorithms() { + ag := ag + t.Run(ag.Name, func(t *testing.T) { + t.Parallel() + n := 1000 + if ag.Name == "RS256" { + n = 10 + if *longTests { + n = 100 + } + } + for i := range n { + ks := ag.Generate(fmt.Sprintf("s%d", i)) + sub := fmt.Sprintf("stress-%d", i) + + // Our sign, go-jose verify. + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatalf("iter %d: NewSigner: %v", i, err) + } + tokenStr, err := signer.SignToString(testkeys.TestClaims(sub)) + if err != nil { + t.Fatalf("iter %d: SignToString: %v", i, err) + } + tok, err := josejwt.ParseSigned(tokenStr, []jose.SignatureAlgorithm{joseAlg(ks.AlgName)}) + if err != nil { + t.Fatalf("iter %d: go-jose parse: %v", i, err) + } + var claims josejwt.Claims + if err := tok.Claims(ks.RawPub, &claims); err != nil { + t.Fatalf("iter %d: go-jose verify: %v", i, err) + } + + // go-jose sign, our verify. + sigKey := jose.SigningKey{ + Algorithm: joseAlg(ks.AlgName), + Key: jose.JSONWebKey{Key: ks.RawPriv, KeyID: ks.KID}, + } + joseSigner, err := jose.NewSigner(sigKey, nil) + if err != nil { + t.Fatalf("iter %d: go-jose NewSigner: %v", i, err) + } + jClaims := josejwt.Claims{ + Subject: sub, + Expiry: josejwt.NewNumericDate(time.Now().Add(time.Hour)), + } + joseToken, err := josejwt.Signed(joseSigner).Claims(jClaims).Serialize() + if err != nil { + t.Fatalf("iter %d: go-jose sign: %v", i, err) + } + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + if _, err := verifier.VerifyJWT(joseToken); err != nil { + t.Fatalf("iter %d: our verify: %v", i, err) + } + } + }) + } +} diff --git a/auth/jwt/tests/round-trip-go-jwt/round_trip_test.go b/auth/jwt/tests/round-trip-go-jwt/round_trip_test.go new file mode 100644 index 0000000..237d859 --- /dev/null +++ b/auth/jwt/tests/round-trip-go-jwt/round_trip_test.go @@ -0,0 +1,775 @@ +// Copyright 2026 AJ ONeal (https://therootcompany.com) +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// SPDX-License-Identifier: MPL-2.0 + +// Package roundtrip_test verifies interoperability between this library and +// github.com/golang-jwt/jwt/v5. It lives in a separate module (tests/go.mod) +// so that the golang-jwt dependency does not leak into the main module graph. +// +// Tests cover: +// - Our sign + their verify (Ed25519, EC P-256, P-384, P-521, RSA) +// - Their sign + our verify (Ed25519, EC P-256, P-384, P-521, RSA) +// - Known/fixed keys: deterministic key material for reproducible tests +// - Stress tests: 1,000 keys per algorithm to catch ASN.1/padding edge cases +// - JWK key round-trip: marshal/unmarshal private and public keys, then +// confirm the recovered keys interoperate correctly. +package roundtrip_test + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/json" + "flag" + "fmt" + "io" + "testing" + "time" + + gjwt "github.com/golang-jwt/jwt/v5" + + "github.com/therootcompany/golib/auth/jwt" +) + +var longTests = flag.Bool("long", false, "run extended stress tests (100 RSA iterations instead of 10)") + +// --- helpers --- + +// testClaims returns a fresh set of claims for a test iteration. +func testClaims(sub string) *jwt.TokenClaims { + now := time.Now() + return &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: sub, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + } +} + +// hashReader produces deterministic bytes from a SHA-256 hash chain. +// Not cryptographically secure - used only for reproducible test key generation. +type hashReader struct { + state [32]byte + pos int +} + +func deterministicRand(seed string) io.Reader { + s := sha256.Sum256([]byte(seed)) + return &hashReader{state: s} +} + +func (r *hashReader) Read(p []byte) (int, error) { + n := 0 + for n < len(p) { + if r.pos >= len(r.state) { + r.state = sha256.Sum256(r.state[:]) + r.pos = 0 + } + copied := copy(p[n:], r.state[r.pos:]) + n += copied + r.pos += copied + } + return n, nil +} + +// mustPK wraps jwt.FromPrivateKey and fails the test on error. +func mustPK(t testing.TB, signer crypto.Signer, kid string) *jwt.PrivateKey { + t.Helper() + pk, err := jwt.FromPrivateKey(signer, kid) + if err != nil { + t.Fatal(err) + } + return pk +} + +// assertOurSignTheirVerify signs with our library and verifies with golang-jwt. +func assertOurSignTheirVerify(t *testing.T, pk *jwt.PrivateKey, gjwtMethod gjwt.SigningMethod, gjwtPub any, sub string) { + t.Helper() + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatalf("NewSigner: %v", err) + } + claims := testClaims(sub) + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatalf("SignToString: %v", err) + } + + parsed, err := gjwt.ParseWithClaims(tokenStr, &gjwt.RegisteredClaims{}, func(tok *gjwt.Token) (any, error) { + if tok.Method.Alg() != gjwtMethod.Alg() { + return nil, fmt.Errorf("unexpected alg: got %q, want %q", tok.Method.Alg(), gjwtMethod.Alg()) + } + return gjwtPub, nil + }) + if err != nil { + t.Fatalf("golang-jwt verify failed: %v", err) + } + rc, ok := parsed.Claims.(*gjwt.RegisteredClaims) + if !ok || !parsed.Valid { + t.Fatal("token invalid or claims unreadable") + } + if rc.Subject != sub { + t.Errorf("sub: got %q, want %q", rc.Subject, sub) + } + if rc.Issuer != claims.Iss { + t.Errorf("iss: got %q, want %q", rc.Issuer, claims.Iss) + } +} + +// assertTheirSignOurVerify signs with golang-jwt and verifies with our library. +func assertTheirSignOurVerify(t *testing.T, gjwtMethod gjwt.SigningMethod, gjwtPriv any, kid string, ourPub jwt.PublicKey, sub string) { + t.Helper() + + now := time.Now() + gClaims := gjwt.RegisteredClaims{ + Issuer: "https://example.com", + Subject: sub, + ExpiresAt: gjwt.NewNumericDate(now.Add(time.Hour)), + IssuedAt: gjwt.NewNumericDate(now), + } + tok := gjwt.NewWithClaims(gjwtMethod, gClaims) + tok.Header["kid"] = kid + + tokenStr, err := tok.SignedString(gjwtPriv) + if err != nil { + t.Fatalf("golang-jwt sign: %v", err) + } + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ourPub}) + jws, err := verifier.VerifyJWT(tokenStr) + if err != nil { + t.Fatalf("our verify failed: %v", err) + } + + var decoded jwt.TokenClaims + if err := jws.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims: %v", err) + } + if decoded.Sub != sub { + t.Errorf("sub: got %q, want %q", decoded.Sub, sub) + } + if decoded.Iss != gClaims.Issuer { + t.Errorf("iss: got %q, want %q", decoded.Iss, gClaims.Issuer) + } +} + +// stressIteration tests one key in both directions: our sign + their verify, +// then their sign + our verify. +func stressIteration(t *testing.T, i int, pk *jwt.PrivateKey, pub jwt.PublicKey, gjwtMethod gjwt.SigningMethod, gjwtPriv any, gjwtPub any) { + t.Helper() + sub := fmt.Sprintf("stress-%d", i) + + // Our sign, their verify. + signer, err := jwt.NewSigner([]*jwt.PrivateKey{pk}) + if err != nil { + t.Fatalf("iter %d: NewSigner: %v", i, err) + } + tokenStr, err := signer.SignToString(testClaims(sub)) + if err != nil { + t.Fatalf("iter %d: SignToString: %v", i, err) + } + _, err = gjwt.ParseWithClaims(tokenStr, &gjwt.RegisteredClaims{}, func(tok *gjwt.Token) (any, error) { + return gjwtPub, nil + }) + if err != nil { + t.Fatalf("iter %d: golang-jwt verify: %v", i, err) + } + + // Their sign, our verify. + gClaims := gjwt.RegisteredClaims{ + Subject: sub, + ExpiresAt: gjwt.NewNumericDate(time.Now().Add(time.Hour)), + } + tok := gjwt.NewWithClaims(gjwtMethod, gClaims) + tok.Header["kid"] = pk.KID + tokenStr, err = tok.SignedString(gjwtPriv) + if err != nil { + t.Fatalf("iter %d: golang-jwt sign: %v", i, err) + } + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{pub}) + if _, err := verifier.VerifyJWT(tokenStr); err != nil { + t.Fatalf("iter %d: our verify: %v", i, err) + } +} + +// --- Our sign, their verify (all algorithms) --- + +func TestOurSignTheirVerify_EdDSA(t *testing.T) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + pub := priv.Public().(ed25519.PublicKey) + assertOurSignTheirVerify(t, + mustPK(t, priv, "k1"), + gjwt.SigningMethodEdDSA, pub, "user-eddsa") +} + +func TestOurSignTheirVerify_ES256(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + assertOurSignTheirVerify(t, + mustPK(t, priv, "k1"), + gjwt.SigningMethodES256, &priv.PublicKey, "user-es256") +} + +func TestOurSignTheirVerify_ES384(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + assertOurSignTheirVerify(t, + mustPK(t, priv, "k1"), + gjwt.SigningMethodES384, &priv.PublicKey, "user-es384") +} + +func TestOurSignTheirVerify_ES512(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + assertOurSignTheirVerify(t, + mustPK(t, priv, "k1"), + gjwt.SigningMethodES512, &priv.PublicKey, "user-es512") +} + +func TestOurSignTheirVerify_RS256(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + assertOurSignTheirVerify(t, + mustPK(t, priv, "k1"), + gjwt.SigningMethodRS256, &priv.PublicKey, "user-rs256") +} + +// --- Their sign, our verify (all algorithms) --- + +func TestTheirSignOurVerify_EdDSA(t *testing.T) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + pub := priv.Public().(ed25519.PublicKey) + assertTheirSignOurVerify(t, + gjwt.SigningMethodEdDSA, priv, "k1", + jwt.PublicKey{Key: pub, KID: "k1"}, "user-eddsa") +} + +func TestTheirSignOurVerify_ES256(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + assertTheirSignOurVerify(t, + gjwt.SigningMethodES256, priv, "k1", + jwt.PublicKey{Key: &priv.PublicKey, KID: "k1"}, "user-es256") +} + +func TestTheirSignOurVerify_ES384(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + assertTheirSignOurVerify(t, + gjwt.SigningMethodES384, priv, "k1", + jwt.PublicKey{Key: &priv.PublicKey, KID: "k1"}, "user-es384") +} + +func TestTheirSignOurVerify_ES512(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + assertTheirSignOurVerify(t, + gjwt.SigningMethodES512, priv, "k1", + jwt.PublicKey{Key: &priv.PublicKey, KID: "k1"}, "user-es512") +} + +func TestTheirSignOurVerify_RS256(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + assertTheirSignOurVerify(t, + gjwt.SigningMethodRS256, priv, "k1", + jwt.PublicKey{Key: &priv.PublicKey, KID: "k1"}, "user-rs256") +} + +// --- Known key tests --- +// +// Each algorithm uses deterministic key material so failures are reproducible +// across runs. Ed25519 uses NewKeyFromSeed; EC and RSA use a SHA-256 hash +// chain seeded from a fixed string. + +func TestKnownKeys(t *testing.T) { + t.Run("EdDSA", func(t *testing.T) { + seed := make([]byte, ed25519.SeedSize) + for i := range seed { + seed[i] = byte(i) + } + priv := ed25519.NewKeyFromSeed(seed) + pub := priv.Public().(ed25519.PublicKey) + kid := "known-ed" + pk := mustPK(t, priv, kid) + pubKey := jwt.PublicKey{Key: pub, KID: kid} + assertOurSignTheirVerify(t, pk, gjwt.SigningMethodEdDSA, pub, "known-ed-ours") + assertTheirSignOurVerify(t, gjwt.SigningMethodEdDSA, priv, kid, pubKey, "known-ed-theirs") + }) + + t.Run("ES256", func(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), deterministicRand("known-es256")) + if err != nil { + t.Fatal(err) + } + kid := "known-es256" + pk := mustPK(t, priv, kid) + pubKey := jwt.PublicKey{Key: &priv.PublicKey, KID: kid} + assertOurSignTheirVerify(t, pk, gjwt.SigningMethodES256, &priv.PublicKey, "known-es256-ours") + assertTheirSignOurVerify(t, gjwt.SigningMethodES256, priv, kid, pubKey, "known-es256-theirs") + }) + + t.Run("ES384", func(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P384(), deterministicRand("known-es384")) + if err != nil { + t.Fatal(err) + } + kid := "known-es384" + pk := mustPK(t, priv, kid) + pubKey := jwt.PublicKey{Key: &priv.PublicKey, KID: kid} + assertOurSignTheirVerify(t, pk, gjwt.SigningMethodES384, &priv.PublicKey, "known-es384-ours") + assertTheirSignOurVerify(t, gjwt.SigningMethodES384, priv, kid, pubKey, "known-es384-theirs") + }) + + t.Run("ES512", func(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P521(), deterministicRand("known-es512")) + if err != nil { + t.Fatal(err) + } + kid := "known-es512" + pk := mustPK(t, priv, kid) + pubKey := jwt.PublicKey{Key: &priv.PublicKey, KID: kid} + assertOurSignTheirVerify(t, pk, gjwt.SigningMethodES512, &priv.PublicKey, "known-es512-ours") + assertTheirSignOurVerify(t, gjwt.SigningMethodES512, priv, kid, pubKey, "known-es512-theirs") + }) + + t.Run("RS256", func(t *testing.T) { + priv, err := rsa.GenerateKey(deterministicRand("known-rs256"), 2048) + if err != nil { + t.Fatal(err) + } + kid := "known-rs256" + pk := mustPK(t, priv, kid) + pubKey := jwt.PublicKey{Key: &priv.PublicKey, KID: kid} + assertOurSignTheirVerify(t, pk, gjwt.SigningMethodRS256, &priv.PublicKey, "known-rs256-ours") + assertTheirSignOurVerify(t, gjwt.SigningMethodRS256, priv, kid, pubKey, "known-rs256-theirs") + }) +} + +// --- Stress tests --- +// +// Each subtest generates 1,000 random keys and signs+verifies in both +// directions per key. This catches edge cases in ASN.1 DER-to-raw signature +// conversion (ECDSA r/s values that are shorter than the field size and +// need left-padding) and any key-dependent encoding quirks. +// +// RSA keygen is inherently slow (~10ms per 2048-bit key); RSA defaults to +// 10 iterations. Use -long to run 100. + +func TestStress(t *testing.T) { + t.Run("EdDSA", func(t *testing.T) { + t.Parallel() + for i := range 1000 { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("iter %d: keygen: %v", i, err) + } + pub := priv.Public().(ed25519.PublicKey) + kid := fmt.Sprintf("s%d", i) + stressIteration(t, i, + mustPK(t, priv, kid), + jwt.PublicKey{Key: pub, KID: kid}, + gjwt.SigningMethodEdDSA, priv, pub) + } + }) + + t.Run("ES256", func(t *testing.T) { + t.Parallel() + for i := range 1000 { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("iter %d: keygen: %v", i, err) + } + kid := fmt.Sprintf("s%d", i) + stressIteration(t, i, + mustPK(t, priv, kid), + jwt.PublicKey{Key: &priv.PublicKey, KID: kid}, + gjwt.SigningMethodES256, priv, &priv.PublicKey) + } + }) + + t.Run("ES384", func(t *testing.T) { + t.Parallel() + for i := range 1000 { + priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatalf("iter %d: keygen: %v", i, err) + } + kid := fmt.Sprintf("s%d", i) + stressIteration(t, i, + mustPK(t, priv, kid), + jwt.PublicKey{Key: &priv.PublicKey, KID: kid}, + gjwt.SigningMethodES384, priv, &priv.PublicKey) + } + }) + + t.Run("ES512", func(t *testing.T) { + t.Parallel() + for i := range 1000 { + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatalf("iter %d: keygen: %v", i, err) + } + kid := fmt.Sprintf("s%d", i) + stressIteration(t, i, + mustPK(t, priv, kid), + jwt.PublicKey{Key: &priv.PublicKey, KID: kid}, + gjwt.SigningMethodES512, priv, &priv.PublicKey) + } + }) + + t.Run("RS256", func(t *testing.T) { + t.Parallel() + n := 10 + if *longTests { + n = 100 + } + for i := range n { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("iter %d: keygen: %v", i, err) + } + kid := fmt.Sprintf("s%d", i) + stressIteration(t, i, + mustPK(t, priv, kid), + jwt.PublicKey{Key: &priv.PublicKey, KID: kid}, + gjwt.SigningMethodRS256, priv, &priv.PublicKey) + } + }) +} + +// --- JWK private key round-trip --- +// +// Marshal a private key to JWK JSON, unmarshal it back, and confirm the +// recovered key produces tokens verifiable by both the original public key +// and golang-jwt. + +func TestJWKPrivateKeyRoundTrip(t *testing.T) { + t.Run("Ed25519", func(t *testing.T) { + original, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + pub, err := original.PublicKey() + if err != nil { + t.Fatal(err) + } + assertPrivateKeyRoundTrip(t, original, + gjwt.SigningMethodEdDSA, pub.Key.(ed25519.PublicKey)) + }) + + t.Run("EC_P256", func(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + original := mustPK(t, priv, "ec256-rt") + assertPrivateKeyRoundTrip(t, original, + gjwt.SigningMethodES256, &priv.PublicKey) + }) + + t.Run("EC_P384", func(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + original := mustPK(t, priv, "ec384-rt") + assertPrivateKeyRoundTrip(t, original, + gjwt.SigningMethodES384, &priv.PublicKey) + }) + + t.Run("EC_P521", func(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + original := mustPK(t, priv, "ec521-rt") + assertPrivateKeyRoundTrip(t, original, + gjwt.SigningMethodES512, &priv.PublicKey) + }) + + t.Run("RSA", func(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + original := mustPK(t, priv, "rsa-rt") + assertPrivateKeyRoundTrip(t, original, + gjwt.SigningMethodRS256, &priv.PublicKey) + }) +} + +func assertPrivateKeyRoundTrip(t *testing.T, original *jwt.PrivateKey, gjwtMethod gjwt.SigningMethod, gjwtPub any) { + t.Helper() + + // Marshal to JSON. + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + // Unmarshal back. + var recovered jwt.PrivateKey + if err := json.Unmarshal(data, &recovered); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if recovered.KID != original.KID { + t.Errorf("KID: got %q, want %q", recovered.KID, original.KID) + } + + claims := testClaims("pk-roundtrip") + + // Sign with recovered key, verify with original pubkey. + signer, err := jwt.NewSigner([]*jwt.PrivateKey{&recovered}) + if err != nil { + t.Fatal(err) + } + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + origPub, _ := original.PublicKey() + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{*origPub}) + if _, err := verifier.VerifyJWT(tokenStr); err != nil { + t.Errorf("verify with original pubkey: %v", err) + } + + // Sign with original key, verify with recovered pubkey. + origSigner, err := jwt.NewSigner([]*jwt.PrivateKey{original}) + if err != nil { + t.Fatal(err) + } + tokenStr2, err := origSigner.SignToString(claims) + if err != nil { + t.Fatal(err) + } + recPub, _ := recovered.PublicKey() + verifier2, _ := jwt.NewVerifier([]jwt.PublicKey{*recPub}) + if _, err := verifier2.VerifyJWT(tokenStr2); err != nil { + t.Errorf("verify with recovered pubkey: %v", err) + } + + // Cross-verify with golang-jwt. + _, err = gjwt.ParseWithClaims(tokenStr, &gjwt.RegisteredClaims{}, func(tok *gjwt.Token) (any, error) { + return gjwtPub, nil + }) + if err != nil { + t.Errorf("golang-jwt cross-verify: %v", err) + } +} + +// --- JWK public key round-trip --- +// +// Marshal a public key to JWK JSON, unmarshal it back, and confirm the +// round-tripped key verifies tokens signed by the original private key. + +func TestJWKPublicKeyRoundTrip(t *testing.T) { + t.Run("Ed25519", func(t *testing.T) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + pub := priv.Public().(ed25519.PublicKey) + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "ed-pub-rt")}) + if err != nil { + t.Fatal(err) + } + assertPublicKeyRoundTrip(t, + jwt.PublicKey{Key: pub, KID: "ed-pub-rt"}, + signer, pub) + }) + + t.Run("EC_P256", func(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "ec256-pub-rt")}) + if err != nil { + t.Fatal(err) + } + assertPublicKeyRoundTrip(t, + jwt.PublicKey{Key: &priv.PublicKey, KID: "ec256-pub-rt"}, + signer, &priv.PublicKey) + }) + + t.Run("EC_P384", func(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "ec384-pub-rt")}) + if err != nil { + t.Fatal(err) + } + assertPublicKeyRoundTrip(t, + jwt.PublicKey{Key: &priv.PublicKey, KID: "ec384-pub-rt"}, + signer, &priv.PublicKey) + }) + + t.Run("EC_P521", func(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "ec521-pub-rt")}) + if err != nil { + t.Fatal(err) + } + assertPublicKeyRoundTrip(t, + jwt.PublicKey{Key: &priv.PublicKey, KID: "ec521-pub-rt"}, + signer, &priv.PublicKey) + }) + + t.Run("RSA", func(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + signer, err := jwt.NewSigner([]*jwt.PrivateKey{mustPK(t, priv, "rsa-pub-rt")}) + if err != nil { + t.Fatal(err) + } + assertPublicKeyRoundTrip(t, + jwt.PublicKey{Key: &priv.PublicKey, KID: "rsa-pub-rt"}, + signer, &priv.PublicKey) + }) +} + +func assertPublicKeyRoundTrip(t *testing.T, origPub jwt.PublicKey, signer *jwt.Signer, gjwtPub any) { + t.Helper() + + // Marshal to JSON. + data, err := json.Marshal(origPub) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + // Unmarshal back. + var recovered jwt.PublicKey + if err := json.Unmarshal(data, &recovered); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if recovered.KID != origPub.KID { + t.Errorf("KID: got %q, want %q", recovered.KID, origPub.KID) + } + + // Sign and verify with the round-tripped key. + tokenStr, err := signer.SignToString(testClaims("pub-roundtrip")) + if err != nil { + t.Fatal(err) + } + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{recovered}) + if _, err := verifier.VerifyJWT(tokenStr); err != nil { + t.Errorf("verify with round-tripped pubkey: %v", err) + } + + // Cross-verify with golang-jwt. + _, err = gjwt.ParseWithClaims(tokenStr, &gjwt.RegisteredClaims{}, func(tok *gjwt.Token) (any, error) { + return gjwtPub, nil + }) + if err != nil { + t.Errorf("golang-jwt cross-verify: %v", err) + } +} + +// --- JWKS round-trip --- + +// TestJWKSRoundTrip marshals a full JWKS document containing all supported +// key types and verifies that tokens signed with each key are verifiable +// after unmarshal. +func TestJWKSRoundTrip(t *testing.T) { + edKey, err := jwt.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + ec256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + ec384, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + ec521, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatal(err) + } + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + keys := []*jwt.PrivateKey{ + edKey, + mustPK(t, ec256, "ec256"), + mustPK(t, ec384, "ec384"), + mustPK(t, ec521, "ec521"), + mustPK(t, rsaKey, "rsa"), + } + signer, err := jwt.NewSigner(keys) + if err != nil { + t.Fatal(err) + } + + // Serialize the JWKS (public keys only). + jwksData, err := json.Marshal(&signer) + if err != nil { + t.Fatalf("marshal JWKS: %v", err) + } + + // Parse it back. + var jwks jwt.WellKnownJWKs + if err := json.Unmarshal(jwksData, &jwks); err != nil { + t.Fatalf("unmarshal JWKS: %v", err) + } + if len(jwks.Keys) != 5 { + t.Fatalf("expected 5 keys, got %d", len(jwks.Keys)) + } + + verifier, _ := jwt.NewVerifier(jwks.Keys) + claims := testClaims("jwks-round-trip") + + // Sign with each key (round-robin) and verify all. + for i := range len(keys) { + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatalf("sign[%d]: %v", i, err) + } + if _, err := verifier.VerifyJWT(tokenStr); err != nil { + t.Errorf("verify[%d] after JWKS round-trip: %v", i, err) + } + } +} diff --git a/auth/jwt/tests/round-trip-jwx/round_trip_test.go b/auth/jwt/tests/round-trip-jwx/round_trip_test.go new file mode 100644 index 0000000..d28dc08 --- /dev/null +++ b/auth/jwt/tests/round-trip-jwx/round_trip_test.go @@ -0,0 +1,621 @@ +// Package jwxrt_test verifies interoperability between this library and +// github.com/lestrrat-go/jwx/v3 (JWA, JWK, JWS, JWT). It covers sign/verify, +// JWK serialization, thumbprint consistency, JWKS, audience, custom claims, +// NumericDate precision, and stress tests. +package jwxrt_test + +import ( + "crypto" + "encoding/base64" + "encoding/json" + "flag" + "fmt" + "testing" + "time" + + "github.com/lestrrat-go/jwx/v3/jwa" + jwxjwk "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/lestrrat-go/jwx/v3/jws" + jwxjwt "github.com/lestrrat-go/jwx/v3/jwt" + + "github.com/therootcompany/golib/auth/jwt" + "github.com/therootcompany/golib/auth/jwt/tests/testkeys" +) + +var longTests = flag.Bool("long", false, "run extended stress tests (100 RSA iterations instead of 10)") + +// jwxAlg maps our algorithm name to a jwx v3 SignatureAlgorithm. +func jwxAlg(name string) jwa.SignatureAlgorithm { + switch name { + case "EdDSA": + return jwa.EdDSA() + case "ES256": + return jwa.ES256() + case "ES384": + return jwa.ES384() + case "ES512": + return jwa.ES512() + case "RS256": + return jwa.RS256() + } + panic("unknown alg: " + name) +} + +// --- helpers --- + +func assertOurSignJWXVerify(t *testing.T, ks testkeys.KeySet, sub string) { + t.Helper() + + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatalf("NewSigner: %v", err) + } + tokenStr, err := signer.SignToString(testkeys.TestClaims(sub)) + if err != nil { + t.Fatalf("SignToString: %v", err) + } + + // Verify at JWS level. + _, err = jws.Verify([]byte(tokenStr), jws.WithKey(jwxAlg(ks.AlgName), ks.RawPub)) + if err != nil { + t.Fatalf("jwx jws.Verify: %v", err) + } + + // Verify at JWT level and check claims. + tok, err := jwxjwt.Parse([]byte(tokenStr), jwxjwt.WithKey(jwxAlg(ks.AlgName), ks.RawPub)) + if err != nil { + t.Fatalf("jwx jwt.Parse: %v", err) + } + gotSub, ok := tok.Subject() + if !ok || gotSub != sub { + t.Errorf("sub: got %q (ok=%v), want %q", gotSub, ok, sub) + } + gotIss, ok := tok.Issuer() + if !ok || gotIss != "https://example.com" { + t.Errorf("iss: got %q (ok=%v), want %q", gotIss, ok, "https://example.com") + } +} + +func assertJWXSignOurVerify(t *testing.T, ks testkeys.KeySet, sub string) { + t.Helper() + + // Import raw key into jwx and set kid. + jwxKey, err := jwxjwk.Import(ks.RawPriv) + if err != nil { + t.Fatalf("jwk.Import: %v", err) + } + if err := jwxKey.Set(jwxjwk.KeyIDKey, ks.KID); err != nil { + t.Fatalf("set kid: %v", err) + } + + tok := jwxjwt.New() + tok.Set(jwxjwt.SubjectKey, sub) + tok.Set(jwxjwt.IssuerKey, "https://example.com") + tok.Set(jwxjwt.ExpirationKey, time.Now().Add(time.Hour)) + tok.Set(jwxjwt.IssuedAtKey, time.Now()) + + signed, err := jwxjwt.Sign(tok, jwxjwt.WithKey(jwxAlg(ks.AlgName), jwxKey)) + if err != nil { + t.Fatalf("jwx jwt.Sign: %v", err) + } + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + verifiedJWS, err := verifier.VerifyJWT(string(signed)) + if err != nil { + t.Fatalf("our verify: %v", err) + } + + var decoded jwt.TokenClaims + if err := verifiedJWS.UnmarshalClaims(&decoded); err != nil { + t.Fatalf("UnmarshalClaims: %v", err) + } + if decoded.Sub != sub { + t.Errorf("sub: got %q, want %q", decoded.Sub, sub) + } +} + +// --- Our sign, jwx verify (all algorithms) --- + +func TestOurSignJWXVerify_EdDSA(t *testing.T) { + assertOurSignJWXVerify(t, testkeys.GenerateEdDSA("k1"), "user-eddsa") +} + +func TestOurSignJWXVerify_ES256(t *testing.T) { + assertOurSignJWXVerify(t, testkeys.GenerateES256("k1"), "user-es256") +} + +func TestOurSignJWXVerify_ES384(t *testing.T) { + assertOurSignJWXVerify(t, testkeys.GenerateES384("k1"), "user-es384") +} + +func TestOurSignJWXVerify_ES512(t *testing.T) { + assertOurSignJWXVerify(t, testkeys.GenerateES512("k1"), "user-es512") +} + +func TestOurSignJWXVerify_RS256(t *testing.T) { + assertOurSignJWXVerify(t, testkeys.GenerateRS256("k1"), "user-rs256") +} + +// --- jwx sign, our verify (all algorithms) --- + +func TestJWXSignOurVerify_EdDSA(t *testing.T) { + assertJWXSignOurVerify(t, testkeys.GenerateEdDSA("k1"), "user-eddsa") +} + +func TestJWXSignOurVerify_ES256(t *testing.T) { + assertJWXSignOurVerify(t, testkeys.GenerateES256("k1"), "user-es256") +} + +func TestJWXSignOurVerify_ES384(t *testing.T) { + assertJWXSignOurVerify(t, testkeys.GenerateES384("k1"), "user-es384") +} + +func TestJWXSignOurVerify_ES512(t *testing.T) { + assertJWXSignOurVerify(t, testkeys.GenerateES512("k1"), "user-es512") +} + +func TestJWXSignOurVerify_RS256(t *testing.T) { + assertJWXSignOurVerify(t, testkeys.GenerateRS256("k1"), "user-rs256") +} + +// --- JWK serialization interop --- + +func TestJWKInterop_OurJSONToJWX(t *testing.T) { + for _, ag := range testkeys.AllAlgorithms() { + t.Run(ag.Name+"_Public", func(t *testing.T) { + ks := ag.Generate("jwk-" + ag.Name) + + // Marshal our public key to JSON. + ourJSON, err := json.Marshal(ks.PubKey) + if err != nil { + t.Fatalf("marshal our pubkey: %v", err) + } + + // Parse with jwx. + jwxKey, err := jwxjwk.ParseKey(ourJSON) + if err != nil { + t.Fatalf("jwx ParseKey from our JSON: %v", err) + } + + // Verify a token signed by us, using the jwx-parsed key. + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatal(err) + } + tokenStr, err := signer.SignToString(testkeys.TestClaims("jwk-interop")) + if err != nil { + t.Fatal(err) + } + + // Export the raw public key from the jwx Key. + var rawPub any + if err := jwxjwk.Export(jwxKey, &rawPub); err != nil { + t.Fatalf("jwx Export: %v", err) + } + _, err = jws.Verify([]byte(tokenStr), jws.WithKey(jwxAlg(ks.AlgName), rawPub)) + if err != nil { + t.Fatalf("jwx verify with our-JSON-parsed key: %v", err) + } + }) + + t.Run(ag.Name+"_Private", func(t *testing.T) { + ks := ag.Generate("jwk-priv-" + ag.Name) + + // Marshal our private key to JSON. + ourJSON, err := json.Marshal(ks.PrivKey) + if err != nil { + t.Fatalf("marshal our privkey: %v", err) + } + + // Parse with jwx. + jwxKey, err := jwxjwk.ParseKey(ourJSON) + if err != nil { + t.Fatalf("jwx ParseKey from our private JSON: %v", err) + } + + // Sign with the jwx-parsed key, verify with our lib. + if err := jwxKey.Set(jwxjwk.KeyIDKey, ks.KID); err != nil { + t.Fatal(err) + } + tok := jwxjwt.New() + tok.Set(jwxjwt.SubjectKey, "jwk-priv-interop") + tok.Set(jwxjwt.ExpirationKey, time.Now().Add(time.Hour)) + signed, err := jwxjwt.Sign(tok, jwxjwt.WithKey(jwxAlg(ks.AlgName), jwxKey)) + if err != nil { + t.Fatalf("jwx sign with our-JSON-parsed key: %v", err) + } + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + if _, err := verifier.VerifyJWT(string(signed)); err != nil { + t.Fatalf("our verify: %v", err) + } + }) + } +} + +func TestJWKInterop_JWXJSONToOur(t *testing.T) { + for _, ag := range testkeys.AllAlgorithms() { + t.Run(ag.Name, func(t *testing.T) { + ks := ag.Generate("jwx-to-our-" + ag.Name) + + // Create jwx key and serialize. + jwxKey, err := jwxjwk.Import(ks.RawPub) + if err != nil { + t.Fatal(err) + } + jwxKey.Set(jwxjwk.KeyIDKey, ks.KID) + jwxJSON, err := json.Marshal(jwxKey) + if err != nil { + t.Fatalf("marshal jwx key: %v", err) + } + + // Parse with our library. + var recovered jwt.PublicKey + if err := json.Unmarshal(jwxJSON, &recovered); err != nil { + t.Fatalf("our unmarshal of jwx JSON: %v", err) + } + + // Sign with our signer, verify with the recovered key. + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatal(err) + } + tokenStr, err := signer.SignToString(testkeys.TestClaims("jwx-json")) + if err != nil { + t.Fatal(err) + } + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{recovered}) + if _, err := verifier.VerifyJWT(tokenStr); err != nil { + t.Fatalf("verify with jwx-JSON-parsed key: %v", err) + } + }) + } +} + +// --- Thumbprint consistency (RFC 7638) --- + +func TestThumbprintConsistency_JWX(t *testing.T) { + for _, ag := range testkeys.AllAlgorithms() { + t.Run(ag.Name, func(t *testing.T) { + ks := ag.Generate("thumb-" + ag.Name) + + // Our thumbprint (returns base64url string). + ourThumb, err := ks.PubKey.Thumbprint() + if err != nil { + t.Fatalf("our Thumbprint: %v", err) + } + + // jwx thumbprint (returns raw bytes). + jwxKey, err := jwxjwk.Import(ks.RawPub) + if err != nil { + t.Fatal(err) + } + jwxRaw, err := jwxKey.Thumbprint(crypto.SHA256) + if err != nil { + t.Fatalf("jwx Thumbprint: %v", err) + } + jwxThumb := base64.RawURLEncoding.EncodeToString(jwxRaw) + + if ourThumb != jwxThumb { + t.Errorf("thumbprint mismatch:\n ours: %s\n jwx: %s", ourThumb, jwxThumb) + } + }) + } +} + +// --- JWKS interop --- + +func TestJWKSInterop_OurToJWX(t *testing.T) { + // Build a signer with all 5 key types. + var keys []*jwt.PrivateKey + var sets []testkeys.KeySet + for _, ag := range testkeys.AllAlgorithms() { + ks := ag.Generate("jwks-" + ag.Name) + keys = append(keys, ks.PrivKey) + sets = append(sets, ks) + } + signer, err := jwt.NewSigner(keys) + if err != nil { + t.Fatal(err) + } + + // Serialize our JWKS. + jwksData, err := json.Marshal(&signer) + if err != nil { + t.Fatal(err) + } + + // Parse with jwx. + jwxSet, err := jwxjwk.Parse(jwksData) + if err != nil { + t.Fatalf("jwx Parse JWKS: %v", err) + } + if jwxSet.Len() != 5 { + t.Fatalf("expected 5 keys, got %d", jwxSet.Len()) + } + + // Sign tokens with each key and verify with the jwx-parsed set. + for i, ks := range sets { + tokenStr, err := signer.SignToString(testkeys.TestClaims(fmt.Sprintf("jwks-%d", i))) + if err != nil { + t.Fatalf("sign[%d]: %v", i, err) + } + _, err = jws.Verify([]byte(tokenStr), jws.WithKeySet(jwxSet)) + if err != nil { + t.Errorf("jwx verify[%d] (%s) with parsed JWKS: %v", i, ks.AlgName, err) + } + } +} + +func TestJWKSInterop_JWXToOur(t *testing.T) { + // Build a jwx key set. + jwxSet := jwxjwk.NewSet() + var sets []testkeys.KeySet + for _, ag := range testkeys.AllAlgorithms() { + ks := ag.Generate("jwx-jwks-" + ag.Name) + sets = append(sets, ks) + jwxKey, err := jwxjwk.Import(ks.RawPub) + if err != nil { + t.Fatal(err) + } + jwxKey.Set(jwxjwk.KeyIDKey, ks.KID) + jwxKey.Set(jwxjwk.AlgorithmKey, jwxAlg(ks.AlgName)) + if err := jwxSet.AddKey(jwxKey); err != nil { + t.Fatal(err) + } + } + + // Serialize jwx JWKS. + jwksData, err := json.Marshal(jwxSet) + if err != nil { + t.Fatal(err) + } + + // Parse with our library. + var ourJWKS jwt.WellKnownJWKs + if err := json.Unmarshal(jwksData, &ourJWKS); err != nil { + t.Fatalf("our unmarshal of jwx JWKS: %v", err) + } + if len(ourJWKS.Keys) != 5 { + t.Fatalf("expected 5 keys, got %d", len(ourJWKS.Keys)) + } + + verifier, _ := jwt.NewVerifier(ourJWKS.Keys) + + // Sign tokens with jwx, verify with our library. + for _, ks := range sets { + jwxKey, _ := jwxjwk.Import(ks.RawPriv) + jwxKey.Set(jwxjwk.KeyIDKey, ks.KID) + tok := jwxjwt.New() + tok.Set(jwxjwt.SubjectKey, "jwx-to-our") + tok.Set(jwxjwt.ExpirationKey, time.Now().Add(time.Hour)) + signed, err := jwxjwt.Sign(tok, jwxjwt.WithKey(jwxAlg(ks.AlgName), jwxKey)) + if err != nil { + t.Fatalf("jwx sign %s: %v", ks.AlgName, err) + } + if _, err := verifier.VerifyJWT(string(signed)); err != nil { + t.Errorf("our verify %s from jwx JWKS: %v", ks.AlgName, err) + } + } +} + +// --- Audience interop --- + +func TestAudienceStringInterop_JWX(t *testing.T) { + ks := testkeys.GenerateEdDSA("aud-test") + + // Our library: single aud marshals as plain string "single-aud". + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatal(err) + } + claims := testkeys.ListishClaims("aud-str", jwt.Listish{"single-aud"}) + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + + tok, err := jwxjwt.Parse([]byte(tokenStr), jwxjwt.WithKey(jwa.EdDSA(), ks.RawPub)) + if err != nil { + t.Fatalf("jwx parse: %v", err) + } + aud, ok := tok.Audience() + if !ok || len(aud) != 1 || aud[0] != "single-aud" { + t.Errorf("aud: got %v (ok=%v), want [single-aud]", aud, ok) + } + + // Reverse: jwx signs with single aud, our library parses. + jwxKey, _ := jwxjwk.Import(ks.RawPriv) + jwxKey.Set(jwxjwk.KeyIDKey, ks.KID) + jwxTok := jwxjwt.New() + jwxTok.Set(jwxjwt.ListishKey, []string{"single-aud"}) + jwxTok.Set(jwxjwt.SubjectKey, "aud-str-rev") + jwxTok.Set(jwxjwt.ExpirationKey, time.Now().Add(time.Hour)) + signed, err := jwxjwt.Sign(jwxTok, jwxjwt.WithKey(jwa.EdDSA(), jwxKey)) + if err != nil { + t.Fatal(err) + } + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + verifiedJWS, err := verifier.VerifyJWT(string(signed)) + if err != nil { + t.Fatal(err) + } + var decoded jwt.TokenClaims + verifiedJWS.UnmarshalClaims(&decoded) + if len(decoded.Aud) == 0 || decoded.Aud[0] != "single-aud" { + t.Errorf("reverse aud: got %v, want [single-aud]", decoded.Aud) + } +} + +func TestAudienceArrayInterop_JWX(t *testing.T) { + ks := testkeys.GenerateEdDSA("aud-arr") + + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + claims := testkeys.ListishClaims("aud-arr", jwt.Listish{"aud1", "aud2"}) + tokenStr, _ := signer.SignToString(claims) + + tok, err := jwxjwt.Parse([]byte(tokenStr), jwxjwt.WithKey(jwa.EdDSA(), ks.RawPub)) + if err != nil { + t.Fatalf("jwx parse: %v", err) + } + aud, _ := tok.Audience() + if len(aud) != 2 || aud[0] != "aud1" || aud[1] != "aud2" { + t.Errorf("aud: got %v, want [aud1 aud2]", aud) + } +} + +// --- Custom claims interop --- + +func TestCustomClaimsInterop_JWX(t *testing.T) { + ks := testkeys.GenerateEdDSA("custom") + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + claims := &testkeys.CustomClaims{ + TokenClaims: *testkeys.TestClaims("custom-user"), + Email: "user@example.com", + Roles: []string{"admin", "editor"}, + Metadata: map[string]string{"team": "platform"}, + } + tokenStr, err := signer.SignToString(claims) + if err != nil { + t.Fatal(err) + } + + tok, err := jwxjwt.Parse([]byte(tokenStr), jwxjwt.WithKey(jwa.EdDSA(), ks.RawPub)) + if err != nil { + t.Fatalf("jwx parse: %v", err) + } + + var email string + if err := tok.Get("email", &email); err != nil { + t.Fatalf("get email: %v", err) + } + if email != "user@example.com" { + t.Errorf("email: got %q, want %q", email, "user@example.com") + } + + var roles []any + if err := tok.Get("roles", &roles); err != nil { + t.Fatalf("get roles: %v", err) + } + if len(roles) != 2 || fmt.Sprint(roles[0]) != "admin" { + t.Errorf("roles: got %v, want [admin editor]", roles) + } + + var meta map[string]any + if err := tok.Get("metadata", &meta); err != nil { + t.Fatalf("get metadata: %v", err) + } + if meta["team"] != "platform" { + t.Errorf("metadata.team: got %v, want %q", meta["team"], "platform") + } +} + +// --- NumericDate precision --- + +func TestNumericDatePrecision_JWX(t *testing.T) { + ks := testkeys.GenerateEdDSA("nd") + + // Use fixed future timestamps to test precision without triggering + // expiration validation. 2000000000 = 2033-05-18, well in the future. + var wantExp int64 = 2000000000 + var wantIat int64 = 1999999000 + claims := &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: "numdate", + Exp: wantExp, + IAt: wantIat, + } + signer, _ := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + tokenStr, _ := signer.SignToString(claims) + + // Disable validation - this test is about timestamp precision, not + // expiration checking. jwx rejects future iat by default. + tok, err := jwxjwt.Parse([]byte(tokenStr), + jwxjwt.WithKey(jwa.EdDSA(), ks.RawPub), + jwxjwt.WithValidate(false), + ) + if err != nil { + t.Fatal(err) + } + exp, ok := tok.Expiration() + if !ok || exp.Unix() != wantExp { + t.Errorf("exp: got %d (ok=%v), want %d", exp.Unix(), ok, wantExp) + } + iat, ok := tok.IssuedAt() + if !ok || iat.Unix() != wantIat { + t.Errorf("iat: got %d (ok=%v), want %d", iat.Unix(), ok, wantIat) + } + + // Reverse: jwx signs with specific times, our library reads. + var wantExp2 int64 = 2100000000 + var wantIat2 int64 = 2099999000 + jwxKey, _ := jwxjwk.Import(ks.RawPriv) + jwxKey.Set(jwxjwk.KeyIDKey, ks.KID) + jwxTok := jwxjwt.New() + jwxTok.Set(jwxjwt.SubjectKey, "numdate-rev") + jwxTok.Set(jwxjwt.ExpirationKey, time.Unix(wantExp2, 0)) + jwxTok.Set(jwxjwt.IssuedAtKey, time.Unix(wantIat2, 0)) + signed, _ := jwxjwt.Sign(jwxTok, jwxjwt.WithKey(jwa.EdDSA(), jwxKey)) + + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + verifiedJWS, _ := verifier.VerifyJWT(string(signed)) + var decoded jwt.TokenClaims + verifiedJWS.UnmarshalClaims(&decoded) + if decoded.Exp != wantExp2 { + t.Errorf("rev exp: got %d, want %d", decoded.Exp, wantExp2) + } + if decoded.IAt != wantIat2 { + t.Errorf("rev iat: got %d, want %d", decoded.IAt, wantIat2) + } +} + +// --- Stress tests --- + +func TestStress_JWX(t *testing.T) { + for _, ag := range testkeys.AllAlgorithms() { + ag := ag + t.Run(ag.Name, func(t *testing.T) { + t.Parallel() + n := 1000 + if ag.Name == "RS256" { + n = 10 + if *longTests { + n = 100 + } + } + for i := range n { + ks := ag.Generate(fmt.Sprintf("s%d", i)) + sub := fmt.Sprintf("stress-%d", i) + + // Our sign, jwx verify. + signer, err := jwt.NewSigner([]*jwt.PrivateKey{ks.PrivKey}) + if err != nil { + t.Fatalf("iter %d: NewSigner: %v", i, err) + } + tokenStr, err := signer.SignToString(testkeys.TestClaims(sub)) + if err != nil { + t.Fatalf("iter %d: SignToString: %v", i, err) + } + _, err = jws.Verify([]byte(tokenStr), jws.WithKey(jwxAlg(ks.AlgName), ks.RawPub)) + if err != nil { + t.Fatalf("iter %d: jwx verify: %v", i, err) + } + + // jwx sign, our verify. + jwxKey, _ := jwxjwk.Import(ks.RawPriv) + jwxKey.Set(jwxjwk.KeyIDKey, ks.KID) + jwxTok := jwxjwt.New() + jwxTok.Set(jwxjwt.SubjectKey, sub) + jwxTok.Set(jwxjwt.ExpirationKey, time.Now().Add(time.Hour)) + signed, err := jwxjwt.Sign(jwxTok, jwxjwt.WithKey(jwxAlg(ks.AlgName), jwxKey)) + if err != nil { + t.Fatalf("iter %d: jwx sign: %v", i, err) + } + verifier, _ := jwt.NewVerifier([]jwt.PublicKey{ks.PubKey}) + if _, err := verifier.VerifyJWT(string(signed)); err != nil { + t.Fatalf("iter %d: our verify: %v", i, err) + } + } + }) + } +} diff --git a/auth/jwt/tests/testkeys/testkeys.go b/auth/jwt/tests/testkeys/testkeys.go new file mode 100644 index 0000000..0fe9526 --- /dev/null +++ b/auth/jwt/tests/testkeys/testkeys.go @@ -0,0 +1,158 @@ +// Package testkeys provides shared key generation and test helpers for the +// JWT/JWS/JWK interop test suite. It is a regular (non-test) package so that +// each test subdirectory can import it. +package testkeys + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "io" + "time" + + "github.com/therootcompany/golib/auth/jwt" +) + +// TestClaims returns a fresh TokenClaims with iss, sub, exp, and iat set. +func TestClaims(sub string) *jwt.TokenClaims { + now := time.Now() + return &jwt.TokenClaims{ + Iss: "https://example.com", + Sub: sub, + Exp: now.Add(time.Hour).Unix(), + IAt: now.Unix(), + } +} + +// ListishClaims returns claims with the given audience. +func ListishClaims(sub string, aud jwt.Listish) *jwt.TokenClaims { + c := TestClaims(sub) + c.Aud = aud + return c +} + +// CustomClaims embeds TokenClaims and adds extra fields for testing +// cross-library custom claims extraction. +type CustomClaims struct { + jwt.TokenClaims + Email string `json:"email"` + Roles []string `json:"roles"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// DeterministicRand returns a deterministic io.Reader seeded from a string. +// Not cryptographically secure - used only for reproducible test key generation. +func DeterministicRand(seed string) io.Reader { + s := sha256.Sum256([]byte(seed)) + return &hashReader{state: s} +} + +type hashReader struct { + state [32]byte + pos int +} + +func (r *hashReader) Read(p []byte) (int, error) { + n := 0 + for n < len(p) { + if r.pos >= len(r.state) { + r.state = sha256.Sum256(r.state[:]) + r.pos = 0 + } + copied := copy(p[n:], r.state[r.pos:]) + n += copied + r.pos += copied + } + return n, nil +} + +// KeySet bundles a generated key in all the forms interop tests need: +// our library's wrappers, the raw Go crypto types, and metadata. +type KeySet struct { + PrivKey *jwt.PrivateKey // our library's key wrapper + PubKey jwt.PublicKey // our library's public key wrapper + RawPriv any // *ecdsa.PrivateKey | *rsa.PrivateKey | ed25519.PrivateKey + RawPub any // *ecdsa.PublicKey | *rsa.PublicKey | ed25519.PublicKey + KID string + AlgName string // "EdDSA", "ES256", "ES384", "ES512", "RS256" +} + +func mustPK(signer crypto.Signer, kid string) *jwt.PrivateKey { + pk, err := jwt.FromPrivateKey(signer, kid) + if err != nil { + panic("mustPK: " + err.Error()) + } + return pk +} + +// GenerateEdDSA generates an Ed25519 key set. +func GenerateEdDSA(kid string) KeySet { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + panic("GenerateEdDSA: " + err.Error()) + } + pub := priv.Public().(ed25519.PublicKey) + return KeySet{ + PrivKey: mustPK(priv, kid), + PubKey: jwt.PublicKey{Key: pub, KID: kid}, + RawPriv: priv, RawPub: pub, + KID: kid, AlgName: "EdDSA", + } +} + +// GenerateES256 generates an EC P-256 key set. +func GenerateES256(kid string) KeySet { return generateEC(kid, elliptic.P256(), "ES256") } + +// GenerateES384 generates an EC P-384 key set. +func GenerateES384(kid string) KeySet { return generateEC(kid, elliptic.P384(), "ES384") } + +// GenerateES512 generates an EC P-521 key set. +func GenerateES512(kid string) KeySet { return generateEC(kid, elliptic.P521(), "ES512") } + +func generateEC(kid string, curve elliptic.Curve, alg string) KeySet { + priv, err := ecdsa.GenerateKey(curve, rand.Reader) + if err != nil { + panic("generateEC " + alg + ": " + err.Error()) + } + return KeySet{ + PrivKey: mustPK(priv, kid), + PubKey: jwt.PublicKey{Key: &priv.PublicKey, KID: kid}, + RawPriv: priv, RawPub: &priv.PublicKey, + KID: kid, AlgName: alg, + } +} + +// GenerateRS256 generates an RSA 2048-bit key set. +func GenerateRS256(kid string) KeySet { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic("GenerateRS256: " + err.Error()) + } + return KeySet{ + PrivKey: mustPK(priv, kid), + PubKey: jwt.PublicKey{Key: &priv.PublicKey, KID: kid}, + RawPriv: priv, RawPub: &priv.PublicKey, + KID: kid, AlgName: "RS256", + } +} + +// AlgGen pairs an algorithm name with its key generator. +type AlgGen struct { + Name string + Generate func(kid string) KeySet +} + +// AllAlgorithms returns generators for all 5 supported algorithms. +func AllAlgorithms() []AlgGen { + return []AlgGen{ + {"EdDSA", GenerateEdDSA}, + {"ES256", GenerateES256}, + {"ES384", GenerateES384}, + {"ES512", GenerateES512}, + {"RS256", GenerateRS256}, + } +}